此文翻譯自:A quick complete tutorial to save and restore Tensorflow models
這篇tensorflow的教程棉安,將解釋:
1. Tensorflow模型是什么樣的摩骨?
2. 如何保存一個Tensorflow模型颖对?鳖孤?
3. 如何恢復(fù)一個Tensorflow模型轻抱,用于預(yù)測或者遷移學(xué)習(xí)奕剃?
4. 如何利用導(dǎo)入的預(yù)訓(xùn)練好模型渣窜,進(jìn)行fine-tuning或改造吏垮。
這篇教程障涯,假定讀者對神經(jīng)網(wǎng)絡(luò)的訓(xùn)練有基本的了解。如果不是膳汪,請先閱讀Tensorflow Tutorial 2: image classifier using convolutional neural network唯蝶,然后閱讀本文。
1. Tensorflow 模型是什么遗嗽?
當(dāng)訓(xùn)練完一個神經(jīng)網(wǎng)絡(luò)粘我,你就會保存它,以便日后使用和產(chǎn)品發(fā)布痹换。所以征字,Tensorflow模型是如何表示的呢?Tensorflow模型主要包含網(wǎng)絡(luò)設(shè)計(jì)(Graph)和訓(xùn)練好的參數(shù)的值娇豫。因此匙姜,Tensorflow模型包含兩個主要的文件:
a) Meta graph:
這是一個協(xié)議緩沖區(qū)(protocol buffer,google推出的數(shù)據(jù)存儲格式)冯痢,保存完整的Tensorflow的graph信息氮昧;例如:所有的變量,操作(ops)浦楣,集合(collection)等袖肥。此文件帶有.meta擴(kuò)展。
b) Checkpoint file:
它是一個二進(jìn)制文件振劳,包含所有的權(quán)重椎组,偏置,導(dǎo)數(shù)和其他保存變量的值历恐。文件后綴為: .ckpt庐杨。但自從0.11版本之后,Temsorflow作了改變夹供,不再是一個單獨(dú)的.ckpt文件灵份,取而代之的是兩個文件:
<<mymodel.data-00000-of-00001>>
<<mymodel.index>>
.data文件包含著訓(xùn)練好的變量的值,除此之外哮洽,Tensorflow還有一個名為checkpoint的文件填渠,持續(xù)記錄著最新的保存數(shù)據(jù)。
所以,總結(jié)下來氛什,0.10之后的Tensorflow模型如下圖所示:
而莺葫,0.11版本之前的Tensorflow模型,僅僅包含三個文件:
<<inception_v1.meta>>
<<inception_v1.ckpt>>
<<checkpoint>>
2. 保存一個Tensorflow模型:
假設(shè)枪眉,你正在訓(xùn)練一個卷積神經(jīng)網(wǎng)絡(luò)捺檬,用于圖片分類。作為一個標(biāo)準(zhǔn)操作贸铜,你持續(xù)觀測Loss function和Accuracy堡纬。一旦你看到網(wǎng)絡(luò)收斂,你可以人為停止訓(xùn)練或者只訓(xùn)練固定數(shù)目的epochs蒿秦。當(dāng)訓(xùn)練完成之后烤镐,我們想要保存所有的變量和網(wǎng)絡(luò)圖(network graph)到一個文件,以便日后使用棍鳖。因此炮叶,在Tensorflow中,為了保存graph和變量渡处,我們應(yīng)該新建一個tf.train.Saver()類镜悉。
謹(jǐn)記Tensorflow的變量只有在一個session中才是有效的。因此医瘫,你不得不在一個session中保存模型积瞒,使用剛剛新建的saver對象,調(diào)用save方法登下,如下:
這里,sess是一個session對象叮喳,“my-test-model”是你想要保存的模型的名字被芳。完整的例子如下:
如果,我們想要在1000次迭代之后保存模型馍悟,可以傳入表示步數(shù)的參數(shù):
這行代碼將添加‘-1000’至模型的名字畔濒,以下文件將被建立:
假設(shè),訓(xùn)練時(shí)锣咒,我們每隔1000次迭代保存一次模型侵状,因此,.meta文件第1000次迭代生成.meta文件后勇劣,我們不必要每次新建.meta文件(即在2000,3000次等迭代無須新建.meta文件)嫁艇。我們僅僅保存最新的迭代模型致开。因?yàn)間raph結(jié)構(gòu)并沒有改變,因此艇潭,也沒必要寫meta-graph,使用如下代碼:
如果你想要只記錄最新的4個模型,并每隔2個小時(shí)保存一個模型蹋凝,可以使用這兩個參數(shù):max_to_keep和keep_checkpoint_every_n_hours鲁纠,如下:
需要指出的是,如果我們在tf.train.Saver()中不指定任何事情鳍寂,它將保存所有的變量改含。如果,我們不想保存所有的變量迄汛,僅僅是一部分捍壤。我們可以指定想要保存的變量或集合。當(dāng)新建tf.train.Saver實(shí)例時(shí)隔心,傳遞給它一個想要保存的變量的列表或者字典白群。看下面的例子:
可以保存Tensorflow Graph的指定的需要的部分硬霍。
3. 導(dǎo)入預(yù)訓(xùn)練的模型
如果你想要使用別人訓(xùn)練好的模型做fine-tuning帜慢,有兩件事需要做:
a) 構(gòu)建網(wǎng)絡(luò):
你可以寫python代碼,像寫預(yù)訓(xùn)練的網(wǎng)絡(luò)一樣唯卖,人為地復(fù)原每一層或者每一個模塊粱玲。但是,如果你想到我們已經(jīng)將網(wǎng)絡(luò)保存到.meta文件里了拜轨,就可以使用tf.train.import()函數(shù)抽减,恢復(fù)網(wǎng)絡(luò)結(jié)構(gòu),如下:
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
記住橄碾,import_meta_graph方法將預(yù)定義在.meta文件的網(wǎng)絡(luò)添加到當(dāng)前網(wǎng)絡(luò)卵沉。因此,該方法構(gòu)造graph結(jié)構(gòu)法牲,但我們?nèi)孕枰虞d預(yù)訓(xùn)練的參數(shù)的值史汗。
b) 加載參數(shù):
我們可以通過tf.train.Saver()的restore方法,恢復(fù)網(wǎng)絡(luò)的參數(shù):
執(zhí)行完上述代碼拒垃,w1和w2張量的值就被恢復(fù)了停撞,可以通過如下代碼獲取:
所以悼瓮,至此你已經(jīng)理解了如何保存和導(dǎo)入Tensorflow模型的工作戈毒。下一章節(jié),我將描述加載任意預(yù)訓(xùn)練模型的實(shí)際使用横堡。
4. 使用恢復(fù)模型
既然你已經(jīng)理解如何保存并恢復(fù)Tensorflow模型埋市,讓我們養(yǎng)成一個規(guī)范去恢復(fù)任意預(yù)訓(xùn)練模型,并使用它做預(yù)測命贴,fine-tuning或者進(jìn)一步訓(xùn)練恐疲。不管什么時(shí)候使用Tensorflow腊满,你將定義一個Graph,包含輸入培己,一些超參數(shù)碳蛋,如learning rate, global step等省咨。一個標(biāo)準(zhǔn)的喂入數(shù)據(jù)和超參數(shù)的方式是使用placeholders肃弟。讓我們構(gòu)建一個小的使用placeholders的網(wǎng)絡(luò),并保存它零蓉。值得指出的是笤受。當(dāng)網(wǎng)絡(luò)被保存。placeholders的值并未保存敌蜂。
現(xiàn)在箩兽,當(dāng)我們想要恢復(fù)模型時(shí),不僅需要恢復(fù)graph和權(quán)重章喉,也需要準(zhǔn)備新的feed_dict去喂新的訓(xùn)練數(shù)據(jù)給網(wǎng)絡(luò)汗贫。我們可以通過graph.get_tensor_by_name()等方法得到保存的ops和placeholder變量的引用。
如果我們僅僅想要在網(wǎng)絡(luò)上跑不同的數(shù)據(jù)秸脱,可以通過feed_dict傳遞新的數(shù)據(jù)給網(wǎng)絡(luò)落包。
如果想要增加更多的操作(增加更多的layers)到graph里,并訓(xùn)練它摊唇。當(dāng)然咐蝇,你也可以如下:
但是,可以只恢復(fù)一部分的graph然后增加一些操作進(jìn)行fine-tuning么巷查?當(dāng)然可以有序。利用graph.get_tensor_by_name()方法得到相應(yīng)操作的引用,在頂層構(gòu)建網(wǎng)絡(luò)岛请。這里有個實(shí)際的例子旭寿。我們加載一個預(yù)訓(xùn)練的VGG網(wǎng)絡(luò),改變輸出的單元數(shù)目為2髓需,利用新的訓(xùn)練數(shù)據(jù)fine-tuning。
希望這篇文章能讓你清晰地理解Tensorflow模型的保存和恢復(fù)房蝉。
轉(zhuǎn)載請注明來源僚匆,謝謝。