選取官方鏈接里面的例子汗洒,介紹如何用PyTorch訓(xùn)練一個ResNet模型用于圖像分類秃励,代碼邏輯非常清晰怜奖,基本上和許多深度學(xué)習(xí)框架的代碼思路類似,非常適合初學(xué)者想上手PyTorch訓(xùn)練模型(不必每次都跑mnist的demo了)芥颈。接下來從個人使用角度加以解釋。解釋的思路是從數(shù)據(jù)導(dǎo)入開始到模型訓(xùn)練結(jié)束做粤,基本上就是搭積木的方式來寫代碼浇借。
首先是數(shù)據(jù)導(dǎo)入部分,這里采用官方寫好的torchvision.datasets.ImageFolder接口實現(xiàn)數(shù)據(jù)導(dǎo)入怕品。這個接口需要你提供圖像所在的文件夾妇垢,就是下面的data_dir=‘/data’這句,然后對于一個分類問題肉康,這里data_dir目錄下一般包括兩個文件夾:train和val闯估,每個文件件下面包含N個子文件夾,N是你的分類類別數(shù)吼和,且每個子文件夾里存放的就是這個類別的圖像涨薪。這樣torchvision.datasets.ImageFolder就會返回一個列表(比如下面代碼中的image_datasets[‘train’]或者image_datasets[‘val]),列表中的每個值都是一個tuple炫乓,每個tuple包含圖像和標(biāo)簽信息刚夺。
————————————
data_dir = '/data'
image_datasets = {x: datasets.ImageFolder(
os.path.join(data_dir, x),
data_transforms[x]),
for x in ['train', 'val']}
另外這里的data_transforms是一個字典末捣,如下侠姑。主要是進(jìn)行一些圖像預(yù)處理,比如resize箩做、crop等莽红。實現(xiàn)的時候采用的是torchvision.transforms模塊,比如torchvision.transforms.Compose是用來管理所有transforms操作的邦邦,torchvision.transforms.RandomSizedCrop是做crop的安吁。需要注意的是對于torchvision.transforms.RandomSizedCrop和transforms.RandomHorizontalFlip()等,輸入對象都是PIL Image燃辖,也就是用python的PIL庫讀進(jìn)來的圖像內(nèi)容鬼店,而transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])的作用對象需要是一個Tensor,因此在transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])之前有一個 transforms.ToTensor()就是用來生成Tensor的黔龟。另外transforms.Scale(256)其實就是resize操作妇智,目前已經(jīng)被transforms.Resize類取代了确沸。
————————————————
data_transforms = {
'train': transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
前面torchvision.datasets.ImageFolder只是返回list,list是不能作為模型輸入的俘陷,因此在PyTorch中需要用另一個類來封裝list罗捎,那就是:torch.utils.data.DataLoader。torch.utils.data.DataLoader類可以將list類型的輸入數(shù)據(jù)封裝成Tensor數(shù)據(jù)格式拉盾,以備模型使用桨菜。注意,這里是對圖像和標(biāo)簽分別封裝成一個Tensor捉偏。這里要提到另一個很重要的類:torch.utils.data.Dataset倒得,這是一個抽象類,在pytorch中所有和數(shù)據(jù)相關(guān)的類都要繼承這個類來實現(xiàn)夭禽。比如前面說的torchvision.datasets.ImageFolder類是這樣的霞掺,以及這里的torch.util.data.DataLoader類也是這樣的。所以當(dāng)你的數(shù)據(jù)不是按照一個類別一個文件夾這種方式存儲時讹躯,你就要自定義一個類來讀取數(shù)據(jù)菩彬,自定義的這個類必須繼承自torch.utils.data.Dataset這個基類,最后同樣用torch.utils.data.DataLoader封裝成Tensor潮梯。
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=4,
shuffle=True,
num_workers=4)
for x in ['train', 'val']}
生成dataloaders后再有一步就可以作為模型的輸入了骗灶,那就是將Tensor數(shù)據(jù)類型封裝成Variable數(shù)據(jù)類型,來看下面這段代碼秉馏。dataloaders是一個字典耙旦,dataloders[‘train’]存的就是訓(xùn)練的數(shù)據(jù),這個for循環(huán)就是從dataloders[‘train’]中讀取batch_size個數(shù)據(jù)萝究,batch_size在前面生成dataloaders的時候就設(shè)置了免都。因此這個data里面包含圖像數(shù)據(jù)(inputs)這個Tensor和標(biāo)簽(labels)這個Tensor。然后用torch.autograd.Variable將Tensor封裝成模型真正可以用的Variable數(shù)據(jù)類型帆竹。
為什么要封裝成Variable呢绕娘?在pytorch中,torch.tensor和torch.autograd.Variable是兩種比較重要的數(shù)據(jù)結(jié)構(gòu)馆揉,Variable可以看成是tensor的一種包裝业舍,其不僅包含了tensor的內(nèi)容抖拦,還包含了梯度等信息升酣,因此在神經(jīng)網(wǎng)絡(luò)中常常用Variable數(shù)據(jù)結(jié)構(gòu)。那么怎么從一個Variable類型中取出tensor呢态罪?也很簡單噩茄,比如下面封裝后的inputs是一個Variable,那么inputs.data就是對應(yīng)的tensor复颈。
for data in dataloders['train']:
inputs, labels = data
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)
封裝好了數(shù)據(jù)后绩聘,就可以作為模型的輸入了沥割。所以要先導(dǎo)入你的模型。在PyTorch中已經(jīng)默認(rèn)為大家準(zhǔn)備了一些常用的網(wǎng)絡(luò)結(jié)構(gòu)凿菩,比如分類中的VGG机杜,ResNet,DenseNet等等衅谷,可以用torchvision.models模塊來導(dǎo)入椒拗。比如用torchvision.models.resnet18(pretrained=True)來導(dǎo)入ResNet18網(wǎng)絡(luò),同時指明導(dǎo)入的是已經(jīng)預(yù)訓(xùn)練過的網(wǎng)絡(luò)获黔。因為預(yù)訓(xùn)練網(wǎng)絡(luò)一般是在1000類的ImageNet數(shù)據(jù)集上進(jìn)行的蚀苛,所以要遷移到你自己數(shù)據(jù)集的2分類,需要替換最后的全連接層為你所需要的輸出玷氏。因此下面這三行代碼進(jìn)行的就是用models模塊導(dǎo)入resnet18網(wǎng)絡(luò)堵未,然后獲取全連接層的輸入channel個數(shù),用這個channel個數(shù)和你要做的分類類別數(shù)(這里是2)替換原來模型中的全連接層盏触。這樣網(wǎng)絡(luò)結(jié)果也準(zhǔn)備好渗蟹。
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
但是只有網(wǎng)絡(luò)結(jié)構(gòu)和數(shù)據(jù)還不足以讓代碼運行起來,還需要定義損失函數(shù)赞辩。在PyTorch中采用torch.nn模塊來定義網(wǎng)絡(luò)的所有層拙徽,比如卷積、降采樣诗宣、損失層等等膘怕,這里采用交叉熵函數(shù),因此可以這樣定義:
criterion = nn.CrossEntropyLoss()
然后你還需要定義優(yōu)化函數(shù)召庞,比如最常見的隨機(jī)梯度下降岛心,在PyTorch中是通過torch.optim模塊來實現(xiàn)的。另外這里雖然寫的是SGD篮灼,但是因為有momentum忘古,所以是Adam的優(yōu)化方式。這個類的輸入包括需要優(yōu)化的參數(shù):model.parameters()诅诱,學(xué)習(xí)率髓堪,還有Adam相關(guān)的momentum參數(shù)。現(xiàn)在很多優(yōu)化方式的默認(rèn)定義形式就是這樣的娘荡。
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
然后一般還會定義學(xué)習(xí)率的變化策略干旁,這里采用的是torch.optim.lr_scheduler模塊的StepLR類,表示每隔step_size個epoch就將學(xué)習(xí)率降為原來的gamma倍炮沐。
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
開始訓(xùn)練
首先訓(xùn)練開始的時候需要先更新下學(xué)習(xí)率争群,這是因為我們前面制定了學(xué)習(xí)率的變化策略,所以在每個epoch開始時都要更新下:
scheduler.step()
然后設(shè)置模型狀態(tài)為訓(xùn)練狀態(tài):
model.train(True)
然后先將網(wǎng)絡(luò)中的所有梯度置0:
optimizer.zero_grad()
然后就是網(wǎng)絡(luò)的前向傳播了:
outputs = model(inputs)
然后將輸出的outputs和原來導(dǎo)入的labels作為loss函數(shù)的輸入就可以得到損失了
loss = criterion(outputs, labels)
輸出的outputs也是torch.autograd.Variable格式大年,得到輸出后(網(wǎng)絡(luò)的全連接層的輸出)還希望能到到模型預(yù)測該樣本屬于哪個類別的信息换薄,這里采用torch.max玉雾。torch.max()的第一個輸入是tensor格式,所以用outputs.data而不是outputs作為輸入轻要;第二個參數(shù)1是代表dim的意思复旬,也就是取每一行的最大值,其實就是我們常見的取概率最大的那個index冲泥;第三個參數(shù)loss也是torch.autograd.Variable格式赢底。
_, preds = torch.max(outputs.data, 1)
計算得到loss后就要回傳損失。要注意的是這是在訓(xùn)練的時候才會有的操作柏蘑,測試時候只有forward過程幸冻。
loss.backward()
回傳損失過程中會計算梯度,然后需要根據(jù)這些梯度更新參數(shù)咳焚,optimizer.step()就是用來更新參數(shù)的洽损。optimizer.step()后,你就可以從optimizer.param_groups[0][‘params’]里面看到各個層的梯度和權(quán)值信息革半。
optimizer.step()
這樣一個batch數(shù)據(jù)的訓(xùn)練就結(jié)束了碑定!當(dāng)你不斷重復(fù)這樣的訓(xùn)練過程,最終就可以達(dá)到你想要的結(jié)果了又官。
另外如果你有g(shù)pu可用延刘,那么包括你的數(shù)據(jù)和模型都可以在gpu上操作,這在PyTorch中也非常簡單六敬。判斷你是否有g(shù)pu可以用可以通過下面這行代碼碘赖,如果有,則use_gpu是true外构。
use_gpu = torch.cuda.is_available()
參考
PyTorch使用及源碼解讀
Pytorch-Transfer Learning
Welcome to PyTorch Tutorials
PyTorch學(xué)習(xí)之路(level1)——訓(xùn)練一個圖像分類模型
https://github.com/miraclewkf/ImageClassification-PyTorch/blob/master/level1/train.py