pytorch finetune模型
文章主要講述如何在pytorch上讀取以往訓(xùn)練的模型參數(shù),在模型的名字已經(jīng)變更的情況下又如何讀取模型的部分參數(shù)等。
?????????????????????????????????????????????????????????????????????????????????????? --------作者:jiangwenj02【轉(zhuǎn)載請注明】
pytorch 模型的存儲(chǔ)與讀取
其中在模型的保存過程有存儲(chǔ)模型和參數(shù)一起的也有單獨(dú)存儲(chǔ)模型參數(shù)的
單獨(dú)存儲(chǔ)模型參數(shù)
存儲(chǔ)時(shí)使用:
torch.save(the_model.state_dict(), PATH)
讀取時(shí):
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
存儲(chǔ)模型與參數(shù)
存儲(chǔ):
torch.save(the_model, PATH)
讀取:
the_model = torch.load(PATH)
模型的參數(shù)
fine-tune的過程是讀取原有模型的參數(shù),但是由于模型的所要處理的數(shù)據(jù)集不同,最后的一層class的總數(shù)不同趋观,所以需要修改模型的最后一層,這樣模型讀取的參數(shù)锋边,和在大數(shù)據(jù)集上訓(xùn)練好下載的模型參數(shù)在形式上不一樣皱坛。需要我們自己去寫函數(shù)讀取參數(shù)。
pytorch模型參數(shù)的形式
模型的參數(shù)是以字典的形式存儲(chǔ)的豆巨。
model_dict = the_model.state_dict(),
for k,v in model_dict.items():
print(k)
即可看到所有的鍵值
如果想修改模型的參數(shù)剩辟,給相應(yīng)的鍵值賦值即可
model_dict[k] = new_value
最后更新模型的參數(shù)
the_model.load_state_dict(model_dict)
如果模型的key值和在大數(shù)據(jù)集上訓(xùn)練時(shí)的key值是一樣的
我們可以通過下列算法進(jìn)行讀取模型
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
# 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
如果模型的key值和在大數(shù)據(jù)集上訓(xùn)練時(shí)的key值是不一樣的,但是順序是一樣的
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
keys.append(k)
i = 0
for k,v in model_dict.items():
if v.size() == pretrained_dict[keys[i]].size():
print(k, ',', keys[i])
model_dict[k]=pretrained_dict[keys[i]]
i = i + 1
model.load_state_dict(model_dict)
如果模型的key值和在大數(shù)據(jù)集上訓(xùn)練時(shí)的key值是不一樣的,但是順序是也不一樣的
自己找對應(yīng)關(guān)系抹沪,一個(gè)key對應(yīng)一個(gè)key的賦值