[圖像算法]-(yolov5.train)-Pytorch保存和加載模型完全指南: 關(guān)于使用Pytorch讀寫模型的一切方法

??????本文是一篇關(guān)于如何用Pytorch保存和加載模型的指南


文章目錄

  • 1.讀寫tensor
    • 1.1 單個(gè)張量
    • 1.2 張量列表和張量詞典
  • 2.保存和加載模型
    • 2.1 state_dict
    • 2.2 保存和加載
      • 2.2.1 保存和加載state_dict(推薦方式)
      • 2.2.2 保存和讀寫整個(gè)模型
      • 2.2.3 保存和加載checkpiont
      • 2.2.4 在一個(gè)文件中保存多個(gè)模型
    • 2.3 使用來自不同模型的參數(shù)進(jìn)行模型熱啟動(dòng)
  • 3.跨設(shè)備保存和加載模型
    • 3.1 在GPU中保存,在CPU中加載
    • 3.2 在GPU中保存怜森,在GPU中加載
    • 3.3 在CPU中保存,在GPU中加載
  • 4.保存torch.nn.DataParallel的模型

本文主要涉及到3個(gè)函數(shù):

  • 1.torch.save: 使用Python的pickle實(shí)用程序?qū)?duì)象進(jìn)行序列化嘿悬,然后將序列化的對(duì)象保存到disk,可以保存各種對(duì)象,包括模型水泉、張量和字典等善涨。
  • 2.torch.load: 使用pickle unpickle工具將pickle的對(duì)象文件反序列化為內(nèi)存。
  • 3.torch.nn.Module.load_state_dict: 用反序列化的state_dict來加載模型參數(shù)草则。

1.讀寫tensor

1.1單個(gè)張量

import torch

x = torch.tensor([3.,4.])
torch.save(x, 'x.pt')
x1 = torch.load('x.pt')
print(x1)

輸出:

tensor([3., 4.])

1.2張量列表和張量詞典

y = torch.ones((4,2))
torch.save([x,y],'xy.pt')
torch.save({'x':x, 'y':y}, 'xy_dict.pt')
xy = torch.load('xy.pt')
xy_dict = torch.load('xy_dict.pt')
print(xy)
print(xy_dict)

輸出:

[tensor([3., 4.]), tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])]
{'x': tensor([3., 4.]), 'y': tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])}

2.保存和加載模型

2.1state_dict

state_dict是一個(gè)從每一個(gè)層的名稱映射到這個(gè)層的參數(shù)Tesnor的字典對(duì)象钢拧。

注意,只有具有可學(xué)習(xí)參數(shù)的層(卷積層畔师、線性層等)和注冊(cè)緩存(batchnorm’s running_mean)才有state_dict中的條目娶靡。優(yōu)化器(torch.optim)也有一個(gè)state_dict牧牢,其中包含關(guān)于優(yōu)化器狀態(tài)以及所使用的超參數(shù)的信息看锉。

from torch import nn
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP()
print(net.state_dict())
print('\n',net.state_dict()['output.weight'])

optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print(optimizer.state_dict())

輸出:

OrderedDict([('hidden.weight', tensor([[ 0.0405, -0.0659, -0.5540],
        [ 0.2954,  0.0676, -0.1933]])), ('hidden.bias', tensor([-0.1628,  0.0768])), ('output.weight', tensor([[-0.4635,  0.4958]])), ('output.bias', tensor([-0.5440]))])

 tensor([[-0.4635,  0.4958]])
{'state': {}, 'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3]}]}

2.2保存和加載

PyTorch中保存和加載訓(xùn)練模型有兩種常見的方法:

    1. 僅保存和加載模型參數(shù)(state_dict)姿锭;
    1. 保存和加載整個(gè)模型。
2.2.1保存和加載state_dict(推薦方式)
torch.save(net.state_dict(), 'net_state_dict.pt')## 后綴名一般寫為: .pt或.pth
net1 = MLP()
net1.load_state_dict(torch.load('net_state_dict.pt'))
print(net1)

輸出:

MLP(
  (hidden): Linear(in_features=3, out_features=2, bias=True)
  (act): ReLU()
  (output): Linear(in_features=2, out_features=1, bias=True)
)

注意: load_state_dict() 接受一個(gè)詞典對(duì)象伯铣,而不是一個(gè)指向?qū)ο蟮穆窂缴氪恕R虼四阈枰仁褂?code>torch.load()來反序列化。比如腔寡,你不能直接這么用model.load_state_dict(PATH)焚鲜。

2.2.2保存和讀寫整個(gè)模型

這個(gè)就相對(duì)來說比較簡(jiǎn)單了。

torch.save(net, 'net.pt')
net2 = torch.load('net.pt')
print(net2)

輸出:

MLP(
  (hidden): Linear(in_features=3, out_features=2, bias=True)
  (act): ReLU()
  (output): Linear(in_features=2, out_features=1, bias=True)
)

注意:以這種方式保存模型將使用Python的pickle模塊保存整個(gè)模塊放前。 這種方法的缺點(diǎn)是序列化的數(shù)據(jù)被綁定到特定的類忿磅,并且在保存模型時(shí)使用了確切的詞典結(jié)構(gòu)。 這樣做的原因是因?yàn)閜ickle不會(huì)保存模型類本身凭语。 而是將其保存到包含這個(gè)類的文件的路徑葱她,該路徑在加載時(shí)使用。 因此似扔,在其他項(xiàng)目中使用或重構(gòu)后吨些,您的代碼可能會(huì)以各種方式中斷。

2.2.3保存和加載checkpiont
## Save
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

###########################
## Load
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

保存用于檢查或繼續(xù)訓(xùn)練的checkpiont時(shí)炒辉,你必須保存的不只是模型的state_dict豪墅。 保存優(yōu)化器的state_dict也很重要,因?yàn)樗S著模型訓(xùn)練而更新的緩沖區(qū)和參數(shù)黔寇。 你可能要保存的其他項(xiàng)目包括你未設(shè)置的時(shí)間段偶器,最新記錄的訓(xùn)練損失,外部torch.nn.Embedding層等缝裤。

2.2.4在一個(gè)文件中保存多個(gè)模型
#Save
torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)


#Load
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

常見的PyTorch約定是使用.tar文件擴(kuò)展名保存這些檢查點(diǎn)状囱。

2.3使用來自不同模型的參數(shù)進(jìn)行模型熱啟動(dòng)

這種方法一般用于遷移學(xué)習(xí)。利用經(jīng)過訓(xùn)練的參數(shù)倘是,即使只有少數(shù)幾個(gè)可用的參數(shù)亭枷,也將有助于熱啟動(dòng)訓(xùn)練過程,并希望與從頭開始訓(xùn)練相比搀崭,可以更快地收斂模型叨粘。

torch.save(modelA.state_dict(), PATH)

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

無論是從缺少某些鍵的部分state_dict加載,還是加載比要加載的模型更多的keystate_dict瘤睹,都可以在load_state_dict()函數(shù)中將strict參數(shù)設(shè)置為False升敲,以忽略不匹配鍵。

如果你想要將一個(gè)層的參數(shù)加載到另一個(gè)層轰传,但是一些keys不匹配驴党,你只需改變你所加載的state_dict中的名稱即可。

3.跨設(shè)備保存和加載模型

3.1在GPU中保存获茬,在CPU中加載

torch.save(model.state_dict(), PATH)

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

3.2在GPU中保存港庄,在GPU中加載

torch.save(model.state_dict(), PATH)

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

3.3在CPU中保存倔既,在GPU中加載

torch.save(model.state_dict(), PATH)

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

4.保存torch.nn.DataParallel的模型

torch.save(model.module.state_dict(), PATH)

# Load to whatever device you want,加載方法使用常規(guī)方式即可。

參考鏈接:

  1. 官方文檔
  2. Dive-into-DL-PyTorch
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末鹏氧,一起剝皮案震驚了整個(gè)濱河市渤涌,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌把还,老刑警劉巖实蓬,帶你破解...
    沈念sama閱讀 211,123評(píng)論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異吊履,居然都是意外死亡安皱,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,031評(píng)論 2 384
  • 文/潘曉璐 我一進(jìn)店門艇炎,熙熙樓的掌柜王于貴愁眉苦臉地迎上來练俐,“玉大人,你說我怎么就攤上這事冕臭∠倭溃” “怎么了?”我有些...
    開封第一講書人閱讀 156,723評(píng)論 0 345
  • 文/不壞的土叔 我叫張陵辜贵,是天一觀的道長(zhǎng)悯蝉。 經(jīng)常有香客問我,道長(zhǎng)托慨,這世上最難降的妖魔是什么鼻由? 我笑而不...
    開封第一講書人閱讀 56,357評(píng)論 1 283
  • 正文 為了忘掉前任,我火速辦了婚禮厚棵,結(jié)果婚禮上蕉世,老公的妹妹穿的比我還像新娘。我一直安慰自己婆硬,他們只是感情好狠轻,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,412評(píng)論 5 384
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著彬犯,像睡著了一般向楼。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上谐区,一...
    開封第一講書人閱讀 49,760評(píng)論 1 289
  • 那天湖蜕,我揣著相機(jī)與錄音,去河邊找鬼宋列。 笑死昭抒,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播灭返,決...
    沈念sama閱讀 38,904評(píng)論 3 405
  • 文/蒼蘭香墨 我猛地睜開眼盗迟,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了婆殿?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,672評(píng)論 0 266
  • 序言:老撾萬榮一對(duì)情侶失蹤罩扇,失蹤者是張志新(化名)和其女友劉穎婆芦,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體喂饥,經(jīng)...
    沈念sama閱讀 44,118評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡消约,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,456評(píng)論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了员帮。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片或粮。...
    茶點(diǎn)故事閱讀 38,599評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖捞高,靈堂內(nèi)的尸體忽然破棺而出氯材,到底是詐尸還是另有隱情,我是刑警寧澤硝岗,帶...
    沈念sama閱讀 34,264評(píng)論 4 328
  • 正文 年R本政府宣布氢哮,位于F島的核電站,受9級(jí)特大地震影響型檀,放射性物質(zhì)發(fā)生泄漏冗尤。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,857評(píng)論 3 312
  • 文/蒙蒙 一胀溺、第九天 我趴在偏房一處隱蔽的房頂上張望裂七。 院中可真熱鬧,春花似錦仓坞、人聲如沸背零。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,731評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽捉兴。三九已至,卻和暖如春录语,著一層夾襖步出監(jiān)牢的瞬間倍啥,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,956評(píng)論 1 264
  • 我被黑心中介騙來泰國(guó)打工澎埠, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留虽缕,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,286評(píng)論 2 360
  • 正文 我出身青樓蒲稳,卻偏偏與公主長(zhǎng)得像氮趋,于是被迫代替她去往敵國(guó)和親伍派。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,465評(píng)論 2 348

推薦閱讀更多精彩內(nèi)容