前言
鑒于發(fā)布這篇博客以來已經(jīng)有不少人私信問我如何去轉(zhuǎn)換,我想可能是我的博客并沒有寫明白箫爷,于是我決定把這篇文章完善一下舌稀。
私信我的人有下面這兩種情況:
- 只是自己寫iOS,想嘗試如何在iOS上跑神經(jīng)網(wǎng)絡(luò)模型剖效,偏向于玩的性質(zhì)嫉入。
- 確實(shí)有需求需要轉(zhuǎn)換模型焰盗,偏向于工作的性質(zhì)。
如果你是第一種的話咒林,我推薦直接使用iOS12的CreateML或者我的另一篇文章Apple機(jī)器學(xué)習(xí)庫turicreate實(shí)戰(zhàn)——使用cifar-10數(shù)據(jù)集做圖像分類熬拒。如果你是第二種的話,你可能需要與交付你模型的人溝通一下垫竞。因?yàn)樯婕暗侥P偷妮斎胼敵龅膕hape澎粟,如果你對tensorflow了解不多的話這個(gè)轉(zhuǎn)化過程是十分不友好的。
我的需求
最近有需求需要把tensorflow訓(xùn)練的模型在iOS上使用欢瞪,然后我在GitHub上發(fā)現(xiàn)了一個(gè)叫tf-coreml的庫活烙,他可以把pb模型轉(zhuǎn)化為mlmodel模型。
轉(zhuǎn)換
獲得模型
你的目標(biāo)一定是通過訓(xùn)練來保存模型遣鼓,最后放在xcode里可以用來調(diào)用啸盏。使用tensorflow訓(xùn)練保存的模型有兩種格式:
- checkpoint格式
- pb格式
第一種格式需要你通過代碼把checkpont固化為pb模型,但是這一步的坑特別多骑祟,我嘗試之后覺得不太友好回懦。第二種是訓(xùn)練完用tensorflow直接保存為pb格式的模型,就已經(jīng)滿足轉(zhuǎn)化的條件了次企。
轉(zhuǎn)換前需要了解的知識
如果你對tensorflow并不了解粉怕,我覺得在看下面的轉(zhuǎn)換過程之前,你需要了解一些基本的知識以便你能順利的轉(zhuǎn)換抒巢。先放上我的demo方便講解贫贝。
一般來說,轉(zhuǎn)換的過程中你需要關(guān)注輸入和輸出蛉谜,我隨意舉例:
input = tf.placeholder(tf.float32, shape=[None, 36], name="Input")
# 可能是這樣稚晚,我記不太清了,輸出可以是千奇百怪的
output = tf.layers.dense(inputs=self.layer2, units=1, activation=tf.nn.relu, name="Prediction")
如果這個(gè)神經(jīng)網(wǎng)絡(luò)模型并不是你一手操辦的型诚,那你肯定是不知道你的輸入輸出到底是啥了客燕。這個(gè)時(shí)候你打開我demo里的一個(gè)叫network-info.txt的文件,里面有兩段內(nèi)容:
---------------------------------------------------------------------------------------------------------------------------------------------
0: op name = import/Input, op type = ( Placeholder ), inputs = , outputs = import/Input:0
@input shapes:
@output shapes:
name = import/Input:0 : (?, 36)
---------------------------------------------------------------------------------------------------------------------------------------------
---------------------------------------------------------------------------------------------------------------------------------------------
19: op name = import/Prediction, op type = ( Relu ), inputs = import/add_5:0, outputs = import/Prediction:0
@input shapes:
name = import/add_5:0 : (?, 1)
@output shapes:
name = import/Prediction:0 : (?, 1)
---------------------------------------------------------------------------------------------------------------------------------------------
這個(gè)是根據(jù)我demo里的inspect_pb.py生成的狰贯,這兩個(gè)就是輸入輸出的op也搓。“outputs = import/Input:0"涵紊,最后那個(gè)Input:0就是你需要關(guān)注的內(nèi)容傍妒。你會(huì)發(fā)現(xiàn)他們倆剛好一個(gè)在最前面一個(gè)在最后面。這不是一定的摸柄,但經(jīng)常是這樣的颤练。實(shí)際情況下你可能需要與交付你模型的同事溝通一下到底輸入和輸出是哪一個(gè)。
轉(zhuǎn)換過程
下載inspect_pb.py文件
進(jìn)入tf-coreml的github驱负,然后下載他們那個(gè)庫里utils/下的一個(gè)inspect_pb.py文件嗦玖,如圖1患雇、2。
把這個(gè)py改一下宇挫,里面的方法可以把pb模型圖里的所有信息寫在一個(gè)txt格式的文件里苛吱,如圖3。
你可以在圖3顯示的txt里找到你需要的輸入和輸出信息的全名器瘪,這里你需要找的就與上文所述的一樣翠储。轉(zhuǎn)換代碼如下:
import tfcoreml
tfcoreml.convert(tf_model_path="./model_0.pb",
mlmodel_path="./model.mlmodel",
output_feature_names=['Prediction:0'],
input_name_shape_dict={'Input:0': [1, 36]})
你可能很納悶為啥名字得是'Prediction:0','Input:0'娱局,這我也不知道彰亥,但是我知道tensorflow里面通過名字獲得變量的時(shí)候也得這么寫咧七。
參數(shù)解釋一下衰齐,第一個(gè)參數(shù)是pb模型的路徑,第二個(gè)參數(shù)是生成mlmodel模型的路徑继阻,第三個(gè)是輸出的名字耻涛,這個(gè)輸出的名字需要在上文所述生成的txt文件里有,且是你需要的瘟檩,第四個(gè)是可選的抹缕,像我的Input的shape是[?, 36],因?yàn)槲业妮斎胧窍旅孢@樣的
input = tf.placeholder(tf.float32, shape=[None, 36], name="Input")
這里要保證輸入固定的墨辛,所以改為 [1, 36]卓研。
demo: GitHub