我們看到torchvision提供的detection訓(xùn)練代碼中
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.output_dir:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'args': args,
'epoch': epoch},
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
都是保存和加載了optimizer和lr_scheduler盛末,為什么不直接保存model呢训挡,因為考慮到adam和sgd兩種常用的優(yōu)化器沮脖,adam的原理 可以看 https://stats.stackexchange.com/questions/220494/how-does-the-adam-method-of-stochastic-gradient-descent-work
adam是動態(tài)調(diào)整的,和當(dāng)前parameter有關(guān),所以resume時需要加載optimizer.state_dict()
。
sgd的learning rate一般都是epoch相關(guān)拓哟,所以需要lr_scheduler.state_dict()