問(wèn)題描述
因?yàn)樵趯?shí)際的深度學(xué)習(xí)中音同,可能在自己的base網(wǎng)絡(luò)基礎(chǔ)之上對(duì)網(wǎng)絡(luò)進(jìn)行一些增刪操作词爬,比如說(shuō)有些attention模塊可以說(shuō)是即插即用的,在這樣的情況之下权均,我們修改一小部分網(wǎng)絡(luò)結(jié)構(gòu)后希望在訓(xùn)練的初期將之前的沒(méi)有改變的網(wǎng)絡(luò)層訓(xùn)練好的參數(shù)進(jìn)行加載顿膨,以節(jié)約自己的模型訓(xùn)練時(shí)間:先放一段代碼:
import torch
from network.res_unet import ResUNet
net = ResUNet(in_ch=3, out_ch=3)
old_net = torch.load('runs/resUnet_3class/checkpoint/cp_030.pth', map_location='cpu')
print(type(old_net))
i = 0
for key, v in old_net['net'].items():
if i < 2:
i += 1
print(key)
print(v)
print('--------------------------------------------------------------------------------------------')
i = 0
for key, v in net.state_dict().items():
if i < 2:
i+=1
print(key)
print(v)
net.load_state_dict(old_net['net'],strict=False)
print('--------------------------------------------------------------------------------------------')
i = 0
for key, v in net.state_dict().items():
if i < 2:
i += 1
print(key)
print(v)
print('end_signal')
此時(shí)net就是resnet其實(shí)和加載的cp_030.pth的網(wǎng)絡(luò)是一模一樣的,先以其舉例進(jìn)行一下說(shuō)明(其中只打印了網(wǎng)絡(luò)的前2層)叽赊。
說(shuō)明:
我的torch.load出來(lái)的old_net出來(lái)的type是一個(gè)dict恋沃,其中old_net['net']是具體的orderdict網(wǎng)絡(luò)參數(shù),因而在load_state_dict注意第一個(gè)入?yún)⑹莖ld_net['net']必指。(有時(shí)候torch.load下來(lái)的是一個(gè)orderdict的類囊咏,此時(shí)net.load_state_dict(old_net,strict=False))即可。
輸出的結(jié)果展示:
1塔橡、old_net['net']第一層參數(shù):
2梅割、未經(jīng)load_state_dict的net第一個(gè)參數(shù):
3、經(jīng)過(guò)load_state_dict的net第一個(gè)參數(shù):
由此可見(jiàn)葛家,經(jīng)過(guò)load_state_dict的net已經(jīng)將原始o(jì)ld_net的參數(shù)載入炮捧。
接下來(lái),對(duì)Net結(jié)構(gòu)進(jìn)行修改:
此時(shí)打印了一下網(wǎng)絡(luò)修改先后的前四組參數(shù)名稱惦银,發(fā)現(xiàn)多了兩個(gè)參數(shù):
然后我們?cè)龠M(jìn)行一次load_state_dict對(duì)比新添加的inc.0.bias和舊網(wǎng)絡(luò)有的參數(shù)inc.1.bias:
1、舊網(wǎng)絡(luò)參數(shù):
2、新網(wǎng)絡(luò)未經(jīng)load_state_dict的參數(shù):
3扯俱、新網(wǎng)絡(luò)經(jīng)load_state_dict的參數(shù):
由圖可知:改了網(wǎng)絡(luò)以后再利用Load_state_dict進(jìn)行網(wǎng)絡(luò)參數(shù)裝載會(huì)自適應(yīng)賦值书蚪,有參數(shù)就覆蓋,沒(méi)有就不迅栅。這主要是通過(guò)Load_state_dict的API中strict=False字段決定的殊校。