Pytorch 實現(xiàn) MobileNet V3 模型,并從 TensorFlow 轉(zhuǎn)化預(yù)訓(xùn)練參數(shù)

????????隨著移動終端的普及悲龟,以及在其上運行深度學(xué)習(xí)模型的需求屋讶,神經(jīng)網(wǎng)絡(luò)小型化越來越得到重視和關(guān)注,已經(jīng)成為研究的熱門之一须教。作為小型化模型的經(jīng)典代表皿渗,MobileNet 系列模型已經(jīng)先后迭代了 3 代,在保持模型參數(shù)量和運算量都極其小的情況下轻腺,其性能越來越優(yōu)異乐疆。本文我們將實現(xiàn)最新一代的 MobileNet V3,為了能不花費時間在 ImageNet 數(shù)據(jù)集上訓(xùn)練而直接使用贬养,我們將從 TensorFlow 官方實現(xiàn)的 MobileNet V3 上轉(zhuǎn)化預(yù)訓(xùn)練參數(shù)挤土。

????????本文將重點關(guān)注以下兩個方面:

  • 詳細(xì)解讀 MobileNet V3 的網(wǎng)絡(luò)結(jié)構(gòu);
  • 詳細(xì)講述從 TensorFlow 轉(zhuǎn)化預(yù)訓(xùn)練參數(shù)的方法误算;

????????本文所有代碼見 GitHub: mobilenet_v3仰美。

一迷殿、MobileNet V3 模型

二、模型實現(xiàn)

三咖杂、預(yù)訓(xùn)練參數(shù)轉(zhuǎn)化

????????完全采用手動指定的方式進行庆寺,即對于 Pytorch 模型的每一參數(shù),從對應(yīng)的 TensorFlow 預(yù)訓(xùn)練參數(shù)里取出诉字,然后賦值給它即可懦尝。為了保證轉(zhuǎn)化的準(zhǔn)確性,我們的目標(biāo)是:

  • 原 TensorFlow 預(yù)訓(xùn)練模型和轉(zhuǎn)化后的 Pytorch 模型的預(yù)測結(jié)果要絕對一致

以下壤圃,我們詳細(xì)的來描述怎么從 TensorFlow 轉(zhuǎn)化預(yù)訓(xùn)練參數(shù)导披。

1.查看 TensorFlow 預(yù)訓(xùn)練模型參數(shù)名

????????首先到 此頁 下載 MobileNet V3 模型的 TensorFlow 預(yù)訓(xùn)練模型,下載后請解壓埃唯。我們以 large dm=1 (float) 預(yù)訓(xùn)練模型為例來說明撩匕。首先,使用如下代碼:

import json
import tensorflow as tf

if __name__ == '__main__':
    checkpoint_path = 'xxx/v3-large_224_1.0_float/ema/model-540000'
    output_path = './mobilenet_v3_large.json'

    reader = tf.train.NewCheckpointReader(checkpoint_path)
    weights = {var: 1 for (var, _) in
               reader.get_variable_to_shape_map().items()}
    
    with open(output_path, 'w') as writer:
        json.dump(weights, writer)

預(yù)訓(xùn)練模型中的所有參數(shù)名都寫到一個 json 文件里墨叛,為了不把高維的數(shù)據(jù)寫進去止毕,我們都將確切的值改成了 1。但直接寫進去的內(nèi)容很亂漠趁,可以借助 json 串格式化的工具(比如扁凛,在線格式化,或者 Google Chrome 瀏覽器插件 FeHelper)將 mobilenet_v3_large.json 文件里的內(nèi)容格式化闯传,這樣你看到的形式就大概如下了:

格式化之后的 mobilenet_v3_large.json 內(nèi)容

接著谨朝,結(jié)合 TensorFlow 官方開源的 MobileNet V3 Large 模型的網(wǎng)絡(luò)定義

MobileNet V3 large 模型 TensorFlow 網(wǎng)絡(luò)定義

就基本可以知道整個模型參數(shù)命名的具體名字和順序了:

MobilenetV3/Conv/
MobilenetV3/expanded_conv/
MobilenetV3/expanded_conv_1/
...
MobilenetV3/expanded_conv_14/
MobilenetV3/Conv_1/
MobilenetV3/Conv_2/
MobilenetV3/Logits/Conv2d_1c_1x1

以上是 large 模型的總共 19 個大的命名空間(scope),每個 / 之后會接小的命名空間甥绿。對于普通的卷積層字币,比如 Conv, Conv_1, Conv_2, Logits/Conv2d_1c_1x1 你要關(guān)注兩個東西:

  • 是否有偏置參數(shù):biases
  • 是否有批標(biāo)準(zhǔn)化:BatchNorm

這既可以幫助你修正你定義的 Pytorch 模型共缕,也可以在轉(zhuǎn)化賦值的時候防止被遺忘洗出。類似的思想可以直接移植到復(fù)雜的模塊 mbv3_op 對應(yīng)的命名空間,expanded_conv, expanded_conv_1, ...图谷。舉個簡單的例子翩活,看 large 模型的第一卷積層:MobilenetV3/Conv/,因為該層使用了批標(biāo)準(zhǔn)化(batch normalization)便贵,因此是沒有偏置參數(shù)的菠镇,那么就只有如下的 5 個參數(shù):

MobilenetV3/Conv/weights,
MobilenetV3/BatchNorm/beta,
MobilenetV3/Conv/BatchNorm/gamma
MobilenetV3/Conv/BatchNorm/moving_mean
MobilenetV3/Conv/BatchNorm/moving_variance

其中后 4 個參數(shù)對應(yīng)于批標(biāo)準(zhǔn)化的公式:
\gamma \frac{x - \mu}{\sigma} + \beta. \\
再看 large 模型的最后一個卷積層(分類層):MobilenetV3/Logits/Conv2d_1c_1x1,因為該層沒有使用批標(biāo)準(zhǔn)化的正規(guī)化函數(shù)承璃,因此帶有偏置項利耍,就只有兩個參數(shù):

MobilenetV3/Logits/Conv2d_1c_1x1/weights
MobilenetV3/Logits/Conv2d_1c_1x1/biases

至于其他復(fù)雜模塊,分割開單獨考慮中間命名空間: project, expand, depthwise, squeeze_excite 之后,其實就是簡單的卷積層了堂竟,因此也很容易處理。

2.查看 Pytorch 模型結(jié)構(gòu)

????????這一步更容易玻佩,直接實例化定義的 Pytorch 模型出嘹,然后打印出來(這里,模型的所有的層都定義在了屬性 _layers 里咬崔,見 mobilenet_v3.MobileNet 類):

import mobilenet_v3

large = mobilenet_v3.large()
print(large._layers[:10])
print(large._layers[10:])

因為模型結(jié)構(gòu)很長税稼,所以打印的時候分成了前后兩部分。保存在 txt 文件里如下:

MobileNet V3 large 模型網(wǎng)絡(luò)結(jié)構(gòu)

這一步我們唯一需要關(guān)注的就是每一層在網(wǎng)絡(luò)結(jié)構(gòu)里的下標(biāo)了垮斯,比如 _layers[0] 就是整個網(wǎng)絡(luò)的第 1 個卷積層模塊郎仆,而 _layers[0]._layers[0] 是這個模塊內(nèi)的二維卷積層,_layers[0]._layers[1] 是這個模塊內(nèi)的批標(biāo)準(zhǔn)化層兜蠕。因為 torch.nn.Sequential 的行為和 list 一樣扰肌,因此它們的順序是確定不變的,取下標(biāo)是非常安全的操作熊杨。

3.對照參數(shù)名逐一賦值

????????經(jīng)過前面兩步之后曙旭,應(yīng)該對 TensorFlow 預(yù)訓(xùn)練模型 和 Pytorch 定義的模型結(jié)構(gòu) 之間的對應(yīng)關(guān)系應(yīng)該有所印象了,下面需要將它們嚴(yán)格的對應(yīng)起來晶府,以便預(yù)訓(xùn)練參數(shù)轉(zhuǎn)化桂躏。

????????首先,看第一個卷積模塊川陆,它包含一個卷積層剂习、批標(biāo)準(zhǔn)化層和一個激活函數(shù)層,其中只有前兩者是有訓(xùn)練參數(shù)的较沪。而且鳞绕,根據(jù)第一步,我們知道對應(yīng)的 TensorFlow 模型這一個模塊的命名空間是:MobilenetV3/Conv/尸曼,因此如果我聲明了

import mobilenet_v3

model = mobilenet_v3.large()

large 模型猾昆,那么對應(yīng)的第 1 個卷積模塊的二維卷積層是 model._layers[0]._layers[0],批標(biāo)準(zhǔn)化層是 model._layers[0]._layers[1]骡苞。它們所含有的參數(shù)如下:

model._layers[0]._layers[0].weight
model._layers[0]._layers[1].bias:
model._layers[0]._layers[1].weight
model._layers[0]._layers[1].running_mean
model._layers[0]._layers[1].running_var

即卷積層的權(quán)重參數(shù)(對于 slim.conv2d()垂蜗,如果指定了正規(guī)化函數(shù),即關(guān)鍵字參數(shù) normalizer_fn 不為 None解幽,那么這個卷積層是沒有偏置項的贴见;反之,則有躲株,除非將偏置的初始化函數(shù) biases_initializer 設(shè)為 None)片部,和批標(biāo)準(zhǔn)化層的 4 個參數(shù):
\gamma \frac{x - \mu}{\sigma} + \beta. \\
很容易的,你可以從 mobilenet_v3_large.json 里找到對應(yīng)的 TensorFlow 變量名:

conversion_map_for_root_block = {
    model._layers[0]._layers[0].weight: 
        'MobilenetV3/Conv/weights',
    model._layers[0]._layers[1].bias: 
        'MobilenetV3/Conv/BatchNorm/beta',
    model._layers[0]._layers[1].weight:
        'MobilenetV3/Conv/BatchNorm/gamma',
    model._layers[0]._layers[1].running_mean: 
        'MobilenetV3/Conv/BatchNorm/moving_mean',
    model._layers[0]._layers[1].running_var: 
        'MobilenetV3/Conv/BatchNorm/moving_variance',
}

然后用函數(shù) tf.train.load_variable霜定,按照 TensorFlow 的變量名從預(yù)訓(xùn)練模型中取出變量的名字賦值給對應(yīng)的 Pytorch 變量档悠,比如:

checkpoint_path = 'xxx/v3-large_224_1.0_float/ema/model-540000'

tf_param = tf.train.load_variable(checkpoint_path, 'MobilenetV3/Conv/weights')
tf_param = np.transpose(tf_param, (3, 2, 0, 1))
model._layers[0]._layers[0].weight.data = torch.from_numpy(tf_param)

就將第 1 個卷積層的參數(shù)轉(zhuǎn)化好了廊鸥。這里,唯一需要注意的是辖所,TensorFlow 權(quán)重的順序是 [kernel_size, kernel_size, in_channels, out_channels]惰说,而 Pytorch 的順序是 [out_channels, in_channels, kernel_size, kernel_size],因此要將它們的順序調(diào)整到一致缘回。

????????其它參數(shù)完全按照一樣的方式轉(zhuǎn)化即可吆视。完整的轉(zhuǎn)化代碼請見 converter.py

????????以上過程結(jié)束之后酥宴,我們來轉(zhuǎn)化幾個模型

1.large 模型

????????執(zhí)行(tf_checkpoint_path 參數(shù)指定 TensorFlow 預(yù)訓(xùn)練模型參數(shù)的保存路徑):

python3 tf_weights_to_pth.py --tf_checkpoint_path xxx/v3-large_224_1.0_float/ema/model-540000

將在當(dāng)前項目路徑下生成一個 pretrained_models 文件夾啦吧,里面保存了轉(zhuǎn)化后的模型:mobilenet_v3_large.pth,同時將輸出測試圖片(熊貓圖片):

panda.jpg

的分類結(jié)果:

large 模型 TensorFlow 原預(yù)訓(xùn)練模型和轉(zhuǎn)化的 Pytorch 模型對熊貓圖片的識別結(jié)果

可以看到兩者的結(jié)果是一模一樣的拙寡。類似的授滓,再指定另一張測試圖片(貓圖片),執(zhí)行以下命令(image_path 參數(shù)指定測試圖片的路徑):

python3 tf_weights_to_pth.py --tf_checkpoint_path xxx/v3-large_224_1.0_float/ema/model-540000 \
    --image_path ./test/cat.jpg
cat.jpg

就可以看到對貓的分類結(jié)果:

large 模型 TensorFlow 原預(yù)訓(xùn)練模型和轉(zhuǎn)化的 Pytorch 模型對貓圖片的識別結(jié)果

顯然肆糕,TensorFlow 官方和本文實現(xiàn)的 Pytorch 模型的預(yù)測結(jié)果也是一模一樣的褒墨。

2.small 模型(depth_multiplier = 0.75)

執(zhí)行(output_name 指定轉(zhuǎn)化來的模型的保存名字,depth_multiplier 指定卷積層的通道數(shù)乘子擎宝,model_name 指定轉(zhuǎn)化的模型名):

python3 tf_weights_to_pth.py --tf_checkpoint_path xxx/v3-small_224_0.75_float/ema/model-497500 \
    --output_name mobilenet_v3_small_0.75.pth --depth_multiplier 0.75 --model_name small

得到熊貓圖片的分類結(jié)果:

small-dm=0.75 模型 TensorFlow 原預(yù)訓(xùn)練模型和轉(zhuǎn)化的 Pytorch 模型對熊貓圖片的識別結(jié)果

也得到一模一樣的結(jié)果郁妈,說明轉(zhuǎn)化參數(shù)是正確的。

????????當(dāng)前支持參數(shù)轉(zhuǎn)化的預(yù)訓(xùn)練模型如下:

本文所有支持參數(shù)轉(zhuǎn)化的預(yù)訓(xùn)練模型

對應(yīng)的模型名(由 model_name 參數(shù)指定)分別為:large, small, large_minimalistic, small_minimalistic绍申,如果 dm=0.75噩咪,請指定參數(shù) depth_multiplier。你可以逐一轉(zhuǎn)化并驗證本文定義的 MobileNet V3 模型的正確性极阅,不出意外應(yīng)該是準(zhǔn)確的(作者未轉(zhuǎn)化 8-bit 的預(yù)訓(xùn)練模型)胃碾。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市筋搏,隨后出現(xiàn)的幾起案子仆百,更是在濱河造成了極大的恐慌,老刑警劉巖奔脐,帶你破解...
    沈念sama閱讀 216,372評論 6 498
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件俄周,死亡現(xiàn)場離奇詭異,居然都是意外死亡髓迎,警方通過查閱死者的電腦和手機峦朗,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,368評論 3 392
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來排龄,“玉大人波势,你說我怎么就攤上這事。” “怎么了尺铣?”我有些...
    開封第一講書人閱讀 162,415評論 0 353
  • 文/不壞的土叔 我叫張陵拴曲,是天一觀的道長。 經(jīng)常有香客問我凛忿,道長澈灼,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,157評論 1 292
  • 正文 為了忘掉前任侄非,我火速辦了婚禮,結(jié)果婚禮上流译,老公的妹妹穿的比我還像新娘逞怨。我一直安慰自己,他們只是感情好福澡,可當(dāng)我...
    茶點故事閱讀 67,171評論 6 388
  • 文/花漫 我一把揭開白布叠赦。 她就那樣靜靜地躺著,像睡著了一般革砸。 火紅的嫁衣襯著肌膚如雪除秀。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,125評論 1 297
  • 那天算利,我揣著相機與錄音册踩,去河邊找鬼。 笑死效拭,一個胖子當(dāng)著我的面吹牛暂吉,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播缎患,決...
    沈念sama閱讀 40,028評論 3 417
  • 文/蒼蘭香墨 我猛地睜開眼慕的,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了挤渔?” 一聲冷哼從身側(cè)響起肮街,我...
    開封第一講書人閱讀 38,887評論 0 274
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎判导,沒想到半個月后嫉父,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,310評論 1 310
  • 正文 獨居荒郊野嶺守林人離奇死亡眼刃,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,533評論 2 332
  • 正文 我和宋清朗相戀三年熔号,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片鸟整。...
    茶點故事閱讀 39,690評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡引镊,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情弟头,我是刑警寧澤吩抓,帶...
    沈念sama閱讀 35,411評論 5 343
  • 正文 年R本政府宣布,位于F島的核電站赴恨,受9級特大地震影響疹娶,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜伦连,卻給世界環(huán)境...
    茶點故事閱讀 41,004評論 3 325
  • 文/蒙蒙 一雨饺、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧惑淳,春花似錦额港、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,659評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽贩幻。三九已至仓手,卻和暖如春睡互,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背舰涌。 一陣腳步聲響...
    開封第一講書人閱讀 32,812評論 1 268
  • 我被黑心中介騙來泰國打工猖任, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人瓷耙。 一個月前我還...
    沈念sama閱讀 47,693評論 2 368
  • 正文 我出身青樓超升,卻偏偏與公主長得像,于是被迫代替她去往敵國和親哺徊。 傳聞我的和親對象是個殘疾皇子室琢,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,577評論 2 353

推薦閱讀更多精彩內(nèi)容