我又一次開始了“看不懂你掐死我系列”遍膜。標(biāo)題名稱是仿照知乎的一篇介紹傅里葉變換的文章起的碗硬。當(dāng)時(shí)看完了覺得還真看懂了∑奥可是關(guān)上網(wǎng)頁再自己想的時(shí)候恩尾,就有想掐死博主的沖動(dòng)~~ 為了致敬,這里貼出原文章挽懦,大家共勉翰意。
抄襲標(biāo)題:看不懂傅里葉變換就掐死他
這段時(shí)間做訓(xùn)練的時(shí)候需要分步訓(xùn)練不同的網(wǎng)絡(luò)結(jié)構(gòu)信柿,最后把所有訓(xùn)練好的graph合并成一個(gè)大graph冀偶,前后接起來并且重新定義輸入和輸出再繼續(xù)訓(xùn)練,這樣分步先訓(xùn)練小網(wǎng)絡(luò)再合成大網(wǎng)絡(luò)的話效果會(huì)好一點(diǎn)渔嚷,收斂的也會(huì)快一些进鸠。那么有個(gè)問題,怎么把訓(xùn)練好的好幾個(gè)graph恢復(fù)訓(xùn)練參數(shù)再合并到一起呢形病?Tensorflow到底能不能這么做客年?如果能,那應(yīng)該怎么做漠吻?在讀了這篇量瓜,這篇知乎,和搜了無數(shù)個(gè)stackoverflow上的例子之后途乃,終于有了答案绍傲。
要知道我們需要把每個(gè)pretrain的網(wǎng)絡(luò)的結(jié)構(gòu)和參數(shù)全都讀進(jìn)去,再把它們合并在一起耍共。先不說合并的事唧取,讀取參數(shù)和結(jié)構(gòu)就是個(gè)問題。比如下邊這幾個(gè)stackoverflow的帖子划提。1枫弟,2,3鹏往,4淡诗,5。他們都用了不同的讀取方法伊履。但是到底讀取的是什么韩容?有沒有達(dá)到我們預(yù)期的目的卻不清楚。所以我意識到先要把tensorflow的內(nèi)部結(jié)構(gòu)搞清楚唐瀑,看看存有什么東西群凶,再看看存儲和讀取的方式。先來看結(jié)構(gòu)哄辣。
Tensorflow的內(nèi)部結(jié)構(gòu):
上面的這篇和知乎都說的挺清楚的请梢,我就撿這最重要的總結(jié)一下赠尾。
我們都知道tensorflow里有g(shù)raph,graph的節(jié)點(diǎn)就是運(yùn)算operation毅弧。這個(gè)用tensorboard可視化可以看到气嫁。比如下面這就是個(gè)簡單的graph。
這個(gè)graph在tensorflow里實(shí)際的存儲方式是被序列化以后够坐,以Protocol Buffer的形式存儲的寸宵。這里有中文的對protobuf的介紹,是google開發(fā)的元咙。
graph序列化的protobuf叫做graphDef梯影,就是define graph的意思,一個(gè)graph的定義庶香。這個(gè)graphDef可以用tf.train.write_graph()/tf.Import_graph_def()來寫入和導(dǎo)出甲棍。上面stackoverflow里就有人用這個(gè)方法。然而graphDef里面其實(shí)是沒有存儲變量的脉课,但是可以存常量,就是constant财异√攘悖可以用一種叫freeze_graph的工具把變量替換成常量,這里有官方的介紹戳寸。一般來說沒有必要這么做呈驶,因?yàn)榧热淮媪司W(wǎng)絡(luò),肯定有變量的信息疫鹊,雖然不在graphDef里面袖瞻,但是肯定在別的地方。其實(shí)它存在collectionDef里拆吆。還有一些其他的Def聋迎,所以干脆歸納一下:
MetaGraph - MetaInfoDef 這個(gè)是存metadata的,像版本信息啊枣耀,用戶信息啥的
? ? ? ? ? ? ? ? ? ? - GraphDef 上面說的就是這個(gè)GraphDef
? ? ? ? ? ? ? ? ? ? - SaverDef 這個(gè)就是tf.train.Saver的saver
? ? ? ? ? ? ? ? ? ? - CollectionDef
這些Def的數(shù)據(jù)都存在一個(gè)叫MetaGraph的文件里霉晕。這個(gè)MetaGraph有官方介紹。
最后面的collectionDef就是各種集合捞奕。每個(gè)集合里都是1對多的key/value pairs牺堰。你也可以把你想要的變量存進(jìn)某個(gè)即合理,用tf.add_to_collection(collection_name颅围,變量)就行伟葫。然后再用tf.get_collection()取出來。比如我有l(wèi)oss和train_op院促,就可以:
tf.add_to_collection("training_collection",loss)
tf.add_to_collection("training_collection",train_op)
然后再用
Train_collect = tf.get_collection(“training_collection”)? #得到一個(gè)python list
list里面就是你之前存的東西筏养。所以collection我的理解就是為了方便管理變量用的斧抱。
metagraph可以用export_meta_graph/Import_meta_graph來導(dǎo)入導(dǎo)出。
這里注意了撼玄,如果你用tf.Import_graph_def()導(dǎo)入graphDef的話夺姑,導(dǎo)入的東西一般是不能訓(xùn)練的。但是用Import_meta_graph來導(dǎo)入metagraph之后掌猛,就是導(dǎo)入了一個(gè)完整的結(jié)構(gòu)盏浙,這時(shí)候是可以訓(xùn)練的。
雖然能訓(xùn)練荔茬,metagraph里也有變量废膘,但是都是起始值。也就是說我們之前訓(xùn)練的參數(shù)是沒有導(dǎo)入的慕蔚。這里訓(xùn)練等于是從頭訓(xùn)練丐黄。實(shí)際的訓(xùn)練參數(shù)沒有存在metagraph里,而是在data文件里孔飒。這個(gè)下面會(huì)提到灌闺。
說完了tensorflow的結(jié)構(gòu),再說說存儲的方式坏瞄」鸲裕看完這節(jié),你應(yīng)該完全知道什么api是用來讀什么的了鸠匀。
存儲與讀冉缎薄:
上面那篇中文知乎恰好總結(jié)了這些。一般存讀有3個(gè)API:
tf.train.Saver()/saver.restore()
export_meta_graph/Import_meta_graph
tf.train.write_graph()/tf.Import_graph_def()
后兩個(gè)上一節(jié)都見過了∽汗鳎現(xiàn)在說說第一個(gè)宅此。
我平時(shí)常用的只有第一個(gè)tf.train.Saver()和saver.restore()。我也看到很多代碼里這么寫爬范。但是有一點(diǎn)很坑爹的是tf.train.saver.save() 什么都保存父腕。但是在恢復(fù)圖時(shí),tf.train.saver.restore() 只恢復(fù) Variable青瀑,如果要從MetaGraph恢復(fù)圖侣诵,需要使用 import_meta_graph。看明白了嗎狱窘?saver.save()和saver.restore()保存和讀取的東西不杜顺!一!樣蘸炸!也就是說如果我想重組graph躬络,要么用Import_meta_graph來導(dǎo)入graph,之后再saver.restore()搭儒;要么就從新建立graph穷当,把tensor傳入結(jié)構(gòu)的過程再寫一遍提茁,然后再saver.restore()。不然連變量名都找不到肯定會(huì)報(bào)錯(cuò)馁菜。
說道存儲茴扁,我們必須得看看存儲文件的格式。如果你用saver.save()保存的話(好像也只有這一種方法)汪疮,打開你的保存文件夾峭火,你會(huì)看到4種后綴名的文件(events開頭的不算,那是tf.summary生成給tensorboard用的)智嚷,分別是:
checkpoint?- 就是一個(gè)賬本文件卖丸,可以使用高級幫助程序來加載不同的時(shí)間保存的chkp文件。沒什么用
.meta?- 保存壓縮后的Metagraph的protobufs盏道,其實(shí)就是Metagraph稍浆。
.index -?包含一個(gè)不可變的鍵值表,用于鏈接序列化的張量名稱以及在chkp.data文件中查找其數(shù)據(jù)的位置猜嘱,也沒存什么實(shí)際東西
.data - 這個(gè)里面才是存了訓(xùn)練后的參數(shù)衅枫。通常比.meta要大。有的時(shí)候有多個(gè)data文件用于共享或創(chuàng)建多個(gè)訓(xùn)練的時(shí)間戳朗伶。
其中.data文件的名字一般都是這種格式的:
<prefix>-<global_step>.data-<shard_index>-of-<number_of_shards>.
比如:
所以saver.restore()的時(shí)候其實(shí)是restore的.data文件弦撩。當(dāng)然在restore之前可以用tf.train.latest_checkpoint()來得到最后一次存儲點(diǎn)。還有一點(diǎn)是在saver.save()和restore的時(shí)候腕让,那個(gè)文件對象是xxx.ckpt孤钦。但實(shí)際上在存儲文件夾里你找不到xxx.ckpt文件歧斟。這個(gè)也是正常的纯丸。官方文檔有說.ckpt文件其實(shí)是隱性的的。所以除非你文件名字輸入錯(cuò)了静袖,不然不用擔(dān)心讀錯(cuò)文件觉鼻。
下面結(jié)合我的實(shí)例再看看怎么合并graph。
實(shí)例:
先稍微介紹一下網(wǎng)絡(luò)的結(jié)構(gòu)队橙。我有四個(gè)網(wǎng)絡(luò)結(jié)構(gòu)坠陈。其中3個(gè)網(wǎng)絡(luò)是平行的,這里就叫p1捐康,p2和p3吧仇矾。最后一個(gè)網(wǎng)絡(luò)是微調(diào)用的,就叫m吧解总。這個(gè)m會(huì)得到3個(gè)網(wǎng)絡(luò)的輸出贮匕,合并在一起作為m的輸入,輸入到m花枫,最后得到最終結(jié)果刻盐。為了方便理解我畫了個(gè)圖掏膏。
如果直接訓(xùn)練這么大的網(wǎng)絡(luò),收斂起來一定很費(fèi)勁敦锌,有可能某一個(gè)網(wǎng)絡(luò)落到一個(gè)local minimum就出不去了馒疹。所以我們把p1,p2乙墙,p3拿出來單獨(dú)訓(xùn)練颖变,每次只訓(xùn)練一個(gè)。
我分別用數(shù)據(jù)訓(xùn)練這3個(gè)網(wǎng)絡(luò)伶丐。這個(gè)訓(xùn)練階段算是pretrain悼做。待到三個(gè)網(wǎng)絡(luò)都穩(wěn)定的時(shí)候,我把它們的輸出結(jié)果加在一起哗魂,輸入到第四個(gè)網(wǎng)絡(luò)里訓(xùn)練整個(gè)網(wǎng)絡(luò)肛走。
官方文件稱feed_dicts是效率最低的方法,所以我們改用的tfrecord和dataset api來讀取文件录别。如果你不清楚這是啥朽色,可以參看我們辦公室博導(dǎo)的簡書,這家伙可厲害了~
現(xiàn)在有兩個(gè)問題组题,1是用Import_meta_graph導(dǎo)入metagraph的方法沒法合并graph葫男,因?yàn)槲覍懙臄?shù)據(jù)導(dǎo)入之后拿不出來(或者說我不知道怎么拿出來,可能有api可以取出來)崔列。p1,p2,p3的輸出數(shù)據(jù)是要手動(dòng)連接的梢褐。import_graph_def()也可以設(shè)置input,output mapping赵讯,但是我這里沒有tf.placeholder盈咳。我必須拿到一個(gè)從p1,p2,p3合成出來的tensor,再塞到m里去边翼。所以我選擇了用重建graph的方法鱼响。用
traindata, label = data_iterator(tfrecord_path).get_next()?
得到數(shù)據(jù),再把traindata分別放入p1,p2,p3的架構(gòu)中:
out_p1 = networkp1(trandata_p1)
網(wǎng)絡(luò)結(jié)構(gòu)有了组底,再restore參數(shù):
full_path = tf.train.latest_checkpoint(model_ckp)
saver.restore(sess, full_path)
p2和p3也這么做丈积。
三個(gè)全恢復(fù)了會(huì)得到三個(gè)output,再合并
m_data = out_p1 + out_p2 + out_p3
再輸入m中:
output_m = networkm(m_data)
之后再做loss债鸡,bp江滨,summary啥的,就可以訓(xùn)練了厌均。
需要注意的是唬滑,別恢復(fù)錯(cuò)了graph。不要建3個(gè)session下分別用3個(gè)graph恢復(fù),因?yàn)槟菢拥?/p>
m_data = out_p1 + out_p2 + out_p3 #如果三個(gè)out是不同的graph间雀,這里會(huì)報(bào)錯(cuò)
這一步會(huì)報(bào)錯(cuò)悔详。說不同的graph出來的結(jié)果是不能相互運(yùn)算的。大家必須是在同一個(gè)graph里才行惹挟。所以要建一個(gè)session茄螃,在這個(gè)session下挨個(gè)恢復(fù):
with tf.session as sess: # 下面每個(gè)restore里不要單建 with tf.graph():...?
? ? # restore p1
? ? # restore p2
? ? # retore p3
? ? # ....
等于是把大家依次放進(jìn)default graph里。再填上最后的m就ok了连锯。
references:
https://blog.metaflow.fr/tensorflow-saving-restoring-and-mixing-multiple-models-c4c94d5d7125
https://zhuanlan.zhihu.com/p/31308381
https://www.tensorflow.org/api_guides/python/meta_graph#What_s_in_a_MetaGraph
http://www.reibang.com/p/0f9f2bb962f4
stackoverflow:
https://stackoverflow.com/questions/41990014/load-multiple-models-in-tensorflow
https://stackoverflow.com/questions/45093688/how-to-understand-sess-as-default-and-sess-graph-as-default
https://stackoverflow.com/questions/49864234/tensorflow-restoring-variables-from-two-checkpoints-after-combining-two-graphs
https://stackoverflow.com/questions/49490262/combining-graphs-is-there-a-tensorflow-import-graph-def-equivalent-for-c
https://stackoverflow.com/questions/41607144/loading-two-models-from-saver-in-the-same-tensorflow-session