??????本文是一篇關(guān)于如何用Pytorch保存和加載模型的指南
文章目錄
- 1.讀寫
tensor
- 1.1 單個(gè)張量
- 1.2 張量列表和張量詞典
- 2.保存和加載模型
- 2.1
state_dict
- 2.1
- 2.2 保存和加載
- 2.2.1 保存和加載
state_dict
(推薦方式)
- 2.2.1 保存和加載
- 2.2.2 保存和讀寫整個(gè)模型
- 2.2.3 保存和加載
checkpiont
- 2.2.3 保存和加載
- 2.2.4 在一個(gè)文件中保存多個(gè)模型
- 2.3 使用來自不同模型的參數(shù)進(jìn)行模型熱啟動(dòng)
- 3.跨設(shè)備保存和加載模型
- 3.1 在GPU中保存,在CPU中加載
- 3.2 在GPU中保存怜森,在GPU中加載
- 3.3 在CPU中保存,在GPU中加載
- 4.保存
torch.nn.DataParallel
的模型
本文主要涉及到3個(gè)函數(shù):
- 1.
torch.save
: 使用Python的pickle
實(shí)用程序?qū)?duì)象進(jìn)行序列化嘿悬,然后將序列化的對(duì)象保存到disk,可以保存各種對(duì)象,包括模型水泉、張量和字典等善涨。 - 2.
torch.load
: 使用pickle unpickle
工具將pickle
的對(duì)象文件反序列化為內(nèi)存。 - 3.
torch.nn.Module.load_state_dict
: 用反序列化的state_dict
來加載模型參數(shù)草则。
1.讀寫tensor
1.1單個(gè)張量
import torch
x = torch.tensor([3.,4.])
torch.save(x, 'x.pt')
x1 = torch.load('x.pt')
print(x1)
輸出:
tensor([3., 4.])
1.2張量列表和張量詞典
y = torch.ones((4,2))
torch.save([x,y],'xy.pt')
torch.save({'x':x, 'y':y}, 'xy_dict.pt')
xy = torch.load('xy.pt')
xy_dict = torch.load('xy_dict.pt')
print(xy)
print(xy_dict)
輸出:
[tensor([3., 4.]), tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]])]
{'x': tensor([3., 4.]), 'y': tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]])}
2.保存和加載模型
2.1state_dict
state_dict
是一個(gè)從每一個(gè)層的名稱映射到這個(gè)層的參數(shù)Tesnor
的字典對(duì)象钢拧。
注意,只有具有可學(xué)習(xí)參數(shù)的層(卷積層畔师、線性層等)和注冊(cè)緩存(batchnorm’s running_mean)
才有state_dict
中的條目娶靡。優(yōu)化器(torch.optim)
也有一個(gè)state_dict
牧牢,其中包含關(guān)于優(yōu)化器狀態(tài)以及所使用的超參數(shù)的信息看锉。
from torch import nn
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
print(net.state_dict())
print('\n',net.state_dict()['output.weight'])
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print(optimizer.state_dict())
輸出:
OrderedDict([('hidden.weight', tensor([[ 0.0405, -0.0659, -0.5540],
[ 0.2954, 0.0676, -0.1933]])), ('hidden.bias', tensor([-0.1628, 0.0768])), ('output.weight', tensor([[-0.4635, 0.4958]])), ('output.bias', tensor([-0.5440]))])
tensor([[-0.4635, 0.4958]])
{'state': {}, 'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3]}]}
2.2保存和加載
PyTorch中保存和加載訓(xùn)練模型有兩種常見的方法:
- 僅保存和加載模型參數(shù)(state_dict)姿锭;
- 保存和加載整個(gè)模型。
2.2.1保存和加載state_dict
(推薦方式)
torch.save(net.state_dict(), 'net_state_dict.pt')## 后綴名一般寫為: .pt或.pth
net1 = MLP()
net1.load_state_dict(torch.load('net_state_dict.pt'))
print(net1)
輸出:
MLP(
(hidden): Linear(in_features=3, out_features=2, bias=True)
(act): ReLU()
(output): Linear(in_features=2, out_features=1, bias=True)
)
注意: load_state_dict()
接受一個(gè)詞典對(duì)象伯铣,而不是一個(gè)指向?qū)ο蟮穆窂缴氪恕R虼四阈枰仁褂?code>torch.load()來反序列化。比如腔寡,你不能直接這么用model.load_state_dict
(PATH)焚鲜。
2.2.2保存和讀寫整個(gè)模型
這個(gè)就相對(duì)來說比較簡(jiǎn)單了。
torch.save(net, 'net.pt')
net2 = torch.load('net.pt')
print(net2)
輸出:
MLP(
(hidden): Linear(in_features=3, out_features=2, bias=True)
(act): ReLU()
(output): Linear(in_features=2, out_features=1, bias=True)
)
注意:以這種方式保存模型將使用Python的pickle模塊保存整個(gè)模塊放前。 這種方法的缺點(diǎn)是序列化的數(shù)據(jù)被綁定到特定的類忿磅,并且在保存模型時(shí)使用了確切的詞典結(jié)構(gòu)。 這樣做的原因是因?yàn)閜ickle不會(huì)保存模型類本身凭语。 而是將其保存到包含這個(gè)類的文件的路徑葱她,該路徑在加載時(shí)使用。 因此似扔,在其他項(xiàng)目中使用或重構(gòu)后吨些,您的代碼可能會(huì)以各種方式中斷。
2.2.3保存和加載checkpiont
## Save
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
###########################
## Load
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
保存用于檢查或繼續(xù)訓(xùn)練的checkpiont
時(shí)炒辉,你必須保存的不只是模型的state_dict
豪墅。 保存優(yōu)化器的state_dict
也很重要,因?yàn)樗S著模型訓(xùn)練而更新的緩沖區(qū)和參數(shù)黔寇。 你可能要保存的其他項(xiàng)目包括你未設(shè)置的時(shí)間段偶器,最新記錄的訓(xùn)練損失,外部torch.nn.Embedding
層等缝裤。
2.2.4在一個(gè)文件中保存多個(gè)模型
#Save
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
#Load
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
常見的PyTorch約定是使用.tar
文件擴(kuò)展名保存這些檢查點(diǎn)状囱。
2.3使用來自不同模型的參數(shù)進(jìn)行模型熱啟動(dòng)
這種方法一般用于遷移學(xué)習(xí)。利用經(jīng)過訓(xùn)練的參數(shù)倘是,即使只有少數(shù)幾個(gè)可用的參數(shù)亭枷,也將有助于熱啟動(dòng)訓(xùn)練過程,并希望與從頭開始訓(xùn)練相比搀崭,可以更快地收斂模型叨粘。
torch.save(modelA.state_dict(), PATH)
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
無論是從缺少某些鍵的部分state_dict
加載,還是加載比要加載的模型更多的key
的state_dict
瘤睹,都可以在load_state_dict()
函數(shù)中將strict
參數(shù)設(shè)置為False
升敲,以忽略不匹配鍵。
如果你想要將一個(gè)層的參數(shù)加載到另一個(gè)層轰传,但是一些keys
不匹配驴党,你只需改變你所加載的state_dict
中的名稱即可。
3.跨設(shè)備保存和加載模型
3.1在GPU中保存获茬,在CPU中加載
torch.save(model.state_dict(), PATH)
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
3.2在GPU中保存港庄,在GPU中加載
torch.save(model.state_dict(), PATH)
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
3.3在CPU中保存倔既,在GPU中加載
torch.save(model.state_dict(), PATH)
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
4.保存torch.nn.DataParallel
的模型
torch.save(model.module.state_dict(), PATH)
# Load to whatever device you want,加載方法使用常規(guī)方式即可。
參考鏈接: