最近看到一個(gè)巨牛的人工智能教程唆铐,分享一下給大家。教程不僅是零基礎(chǔ)奔滑,通俗易懂艾岂,而且非常風(fēng)趣幽默,像看小說(shuō)一樣朋其!覺(jué)得太牛了王浴,所以分享給大家。平時(shí)碎片時(shí)間可以當(dāng)小說(shuō)看梅猿,【點(diǎn)這里可以去膜拜一下大神的“小說(shuō)”】氓辣。
Tensorflow官方提供的Tensorboard可以可視化神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)圖,但是說(shuō)實(shí)話袱蚓,我?guī)缀鯊膩?lái)不用钞啸。主要是因?yàn)門(mén)ensorboard中查看到的圖結(jié)構(gòu)太混亂了,包含了網(wǎng)絡(luò)中所有的計(jì)算節(jié)點(diǎn)(讀取數(shù)據(jù)節(jié)點(diǎn)喇潘、網(wǎng)絡(luò)節(jié)點(diǎn)体斩、loss計(jì)算節(jié)點(diǎn)等等)。更可怕的是颖低,如果一個(gè)計(jì)算節(jié)點(diǎn)是由多個(gè)基礎(chǔ)計(jì)算(如加減乘除等)構(gòu)成几颜,那么在Tensorboard中會(huì)將基礎(chǔ)計(jì)算節(jié)點(diǎn)顯示而不是作為一個(gè)整體顯示(典型的如Squeeze計(jì)算節(jié)點(diǎn))缨睡。最近為了排查網(wǎng)絡(luò)結(jié)構(gòu)BUG花費(fèi)一周時(shí)間窘俺,因此赌髓,狠下心來(lái)決定自己寫(xiě)一個(gè)工具腕巡,將Tensorflow中的圖以最簡(jiǎn)單的方式顯示最關(guān)鍵的網(wǎng)絡(luò)結(jié)構(gòu)砂碉。
1 Tensor對(duì)象與Operation對(duì)象
Tensorflow中恋谭,Tensor對(duì)象主要用于存儲(chǔ)數(shù)據(jù)如常量和變量(訓(xùn)練參數(shù))弥虐,Operation對(duì)象是計(jì)算節(jié)點(diǎn)话浇,如卷積計(jì)算脏毯、反卷積計(jì)算、ReLU等等幔崖。每一個(gè)Operation對(duì)象均有輸入和輸出Tensor食店,同理渣淤,每個(gè)Tensor對(duì)象均有對(duì)應(yīng)生成該Tensor的Operation對(duì)象和使用該Tensor對(duì)象作為輸入的Operation對(duì)象。Tensor和Operation對(duì)象內(nèi)均有相關(guān)屬性和函數(shù)來(lái)獲取其關(guān)聯(lián)的Operation和Tensor對(duì)象吉嫩,相關(guān)屬性如下所示价认。
Tensor對(duì)象的op屬性指向生成該Tensor的Operation對(duì)象。
Tensor對(duì)象的consumers()函數(shù)獲取使用該Tensor對(duì)象作為輸入的Operation對(duì)象自娩。
Operation對(duì)象的inputs屬性指向該計(jì)算節(jié)點(diǎn)的輸入Tensor對(duì)象用踩。
Operation對(duì)象的outputs屬性執(zhí)行該計(jì)算節(jié)點(diǎn)的輸出Tensor對(duì)象。
如下圖所示的網(wǎng)絡(luò)結(jié)構(gòu)中忙迁,調(diào)用Tensor_2
對(duì)象的consumers()
函數(shù)脐彩,返回的是[op_1,op_2]
。Tensor_3
的op屬性指向的是op_1
姊扔。op_1
的inputs屬性指向的是[Tensor_1,Tensor_2]
惠奸,op_1
的output屬性指向的是[Tensor_3]
。
有了Tensor與Operation對(duì)應(yīng)在圖中的關(guān)聯(lián)關(guān)系恰梢,就可以將網(wǎng)絡(luò)結(jié)構(gòu)給畫(huà)出來(lái)佛南。
2 提取pb文件中的網(wǎng)絡(luò)結(jié)構(gòu)圖
pb文件是將模型參數(shù)固化到圖文件中,并合并了一些基礎(chǔ)計(jì)算和刪除了反向傳播相關(guān)計(jì)算得到的protobuf協(xié)議文件嵌言。如果讀者還不懂如何將CKPT模型文件轉(zhuǎn)pb文件共虑,請(qǐng)參考我另一篇文章《 Tensorflow MobileNet移植到Android》的第1節(jié)部分。有了pb模型文件后呀页,接下來(lái)是加載模型妈拌,加載pb模型示例代碼如下所示。
def read_graph_from_pb(tf_model_path ,input_names,output_name):
with open(tf_model_path, 'rb') as f:
serialized = f.read()
tf.reset_default_graph()
gdef = tf.GraphDef()
gdef.ParseFromString(serialized)
with tf.Graph().as_default() as g:
tf.import_graph_def(gdef, name='')
with tf.Session(graph=g) as sess:
OPS=get_ops_from_pb(g,input_names,output_name)
return OPS
其中蓬蝶,倒數(shù)第2行調(diào)用到的函數(shù)get_ops_from_pb()
用于獲取網(wǎng)絡(luò)結(jié)構(gòu)圖中指定輸入節(jié)點(diǎn)和指定輸出節(jié)點(diǎn)之間的計(jì)算節(jié)點(diǎn)尘分。之所以要指定輸入和輸出,是為了將輸入之前的計(jì)算節(jié)點(diǎn)(如加載數(shù)據(jù)隊(duì)列等相關(guān)計(jì)算節(jié)點(diǎn))和輸出之后的計(jì)算節(jié)點(diǎn)(如計(jì)算loss等相關(guān)計(jì)算節(jié)點(diǎn))去除丸氛,免得礙眼培愁。函數(shù)get_ops_from_pb()
實(shí)現(xiàn)代碼如下。
def get_ops_from_pb(graph,input_names,output_name,save_ori_network=True):
if save_ori_network:
with open('ori_network.txt','w+') as w:
OPS=graph.get_operations()
for op in OPS:
txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
w.write(txt+'\n')
inputs_tf = [graph.get_tensor_by_name(input_name) for input_name in input_names]
output_tf =graph.get_tensor_by_name(output_name)
OPS =get_ops_from_inputs_outputs(graph, inputs_tf,[output_tf] )
with open('network.txt','w+') as w:
for op in OPS:
txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
w.write(txt+'\n')
OPS = sort_ops(OPS)
OPS = merge_layers(OPS)
return OPS
在裁剪網(wǎng)絡(luò)結(jié)構(gòu)(即只保留input_names和output_name之間節(jié)點(diǎn))之前缓窜,先將原始的網(wǎng)絡(luò)結(jié)構(gòu)寫(xiě)入到ori_network.txt
中定续,文件中,每一行寫(xiě)入:輸入Tensor---->op---->輸出Tensor
禾锤。接下來(lái)調(diào)用函數(shù)get_ops_from_inputs_outputs
獲取指定節(jié)點(diǎn)之間的節(jié)點(diǎn)私股。并調(diào)用sort_ops
函數(shù)對(duì)所有的節(jié)點(diǎn)排序,以保證被依賴(lài)的節(jié)點(diǎn)總是出現(xiàn)在相關(guān)節(jié)點(diǎn)之前恩掷。最后調(diào)用merge_layers
函數(shù)倡鲸,將一些可以合并的計(jì)算合并成一個(gè)獨(dú)立的節(jié)點(diǎn),例如黄娘,Squeeze
計(jì)算相關(guān)節(jié)點(diǎn)合并成一個(gè)單獨(dú)的Squeeze節(jié)點(diǎn)峭状,又如const-->identity
兩個(gè)計(jì)算節(jié)點(diǎn)可以直接忽略(即刪除)克滴。
注意:篇幅有限,這里不再將函數(shù)
get_ops_from_inputs_outputs
优床、sort_ops
劝赔、merge_layers
貼出,相關(guān)代碼請(qǐng)前往文尾提供的源碼地址中閱讀胆敞。
3 繪制網(wǎng)絡(luò)結(jié)構(gòu)
考慮到SVG
繪制圖形的簡(jiǎn)單易用優(yōu)點(diǎn)望忆,將排好序的網(wǎng)絡(luò)計(jì)算節(jié)點(diǎn)和相關(guān)Tensor
對(duì)象數(shù)據(jù)以Javascript
字符串的形式寫(xiě)入到HTML
中,使用<line>
標(biāo)簽繪制箭頭竿秆,使用<rect>
標(biāo)簽繪制矩形启摄,使用<ellipse>
標(biāo)簽繪制橢圓,使用<text>
標(biāo)簽顯示文字幽钢。繪制類(lèi)似于如下所示圖像
注意:篇幅有限歉备,這里不再介紹Javascript代碼解析模型結(jié)構(gòu)和SVG顯示相關(guān)的原理,相關(guān)代碼請(qǐng)前往文尾提供的源碼地址中閱讀匪燕。
4 測(cè)試模型顯示
以《MobileNet V1官方預(yù)訓(xùn)練模型的使用》文中介紹的MobileNet V1網(wǎng)絡(luò)結(jié)構(gòu)為例蕾羊,下載MobileNet_v1_1.0_192
文件并壓縮后,得到mobilenet_v1_1.0_192_frozen.pb
文件帽驯。我們還需要知道mobilenet_v1_1.0_192_frozen.pb
模型對(duì)應(yīng)的輸入和輸出Tensor
對(duì)象的名稱(chēng)龟再,好在MobileNet_v1_1.0_192
壓縮包中包含文件mobilenet_v1_1.0_192_info.txt
。通過(guò)該文件可知尼变,輸入Tensor
的名稱(chēng)為:input:0
利凑,輸出Tensor名稱(chēng)為:MobilenetV1/Predictions/Reshape_1:0
。有了這些信息后嫌术,調(diào)用函數(shù)read_graph_from_pb
得到靜態(tài)圖的節(jié)點(diǎn)列表對(duì)象ops哀澈,調(diào)用函數(shù)gen_graph(ops,"save/path/graph.html")
后,在目錄save/path
中得到graph.html
文件度气,打開(kāi)graph.html
后割按,顯示結(jié)果如下。
顯示網(wǎng)絡(luò)結(jié)構(gòu)分兩種模式:合并模式和展開(kāi)模式磷籍,分別如下圖所示适荣。