pytorch之保存與加載模型
本篇筆記譯自
pytorch
官網(wǎng)tutorial
检柬,用于方便查看。
pytorch
與保存咸作、加載模型有關(guān)的常用函數(shù)3個(gè):
-
torch.save()
: 保存一個(gè)序列化的對象到磁盤浙值,使用的是Python
的pickle
庫來實(shí)現(xiàn)的羔巢。 -
torch.load()
: 解序列化一個(gè)pickled
對象并加載到內(nèi)存當(dāng)中萝映。 -
torch.nn.Module.load_state_dict()
: 加載一個(gè)解序列化的state_dict
對象
1. state_dict
在PyTorch
中所有可學(xué)習(xí)的參數(shù)保存在model.parameters()
中吴叶。state_dict
是一個(gè)Python
字典。保存了各層與其參數(shù)張量之間的映射序臂。torch.optim
對象也有一個(gè)state_dict
蚌卤,它包含了optimizer
的state
,以及一些超參數(shù)奥秆。
2. 保存&加載模型來inference
(recommended)
save
torch.save(model.state_dict(), PATH)
load
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval() # 當(dāng)用于inference時(shí)不要忘記添加
- 保存的文件名后綴可以是
.pt
或.pth
- 當(dāng)用于inference時(shí)不要忘記添加
model.eval()
3. 保存&加載整個(gè)模型(not recommended)
save
torch.save(model, PATH)
load
# Model class must be defined somewhere
model = torch.load()
model.eval()
4. 保存&加載帶checkpoint
的模型用于inference
或resuming training
save
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
load
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# or
model.train()
5. 保存多個(gè)模型到一個(gè)文件中
save
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
load
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelAClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict']
modelB.load_state_dict(checkpoint['modelB_state_dict']
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']
modelA.eval()
modelB.eval()
# or
modelA.train()
modelB.train()
- 此情況可能在
GAN
逊彭,Sequence-to-sequence
,或ensemble models
中使用 - 保存
checkpoint
常用.tar
文件擴(kuò)展名
6. Warmstarting Model Using Parameters From A Different Model
save
torch.save(modelA.state_dict(), PATH)
load
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
- 在遷移訓(xùn)練時(shí)构订,可能希望只加載部分模型參數(shù)侮叮,此時(shí)可置
strict
參數(shù)為False
來忽略那些沒有匹配到的keys
7. 保存&加載模型跨設(shè)備
(1) Save on GPU, Load on CPU
save
torch.save(model.state_dict(), PATH)
load
device = torch.device("cpu")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
(2) Save on GPU, Load on GPU
save
torch.save(model.state_dict(), PATH)
load
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
(3) Save on CPU, Load on GPU
save
torch.save(model.state_dict(), PATH)
load
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
model.to(device)
8. 保存torch.nn.DataParallel模型
save
torch.save(model.module.state_dict(), PATH)
load
# Load to whatever device you want