(1)直接加載預(yù)訓(xùn)練模型
????????如果我們使用的模型和原模型完全一樣源武,那么我們可以直接加載別人訓(xùn)練好的模型:
????????my_resnet = MyResNet(*args, **kwargs)
????????my_resnet.load_state_dict(torch.load("my_resnet.pth"))
????????當(dāng)然這樣的加載方法是基于PyTorch推薦的存儲模型的方法:
????????torch.save(my_resnet.state_dict(),"my_resnet.pth")
????????還有第二種加載方法:
????????my_resnet=torch.load("my_resnet.pth")
(2)加載部分預(yù)訓(xùn)練模型
????????其實大多數(shù)時候我們需要根據(jù)我們的任務(wù)調(diào)節(jié)我們的模型盆繁,所以很難保證模型和公開的模型完全一樣汁掠,但是預(yù)訓(xùn)練模型的參數(shù)確實有助于提高訓(xùn)練的準(zhǔn)確率所袁,為了結(jié)合二者的優(yōu)點蒜危,就需要我們加載部分預(yù)訓(xùn)練模型趁俊。
????????pretrained_dict=model_zoo.load_url(model_urls['resnet152'])
????????model_dict=model.state_dict()#?將pretrained_dict里不屬于model_dict的鍵剔除掉
????????pretrained_dict={k:vfork,vinpretrained_dict.items()ifkinmodel_dict}#?更新現(xiàn)有的model_dict
????????model_dict.update(pretrained_dict)#?加載我們真正需要的state_dict
????????model.load_state_dict(model_dict)
????????因為需要剔除原模型中不匹配的鍵脚作,也就是層的名字警没,所以我們的新模型改變了的層需要和原模型對應(yīng)層的名字不一樣匈辱,比如:resnet最后一層的名字是fc(PyTorch中),那么我們修改過的resnet的最后一層就不能取這個名字杀迹,可以叫fc_亡脸,層的名字要改變
(3)微改基礎(chǔ)模型預(yù)訓(xùn)練
????????對于改動比較大的模型,我們可能需要自己實現(xiàn)一下再加載別人的預(yù)訓(xùn)練參數(shù)树酪。但是浅碾,對于一些基本模型PyTorch中已經(jīng)有了,而且我只想進(jìn)行一些小的改動那么怎么辦呢续语?難道我又去實現(xiàn)一遍嗎垂谢?當(dāng)然不是。
????????我們首先看看怎么進(jìn)行微改模型疮茄。
????????微改基礎(chǔ)模型
????????????PyTorch中的torchvision里已經(jīng)有很多常用的模型了滥朱,可以直接調(diào)用:
????????????????AlexNet、VGG力试、ResNet徙邻、SqueezeNet、DenseNet畸裳、
????????????????importtorchvision.modelsasmodels:
????????????????????????????resnet18=models.resnet18()
????????????????????????????alexnet=models.alexnet()
????????????????????????????squeezenet=models.squeezenet1_0()
????????????????????????????densenet=models.densenet_161()
????????????但是對于我們的任務(wù)而言有些層并不是直接能用缰犁,需要我們微微改一下,比如躯畴,resnet最后的全連接層是分1000類民鼓,而我們只有21類;又比如蓬抄,resnet第一層卷積接收的通道是3丰嘉, 我們可能輸入圖片的通道是4,那么可以通過以下方法修改:
????????????resnet.conv1=nn.Conv2d(4,64,kernel_size=7,stride=2,padding=3,bias=False)resnet.fc=nn.Linear(2048,21)
?(4)簡單預(yù)訓(xùn)練
????????????模型已經(jīng)改完了嚷缭,接下來我們就進(jìn)行簡單預(yù)訓(xùn)練吧饮亏。
????????????我們先從torchvision中調(diào)用基本模型耍贾,加載預(yù)訓(xùn)練模型,然后路幸,重點來了荐开,將其中的層直接替換為我們需要的層即可:
????????????????resnet=torchvision.models.resnet152(pretrained=True)#?原本為1000類,改為10類
????????????????resnet.fc=torch.nn.Linear(2048,10)
????????????其中使用了pretrained參數(shù)简肴,會直接加載預(yù)訓(xùn)練模型晃听,內(nèi)部實現(xiàn)和前文提到的加載預(yù)訓(xùn)練的方法一樣。因為是先加載的預(yù)訓(xùn)練參數(shù)砰识,相當(dāng)于模型中已經(jīng)有參數(shù)了能扒,所以替換掉最后一層即可。OK辫狼!