class Net(nn.Module):
def __init__(input, output):
pass
#define your network
net = Net(input, output) #實例化模型
net = nn.DataParallel(net) #數(shù)據(jù)并行
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #初始化計算設(shè)備
net .to(device)
2.模型的保存
if
torch.save(net.)
3.模型的重載
checkpoint = torch.load(resume)
state_dict =checkpoint['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of dataparallel
new_state_dict[name]=v
model.load_state_dict(new_state_dict)
4.模型的遷移
# cpu or gpu
torch.load('model/path', map_location='cpu')