Inverse design using neural network(nanoparticle)---Train the models (Table 1))

demo文件--參數(shù)的傳入

作者提供的demo.sh文件是在linux平臺(tái)下運(yùn)行的bash文件。

#!/bin/bash
# This file trains all the models presented here. 

echo "python scatter_net.py --data data/8_layer_tio2 --output_folder results/8_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 250 --percent_val .2 --patience 10"
python scatter_net.py --data data/8_layer_tio2 --output_folder results/8_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 250 --percent_val .2 --patience 10

echo "python scatter_net.py --data data/7_layer_tio2 --output_folder results/7_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 225 --percent_val .2 --patience 10"
python scatter_net.py --data data/7_layer_tio2 --output_folder results/7_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 225 --percent_val .2 --patience 10

echo "python scatter_net.py --data data/6_layer_tio2 --output_folder results/6_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 225 --percent_val .2 --patience 10"
python scatter_net.py --data data/6_layer_tio2 --output_folder results/6_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 225 --percent_val .2 --patience 10

echo "python scatter_net.py --data data/5_layer_tio2 --output_folder results/5_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 200 --percent_val .2 --patience 10"
python scatter_net.py --data data/5_layer_tio2 --output_folder results/5_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 200 --percent_val .2 --patience 10

echo "python scatter_net.py --data data/4_layer_tio2 --output_folder results/4_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 125 --percent_val .2 --patience 10"
python scatter_net.py --data data/4_layer_tio2 --output_folder results/4_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 125 --percent_val .2 --patience 10

echo "python scatter_net.py --data data/3_layer_tio2 --output_folder results/3_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 100 --percent_val .2 --patience 10"
python scatter_net.py --data data/3_layer_tio2 --output_folder results/3_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 100 --percent_val .2 --patience 10

echo "python scatter_net.py --data data/2_layer_tio2 --output_folder results/2_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 100 --percent_val .2 --patience 10"
python scatter_net.py --data data/2_layer_tio2 --output_folder results/2_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 100 --percent_val .2 --patience 10

文件展示了2-8層的TiO2-Si的nanoparticle的訓(xùn)練過程轮蜕。scatter_net.py的最后專門寫了一個(gè)參數(shù)解析模塊parser對傳入?yún)?shù)進(jìn)行了解析悦施。

if __name__=="__main__":
    parser = argparse.ArgumentParser(description="Physics Net Training")
    parser.add_argument("--data",type=str,default='data/5_layer_tio2') # Where the data file is. Note: This assumes a file of _val.csv and .csv 
    parser.add_argument("--reuse_weights",type=str,default='False') # Whether to load the weights or not. Note this just needs to be set to true, then the output folder directed to the same location. 
    parser.add_argument("--output_folder",type=str,default='results/5_layer_tio2') #Where to output the results to. Note: No / at the end. 
    parser.add_argument("--weight_name_load",type=str,default="")#This would be something that goes infront of w_1.txt. This would be used in saving the weights. In most cases, just leave this as is, it will naturally take care of it. 
    parser.add_argument("--weight_name_save",type=str,default="") #Similiar to above, but for saving now. 
    parser.add_argument("--n_batch",type=int,default=100) # Batch Size
    parser.add_argument("--numEpochs",type=int,default=5000) #Max number of epochs to consider at maximum, if patience condition is not met. 
    parser.add_argument("--lr_rate",type=float,default=.001) # Learning Rate. 
    parser.add_argument("--lr_decay",type=float,default=.7) # Learning rate decay. It decays by this factor every epoch.
    parser.add_argument("--num_layers",default=4) # Number of layers in the network. 
    parser.add_argument("--n_hidden",default=225) # Number of neurons per layer. Fully connected layers. 
    parser.add_argument("--percent_val",default=.2) # Amount of the data to split for validation/test. The validation/test are both split equally. 
    parser.add_argument("--patience",default=10) # Patience for stopping. If validation loss has not decreased in this many steps, it will stop the training. 
    parser.add_argument("--compare",default='False') # Whether it should output the comparison or not. 
    parser.add_argument("--sample_val",default='True') # Wether it should sample from validation or not, for the purposes of graphing. 
    parser.add_argument("--spect_to_sample",type=int,default=300) # Zero Indexing for this. Position in the data file to sample from (note it will take from validation)
    parser.add_argument("--matchSpectrum",default='False') # If it should match an already existing spectrum file. 
    parser.add_argument("--match_test_file",default='results/2_layer_tio2/test_47.5_45.3') # Location of the file with the spectrum in it. 
    parser.add_argument("--designSpectrum",default='False') # If it should 
    parser.add_argument("--design_test_file",default='data/test_gen_spect.csv') # This is a file that should contain 0's and 1's where it should maximize and not maximize. 

    args = parser.parse_args() #是一個(gè)具有某些屬性的對象,這里是一個(gè)__dict__
    dict = vars(args) #返回args這個(gè)對象的非私有屬性 key-value,是一個(gè)dict。
    print(dict)

    for key,value in dict.items():
        if (dict[key]=="False"):
            dict[key] = False
        elif dict[key]=="True":
            dict[key] = True
        try:
            if dict[key].is_integer():
                dict[key] = int(dict[key])
            else:
                dict[key] = float(dict[key])
        except:
            pass
    print (dict)
#上面這段代碼將所有dict的值都轉(zhuǎn)化為bool岭佳,interger和float類型。



    #Note that reuse MUST be set to true.
    if (dict['compare'] or dict['matchSpectrum'] or dict['designSpectrum']):
        if dict['reuse_weights'] != True:
            print("Reuse weights must be set true for comparison, matching, or designing. Setting it to true....")
            time.sleep(1)
        dict['reuse_weights'] = True
        
    kwargs = {  
            'data':dict['data'],
            'reuse_weights':dict['reuse_weights'],
            'output_folder':dict['output_folder'],
            'weight_name_save':dict['weight_name_save'],
            'weight_name_load':dict['weight_name_load'],
            'n_batch':dict['n_batch'],
            'numEpochs':dict['numEpochs'],
            'lr_rate':dict['lr_rate'],
            'lr_decay':dict['lr_decay'],
            'num_layers':int(dict['num_layers']),
            'n_hidden':int(dict['n_hidden']),
            'percent_val':dict['percent_val'],
            'patienceLimit':dict['patience'],
            'compare':dict['compare'],
            'sample_val':dict['sample_val'],
            'spect_to_sample':dict['spect_to_sample'],
            'matchSpectrum':dict['matchSpectrum'],
            'match_test_file':dict['match_test_file'],
            'designSpectrum':dict['designSpectrum'],
            'design_test_file':dict['design_test_file']
            }
#定義一個(gè)dict

    if kwargs['designSpectrum'] == True: #當(dāng)需要designSpectrum時(shí)運(yùn)行下面函數(shù)
        design_spectrum(**kwargs) #將上面定義的dict作為關(guān)鍵字參數(shù)傳入
    elif kwargs['matchSpectrum'] == True: #當(dāng)需要matchSpectrum時(shí)運(yùn)行下面函數(shù)
        match_spectrum(**kwargs)
    else:
        main(**kwargs)   #main()函數(shù)是train model萧锉。

為了理解第一段代碼,我們打印一下運(yùn)行結(jié)果:

  1. print (args)
    得到的是一個(gè)namespace的對象述寡。Python使用叫做命名空間的東西來記錄變量的軌跡柿隙。命名空間是一個(gè) 字典(dictionary) ,它的鍵就是變量名鲫凶,它的值就是那些變量的值
Namespace(compare='False', data='data/5_layer_tio2', designSpectrum='False', design_test_file='data/test_gen_spect.csv', lr_decay=0.7, lr_rate=0.001, matchSpectrum='False', match_test_file='results/2_layer_tio2/test_47.5_45.3', n_batch=100, n_hidden=225, numEpochs=5000, num_layers=4, output_folder='results/5_layer_tio2', patience=10, percent_val=0.2, reuse_weights='False', sample_val='True', spect_to_sample=300, weight_name_load='', weight_name_save='')
  1. print (vars(args))
    利用vars()函數(shù)禀崖,將object args表示成了標(biāo)準(zhǔn)的dict:
{'sample_val': 'True', 'weight_name_load': '', 'compare': 'False', 'match_test_file': 'results/2_layer_tio2/test_47.5_45.3', 'patience': 10, 'numEpochs': 5000, 'design_test_file': 'data/test_gen_spect.csv', 'matchSpectrum': 'False', 'n_batch': 100, 'spect_to_sample': 300, 'output_folder': 'results/5_layer_tio2', 'n_hidden': 225, 'percent_val': 0.2, 'designSpectrum': 'False', 'num_layers': 4, 'lr_rate': 0.001, 'weight_name_save': '', 'reuse_weights': 'False', 'data': 'data/5_layer_tio2', 'lr_decay': 0.7}
  1. 對dict進(jìn)行處理之后,把value的值轉(zhuǎn)變?yōu)楹线m的bool螟炫,int or float數(shù)據(jù)類型波附,比如'True'輸入時(shí)會(huì)時(shí)str的數(shù)據(jù)類型,處理之后變成bool值True
{'sample_val': True, 'weight_name_load': '', 'compare': False, 'match_test_file': 'results/2_layer_tio2/test_47.5_45.3', 'patience': 10, 'numEpochs': 5000, 'design_test_file': 'data/test_gen_spect.csv', 'matchSpectrum': False, 'n_batch': 100, 'spect_to_sample': 300, 'output_folder': 'results/5_layer_tio2', 'n_hidden': 225, 'percent_val': 0.2, 'designSpectrum': False, 'num_layers': 4, 'lr_rate': 0.001, 'weight_name_save': '', 'reuse_weights': False, 'data': 'data/5_layer_tio2', 'lr_decay': 0.7}

linux中的參數(shù)傳入

在linux中運(yùn)行昼钻,用的是關(guān)鍵字參數(shù)傳入掸屡。

python scatter_net.py --data data/5_layer_tio2 --output_folder results/5_layer_tio2 --n_batch 100 --numEpochs 5000 --lr_rate .0006 --lr_decay .99 --num_layers 4 --n_hidden 250 --percent_val .2 --patience 10
  1. --data data/5_layer_tio2給出了數(shù)據(jù)的存儲(chǔ)位置;
  2. --output_folder results/5_layer_tio2 給出了結(jié)果的存儲(chǔ)位置然评;
  3. --n_batch 100 給出了batch size是100仅财;
  4. --numEpochs 5000
  5. --lr_rate .0006 給出了學(xué)習(xí)速率;
  6. --lr_decay .99 給出了decay速度碗淌;
  7. --num_layers 4 一共有四層網(wǎng)絡(luò)盏求;
  8. --n_hidden 250 每層網(wǎng)絡(luò)有250個(gè)神經(jīng)元抖锥;
  9. --percent_val .2 用來做validation的數(shù)據(jù)的比例;
  10. --patience 10 tranning stop的條件碎罚。

可知磅废,傳入上述參數(shù)之后將會(huì)運(yùn)行main(**kwargs)模塊,其他沒傳入的參數(shù)使用默認(rèn)值荆烈。

main()函數(shù)中的數(shù)據(jù)處理

來到main函數(shù)(scatter_net line167)拯勉,函數(shù)的第一部分還是在處理文件名,我們從line189:#getting the data看起耙考,其中最重要的在line 191:

train_X, train_Y , test_X, test_Y, val_X, val_Y , x_mean, x_std = get_data(data,percentTest=percent_val)

調(diào)用了get_data()函數(shù)谜喊,這個(gè)函數(shù)位于scatter_net_core.py文件中:

from sklearn.model_selection import train_test_split

def get_data(data,percentTest=.2,random_state=42):
    x_file = data+"_val.csv"
    y_file = data+".csv"
    train_X = np.genfromtxt(x_file,delimiter=',')#[0:20000,:]
    train_Y = np.transpose(np.genfromtxt(y_file,delimiter=','))#[0:20000,:]
    train_x_mean = train_X.mean(axis=0) #train_X is an array. 求第一列的所有數(shù)據(jù)(sample)的平均值
    train_x_std = train_X.std(axis=0) #第一列所有數(shù)據(jù)的標(biāo)準(zhǔn)差
    train_X = (train_X-train_X.mean(axis=0))/train_X.std(axis=0) #對數(shù)據(jù)進(jìn)行了zero-mean normalization,使得輸入的數(shù)據(jù)是一個(gè)標(biāo)準(zhǔn)正態(tài)分布倦始。
    X_train, test_X, y_train, test_Y = train_test_split(train_X,train_Y,test_size=float(percentTest),random_state=random_state)
#將原始數(shù)據(jù)train_X, train_Y輸入斗遏,并給定train和test的分離比例,就可以把數(shù)據(jù)分成train和test兩部分鞋邑。
    X_test, X_val, y_test, y_val = train_test_split(test_X,test_Y,test_size=.5,random_state=random_state)
#將上一步分離好的test數(shù)據(jù)按0.5進(jìn)行了二次split诵次,一部分作為test,一部分留作X_val,和y_val枚碗。
    return X_train, y_train, X_test, y_test, X_val, y_val, train_x_mean, train_x_std

data文件夾中對于每種情況(比如5_layer)有"_val.csv"和".csv"兩個(gè)文件逾一。data文件夾的read me內(nèi)容如下:

These are all the data files used in the paper.
Note all these were generated using the "ScatterNet_Matlab" directory here in the repository. Be cautious of the order of the harmonics - as the particle get more layers, more orders must be added to compensate for more modes.
Directory:
Data for n layer particle with alternating silica/TiO2 shells:
n_layer_tio2.csv
n_layer_tio2_val.csv
The _val file indicates what the values of the thickneses are (in nanometers). The other file - 2_layer_tio2.csv - indicates the values of the spectrum for each corresponding particle. That is, the first line in 2_layer_tio2 corresponds to the first line in 2_layer_tio2_val.
The 2 layer particle has 30k records. The 3,4,5,6,7 layer has 40k. The 8 layer has 50k.
Data for 3 layer jaggregate particle:
jagg_layer_tio2.csv
jagg_layer_tio2_val.csv
Same format as above. The _val file indicates the thickness of the metallic silver core, dielectric layer of silica, and outside layer of the J-Aggregate dye respectively. The last number is the tuned resonnance for the J-Aggregate dye.

如上所述:

  • "_val.csv"存儲(chǔ)的是input layer (X_train)的值,形式是一個(gè)m \times n的matrix肮雨,m代表訓(xùn)練的sample數(shù)遵堵,n代表輸入層的neuron數(shù),比如是5_layer怨规,那么n就是5陌宿。
  • ".csv"存儲(chǔ)的是output layer (Y_train)的值,形式是一個(gè)m \times n的matrix波丰,m代表輸出層(Y)的neuron數(shù)(本例中是離散的frequency的值)壳坪,n代表sample量。

main()函數(shù)中的neural network的建立

weight initialization

先來看代碼掰烟,從scatter_net.pyline200看起:

    x_size = train_X.shape[1] #輸入層的neuron個(gè)數(shù)
    y_size = train_Y.shape[1]#輸出層的neuron個(gè)數(shù)

    # Symbols
    X = tf.placeholder("float", shape=[None, x_size])#placeholder 是tensorflow中常用的占位符爽蝴,這里分配給了X一個(gè)列數(shù)為x_size行數(shù)不確定的二維向量,數(shù)據(jù)類型是float
    y = tf.placeholder("float", shape=[None, y_size])
    weights = [] 
    biases = []

    # Weight initializations
    if reuse_weights:
        (weights, biases) = load_weights(output_folder,weight_name_load,num_layers) 
#如果reuse_weights 為True則直接從output_folder中導(dǎo)入纫骑,在train model時(shí)蝎亚,reuse_weights默認(rèn)為False,所以執(zhí)行下面的語句
    else:
        for i in xrange(0,num_layers): #對每層網(wǎng)絡(luò)的初始weights和bias調(diào)用了init_weights/bias函數(shù)給出
            if i ==0:
                weights.append(init_weights((x_size,n_hidden))) 
            else:
                weights.append(init_weights((n_hidden,n_hidden)))
            biases.append(init_bias(n_hidden))
        weights.append(init_weights((n_hidden,y_size)))
        biases.append(init_bias(y_size))

上面給出輸入層(X)和輸出層(Y)數(shù)據(jù)的時(shí)候惧磺,用到了placeholder占位符颖对,關(guān)于占位符的用法具體可參考:http://www.reibang.com/p/e4ff91317f7e

上面的weights和bias的初始化用到了init_weights()init_bias()兩個(gè)函數(shù),which are defined in scatter_net_core.py as following:

#As per Xaiver init, this should be 2/n(input), though many different initializations can be tried. 
def init_weights(shape,stddev=.1):
    """ Weight initialization """
    weights = tf.random_normal(shape, stddev=stddev)
    return tf.Variable(weights)

def init_bias(shape, stddev=.1):
    """ Weight initialization """
    biases = tf.random_normal([shape], stddev=stddev)
    return tf.Variable(biases)

這里就用了正態(tài)分布函數(shù)tf.random_normal()進(jìn)行了初始化磨隘。關(guān)于weights的初始化缤底,有很多方法可以選擇顾患,具體可參考:https://zhuanlan.zhihu.com/p/25110150

Forward propagation

現(xiàn)在有了網(wǎng)絡(luò)的結(jié)構(gòu)个唧,有了初始值江解,以及有了輸入層,就可以構(gòu)建正向網(wǎng)絡(luò)了

    # Forward propagation
    yhat    = forwardprop(X, weights,biases,num_layers)

這里作者構(gòu)建了一個(gè)forwardprop函數(shù)徙歼,具體形式位于scatter_net.py中:

def forwardprop(X, weights, biases, num_layers, dropout=False, minLimit=None, maxLimit=None):
    if minLimit is not None:
        X = tf.maximum(X, minLimit)
        X = tf.minimum(X, maxLimit)
    htemp = None
    for i in xrange(0, num_layers):
        if i ==0:
            htemp = tf.nn.relu(tf.add(tf.matmul(X, weights[i]), biases[i]))
        else:
            htemp = tf.nn.relu(tf.add(tf.matmul(htemp, weights[i]), biases[i]))
    yval = tf.add(tf.matmul(htemp, weights[-1]), biases[-1])
    return yval

\color{red}{上面代碼中minLimit和maxLimit的作用是什么犁河?}沒理解。
用的activation function是ReLu魄梯。

Backward propagation
 # Backward propagation
    dif = tf.abs(y-yhat)
    peroff = tf.reduce_mean(dif/tf.abs(y))
    cost = tf.reduce_mean(tf.square(y-yhat))
    global_step = tf.Variable(0, trainable=False)
    print("LR Rate: " , lr_rate) # learnin rate
    print(int(train_X.shape[0]/n_batch))
    print(lr_decay) #learning rate decay. It decays by this factor every epoch.
    print("--done--")
    learning_rate = tf.train.exponential_decay(lr_rate,global_step,int(train_X.shape[0]/n_batch),lr_decay,staircase=False) #learning rate 用了exponential_decay來更新
    optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate).minimize(cost,global_step=global_step) #用了RMS優(yōu)化器來更新weights和bias
Training

下面就是寫一個(gè)循環(huán)進(jìn)行training:

    #Now do the training. 
    step =0; curEpoch =0; cum_loss =0; perinc = 0;
    lowVal = 1000000.0 #Just make this some high number. 這個(gè)參數(shù)是用來決定是否停止訓(xùn)練的桨螺,如果在(patient=10)次內(nèi),
                       #有val_loss大于這個(gè)數(shù)酿秸,那么說明網(wǎng)絡(luò)太差了灭翔,要停止訓(xùn)練辣苏。

    start_time=time.time()

    #Session 是 Tensorflow 為了控制,和輸出文件的執(zhí)行的語句. 
    #運(yùn)行 session.run() 可以獲得你要得知的運(yùn)算結(jié)果, 或者是你所要運(yùn)算的部分.
    with tf.Session() as sess:                
        init = tf.global_variables_initializer()
        sess.run(init)

        if (compare): #Just run a comparison。在demo.sh文件中煌张,這個(gè)默認(rèn)為False,我們先跳過骏融。
            x_set = train_X
            y_set = train_Y
            if sample_val:
                x_set = val_X
                y_set = val_Y
            batch_x = x_set[spect_to_sample : (spect_to_sample+1) ]
            batch_y = y_set[spect_to_sample : (spect_to_sample+1) ]
            mycost = sess.run(cost,feed_dict={X:batch_x,y:batch_y})
            myvals0 = sess.run(yhat,feed_dict={X:batch_x,y:batch_y})
            outputSpectsToFile(output_folder,spect_to_sample,batch_x,batch_y,myvals0,mycost,x_mean,x_std)
            return
        print("========                         Iterations started                  ========")
        while curEpoch < numEpochs:    #寫了一個(gè)循環(huán)來控制訓(xùn)練次數(shù),對全部sample訓(xùn)練完一次绎谦,curEpoch += 1
            batch_x = train_X[step * n_batch : (step+1) * n_batch] #n_batch 是batch_size粥脚,代表一個(gè)batch中的sample個(gè)數(shù),默認(rèn)輸入是100.
            batch_y = train_Y[step * n_batch : (step+1) * n_batch]
            peroffinc, cuminc, _ = sess.run([peroff,cost,optimizer], feed_dict={X: batch_x, y: batch_y})
            cum_loss += cuminc #cuminc=cost是batch中所有sample的平均loss
            perinc += peroffinc
            step += 1  #每傳入一個(gè)batch包个,step +1刷允;表示一次iteration。
            #End of each epoch. 
            if step ==  int(train_X.shape[0]/n_batch): #train_X.shape[0]/n_batch代表batch的總數(shù)碧囊,當(dāng)step==batch的總數(shù)的時(shí)候树灶,代表完成一次全部數(shù)據(jù)的訓(xùn)練,即一次epoch糯而。
                curEpoch +=1            #curEpoch +1
                cum_loss = cum_loss/float(step) #所有batch的平均loss
                perinc = perinc/float(step)
                step = 0
                train_loss_file.write(str(float(cum_loss))+"," + str(perinc) + str("\n"))
                # Every 10 epochs, do a validation. 
                if (curEpoch % 10 == 0 or curEpoch == 1):
                    val_loss, peroff2 = sess.run([cost,peroff],feed_dict={X:test_X,y:test_Y})
                    val_loss_file.write(str(float(val_loss))+","+str(peroff2)+str("\n"))
                    val_loss_file.flush()
                    train_loss_file.flush()
                    if (val_loss > lowVal):
                          patience += 1
                    else:
                          patience = 0
                    lowVal = min(val_loss,lowVal)
                   #每十次epoch天通,輸出一次
                    print("Validation loss: " , str(val_loss) , " per off: " , peroff2)
                    print("Epoch: " + str(curEpoch+1) + " : Loss: " + str(cum_loss) + " : " + str(perinc)) 
                    if (patience > patienceLimit):
                        print("Reached patience limit. Terminating")
                        break
                cum_loss = 0
                perinc = 0
       #保存最終得到的weights和bias。
        save_weights(weights,biases,output_folder,weight_name_save,num_layers) 
  • 關(guān)于batch_size, epoch, iteration有很多說明熄驼,簡而言之像寒,只有在數(shù)據(jù)很龐大的時(shí)候(在機(jī)器學(xué)習(xí)中烘豹,幾乎任何時(shí)候都是),我們才需要使用 epochs诺祸,batch size携悯,迭代這些術(shù)語,在這種情況下筷笨,一次性將數(shù)據(jù)輸入計(jì)算機(jī)是不可能的憔鬼。因此,為了解決這個(gè)問題胃夏,我們需要把數(shù)據(jù)分成小塊轴或,一塊一塊的傳遞給計(jì)算機(jī),在每一步的末端更新神經(jīng)網(wǎng)絡(luò)的權(quán)重仰禀,擬合給定的數(shù)據(jù)照雁。具體可參考:http://www.reibang.com/p/005d05e18c7d
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市悼瘾,隨后出現(xiàn)的幾起案子囊榜,更是在濱河造成了極大的恐慌,老刑警劉巖亥宿,帶你破解...
    沈念sama閱讀 219,039評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件卸勺,死亡現(xiàn)場離奇詭異,居然都是意外死亡烫扼,警方通過查閱死者的電腦和手機(jī)曙求,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,426評論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來映企,“玉大人,你說我怎么就攤上這事挤渐≡÷椋” “怎么了囤攀?”我有些...
    開封第一講書人閱讀 165,417評論 0 356
  • 文/不壞的土叔 我叫張陵膏萧,是天一觀的道長榛泛。 經(jīng)常有香客問我,道長叉信,這世上最難降的妖魔是什么硼身? 我笑而不...
    開封第一講書人閱讀 58,868評論 1 295
  • 正文 為了忘掉前任佳遂,我火速辦了婚禮丑罪,結(jié)果婚禮上吩屹,老公的妹妹穿的比我還像新娘拧抖。我一直安慰自己唧席,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,892評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著徒仓,像睡著了一般。 火紅的嫁衣襯著肌膚如雪喻杈。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,692評論 1 305
  • 那天缴啡,我揣著相機(jī)與錄音业栅,去河邊找鬼。 笑死携取,一個(gè)胖子當(dāng)著我的面吹牛雷滋,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播焕檬,決...
    沈念sama閱讀 40,416評論 3 419
  • 文/蒼蘭香墨 我猛地睜開眼实愚,長吁一口氣:“原來是場噩夢啊……” “哼腊敲!你這毒婦竟也來了维苔?” 一聲冷哼從身側(cè)響起蕉鸳,我...
    開封第一講書人閱讀 39,326評論 0 276
  • 序言:老撾萬榮一對情侶失蹤潮尝,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后羹蚣,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體顽素,經(jīng)...
    沈念sama閱讀 45,782評論 1 316
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡胁出,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,957評論 3 337
  • 正文 我和宋清朗相戀三年全蝶,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了抑淫。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,102評論 1 350
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡砌烁,死狀恐怖催式,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情函似,我是刑警寧澤撇寞,帶...
    沈念sama閱讀 35,790評論 5 346
  • 正文 年R本政府宣布蔑担,位于F島的核電站咽白,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏排抬。R本人自食惡果不足惜蹲蒲,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,442評論 3 331
  • 文/蒙蒙 一届搁、第九天 我趴在偏房一處隱蔽的房頂上張望卡睦。 院中可真熱鬧漱抓,春花似錦乞娄、人聲如沸檐迟。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,996評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至契沫,卻和暖如春懈万,著一層夾襖步出監(jiān)牢的瞬間会通,已是汗流浹背娄周。 一陣腳步聲響...
    開封第一講書人閱讀 33,113評論 1 272
  • 我被黑心中介騙來泰國打工裳涛, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留众辨,地道東北人泻轰。 一個(gè)月前我還...
    沈念sama閱讀 48,332評論 3 373
  • 正文 我出身青樓浮声,卻偏偏與公主長得像泳挥,于是被迫代替她去往敵國和親屉符。 傳聞我的和親對象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,044評論 2 355

推薦閱讀更多精彩內(nèi)容