在這篇tensorflow教程中侨拦,我會(huì)解釋:
1) Tensorflow的模型(model)長(zhǎng)什么樣子?
2) 如何保存tensorflow的模型?
3) 如何恢復(fù)一個(gè)tensorflow模型來用于預(yù)測(cè)或者遷移學(xué)習(xí)?
4) 如何使用預(yù)訓(xùn)練好的模型(imported pretrained models)來用于fine-tuning和?modification
1. Tensorflow模型是什么?
當(dāng)你已經(jīng)訓(xùn)練好一個(gè)神經(jīng)網(wǎng)絡(luò)之后谷暮,你想要保存它,用于以后的使用盛垦,部署到產(chǎn)品里面去湿弦。所以,Tensorflow模型是什么腾夯?Tensorflow模型主要包含網(wǎng)絡(luò)的設(shè)計(jì)或者圖(graph)颊埃,和我們已經(jīng)訓(xùn)練好的網(wǎng)絡(luò)參數(shù)的值蔬充。因此Tensorflow模型有兩個(gè)主要的文件:
A)?Meta graph:
這是一個(gè)保存完整Tensorflow graph的protocol buffer,比如說班利,所有的?variables, operations, collections等等饥漫。這個(gè)文件的后綴是.meta。
B)?Checkpoint file:
這是一個(gè)包含所有權(quán)重(weights)罗标,偏置(biases)庸队,梯度(gradients)和所有其他保存的變量(variables)的二進(jìn)制文件。它包含兩個(gè)文件:
mymodel.data-00000-of-00001
mymodel.index
其中闯割,.data文件包含了我們的訓(xùn)練變量彻消。
另外,除了這兩個(gè)文件宙拉,Tensorflow有一個(gè)叫做checkpoint的文件宾尚,記錄著已經(jīng)最新的保存的模型文件。
注:Tensorflow 0.11版本以前谢澈,Checkpoint file只有一個(gè)后綴名為.ckpt的文件煌贴。
?因此,總結(jié)來說锥忿,Tensorflow(版本0.10以后)模型長(zhǎng)這個(gè)樣子:
? ? ? ?Tensorflow版本0.11以前牛郑,只包含以下三個(gè)文件:
inception_v1.meta
inception_v1.ckpt
checkpoint
?????? 接下來說明如何保存模型。
2. 保存一個(gè)Tensorflow模型
當(dāng)網(wǎng)絡(luò)訓(xùn)練結(jié)束時(shí)缎谷,我們要保存所有變量和網(wǎng)絡(luò)結(jié)構(gòu)體到文件中井濒。在Tensorflow中灶似,我們可以創(chuàng)建一個(gè)tf.train.Saver()?類的實(shí)例列林,如下:
saver = tf.train.Saver()
由于Tensorflow變量?jī)H僅只在session中存在,因此需要調(diào)用save方法來將模型保存在一個(gè)session中酪惭。
saver.save(sess,'my-test-model')
在這里希痴,sess是一個(gè)session對(duì)象,其中my-test-model是你給模型起的名字春感。下面是一個(gè)完整的例子:
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')# This will save following files in Tensorflow v >= 0.11# my_test_model.data-00000-of-00001# my_test_model.index# my_test_model.meta# checkpoint
如果我們想在訓(xùn)練1000次迭代之后保存模型砌创,可以使用如下方法保存
saver.save(sess,'my_test_model',global_step=1000)
這個(gè)將會(huì)在模型名字的后面追加上‘-1000’,下面的文件將會(huì)被創(chuàng)建:
my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint
由于網(wǎng)絡(luò)的圖(graph)在訓(xùn)練的時(shí)候是不會(huì)改變的鲫懒,因此嫩实,我們沒有必要每次都重復(fù)保存.meta文件,可以使用如下方法:
saver.save(sess,'my-model',global_step=step,write_meta_graph=False)
如果你只想要保存最新的4個(gè)模型窥岩,并且想要在訓(xùn)練的時(shí)候每2個(gè)小時(shí)保存一個(gè)模型甲献,那么你可以使用max_to_keep 和 keep_checkpoint_every_n_hours,如下所示:
#saves a model every 2 hours and maximum 4 latest models are saved.saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)
注意到颂翼,我們?cè)趖f.train.Saver()中并沒有指定任何東西晃洒,因此它將保存所有變量慨灭。如果我們不想保存所有的變量,只想保存其中一些變量球及,我們可以在創(chuàng)建tf.train.Saver實(shí)例的時(shí)候氧骤,給它傳遞一個(gè)我們想要保存的變量的list或者字典。示例如下:
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)
3. 導(dǎo)入一個(gè)已經(jīng)訓(xùn)練好的模型
如果你想要使用別人已經(jīng)訓(xùn)練好的模型來fine-tuning吃引,那么你需要做兩個(gè)步驟:
A)創(chuàng)建網(wǎng)絡(luò)Create the network:
?????? 你可以通過寫python代碼筹陵,來手動(dòng)地創(chuàng)建每一個(gè)、每一層镊尺,使得跟原始網(wǎng)絡(luò)一樣惶翻。
但是,如果你仔細(xì)想的話鹅心,我們已經(jīng)將模型保存在了.meta文件中吕粗,因此我們可以使用tf.train.import()函數(shù)來重新創(chuàng)建網(wǎng)絡(luò),使用方法如下:
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
? ? ? ?注意旭愧,這僅僅是將已經(jīng)定義的網(wǎng)絡(luò)導(dǎo)入到當(dāng)前的graph中颅筋,但是我們還是需要加載網(wǎng)絡(luò)的參數(shù)值。
B)加載參數(shù)Load the parameters
?????? 我們可以通過調(diào)用restore函數(shù)來恢復(fù)網(wǎng)絡(luò)的參數(shù)输枯,如下:
with tf.Session() as sess:
? new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
? new_saver.restore(sess, tf.train.latest_checkpoint('./'))
在這之后议泵,像w1和w2的tensor的值已經(jīng)被恢復(fù),并且可以獲取到:
with tf.Session() as sess:? ?
? ? saver = tf.train.import_meta_graph('my-model-1000.meta')
? ? saver.restore(sess,tf.train.latest_checkpoint('./'))
? ? print(sess.run('w1:0'))##Model has been restored. Above statement will print the saved value of w1.
? ? ? ?上面介紹了如何保存和恢復(fù)一個(gè)Tensorflow模型桃熄。下面介紹一個(gè)加載任何預(yù)訓(xùn)練模型的實(shí)用方法先口。
4. Working with restored models
下面介紹如何恢復(fù)任何一個(gè)預(yù)訓(xùn)練好的模型,并使用它來預(yù)測(cè)瞳收,fine-tuning或者進(jìn)一步訓(xùn)練碉京。當(dāng)你使用Tensorflow時(shí),你會(huì)定義一個(gè)圖(graph)螟深,其中谐宙,你會(huì)給這個(gè)圖喂(feed)訓(xùn)練數(shù)據(jù)和一些超參數(shù)(比如說learning rate,global step等)界弧。下面我們使用placeholder建立一個(gè)小的網(wǎng)絡(luò)凡蜻,然后保存該網(wǎng)絡(luò)。注意到垢箕,當(dāng)網(wǎng)絡(luò)被保存時(shí)划栓,placeholder的值并不會(huì)被保存。
import tensorflow as tf#Prepare to feed input, i.e. feed_dict and placeholdersw1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}#Define a test operation that we will restorew3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())#Create a saver object which will save all the variablessaver = tf.train.Saver()#Run the operation by feeding inputprint sess.run(w4,feed_dict)#Prints 24 which is sum of (w1+w2)*b1 #Now, save the graphsaver.save(sess,'my_test_model',global_step=1000)
現(xiàn)在条获,我們想要恢復(fù)這個(gè)網(wǎng)絡(luò)忠荞,我們不僅需要恢復(fù)圖(graph)和權(quán)重,而且也需要準(zhǔn)備一個(gè)新的feed_dict,將新的訓(xùn)練數(shù)據(jù)喂給網(wǎng)絡(luò)钻洒。我們可以通過使用graph.get_tensor_by_name()方法來獲得已經(jīng)保存的操作(operations)和placeholder variables奋姿。
#How to access saved variable/Tensor/placeholders w1 = graph.get_tensor_by_name("w1:0")## How to access saved operationop_to_restore = graph.get_tensor_by_name("op_to_restore:0")
如果我們僅僅想要用不同的數(shù)據(jù)運(yùn)行這個(gè)網(wǎng)絡(luò),可以簡(jiǎn)單的使用feed_dict來將新的數(shù)據(jù)傳遞給網(wǎng)絡(luò)素标。
import tensorflow as tf
sess=tf.Session()? ? #First let's load meta graph and restore weightssaver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))# Now, let's access and create placeholders variables and# create feed-dict to feed new datagraph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}#Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0")print sess.run(op_to_restore,feed_dict)#This will print 60 which is calculated #using new values of w1 and w2 and saved value of b1.
如果你想要給graph增加更多的操作(operations)然后訓(xùn)練它称诗,可以像如下那么做:
import tensorflow as tf
sess=tf.Session()? ? #First let's load meta graph and restore weightssaver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))# Now, let's access and create placeholders variables and# create feed-dict to feed new datagraph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}#Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0")#Add more to the current graphadd_on_op = tf.multiply(op_to_restore,2)print sess.run(add_on_op,feed_dict)#This will print 120.
但是,你可以只恢復(fù)舊的graph的一部分头遭,然后插入一些操作用于fine-tuning寓免?當(dāng)然可以。僅僅需要通過?by graph.get_tensor_by_name()?方法來獲取合適的operation计维,然后在這上面建立graph袜香。下面是一個(gè)實(shí)際的例子,我們使用meta graph?加載了一個(gè)預(yù)訓(xùn)練好的vgg模型鲫惶,并且在最后一層將輸出個(gè)數(shù)改成2蜈首,然后用新的數(shù)據(jù)fine-tuning。
......
......
saver = tf.train.import_meta_graph('vgg.meta')# Access the graphgraph = tf.get_default_graph()## Prepare the feed_dict for feeding data for fine-tuning #Access the appropriate output for fine-tuningfc7= graph.get_tensor_by_name('fc7:0')#use this if you only want to change gradients of the last layerfc7 = tf.stop_gradient(fc7)# It's an identity functionfc7_shape= fc7.get_shape().as_list()
new_outputs=2weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)# Now, you run this with fine-tuning data in sess.run()