PyTorch學(xué)習(xí)之路(level1)——訓(xùn)練一個圖像分類模型

選取官方鏈接里面的例子汗洒,介紹如何用PyTorch訓(xùn)練一個ResNet模型用于圖像分類秃励,代碼邏輯非常清晰怜奖,基本上和許多深度學(xué)習(xí)框架的代碼思路類似,非常適合初學(xué)者想上手PyTorch訓(xùn)練模型(不必每次都跑mnist的demo了)芥颈。接下來從個人使用角度加以解釋。解釋的思路是從數(shù)據(jù)導(dǎo)入開始到模型訓(xùn)練結(jié)束做粤,基本上就是搭積木的方式來寫代碼浇借。

首先是數(shù)據(jù)導(dǎo)入部分,這里采用官方寫好的torchvision.datasets.ImageFolder接口實現(xiàn)數(shù)據(jù)導(dǎo)入怕品。這個接口需要你提供圖像所在的文件夾妇垢,就是下面的data_dir=‘/data’這句,然后對于一個分類問題肉康,這里data_dir目錄下一般包括兩個文件夾:train和val闯估,每個文件件下面包含N個子文件夾,N是你的分類類別數(shù)吼和,且每個子文件夾里存放的就是這個類別的圖像涨薪。這樣torchvision.datasets.ImageFolder就會返回一個列表(比如下面代碼中的image_datasets[‘train’]或者image_datasets[‘val]),列表中的每個值都是一個tuple炫乓,每個tuple包含圖像和標(biāo)簽信息刚夺。
————————————

data_dir = '/data'
image_datasets = {x: datasets.ImageFolder(
                    os.path.join(data_dir, x),
                    data_transforms[x]), 
                    for x in ['train', 'val']}

另外這里的data_transforms是一個字典末捣,如下侠姑。主要是進(jìn)行一些圖像預(yù)處理,比如resize箩做、crop等莽红。實現(xiàn)的時候采用的是torchvision.transforms模塊,比如torchvision.transforms.Compose是用來管理所有transforms操作的邦邦,torchvision.transforms.RandomSizedCrop是做crop的安吁。需要注意的是對于torchvision.transforms.RandomSizedCrop和transforms.RandomHorizontalFlip()等,輸入對象都是PIL Image燃辖,也就是用python的PIL庫讀進(jìn)來的圖像內(nèi)容鬼店,而transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])的作用對象需要是一個Tensor,因此在transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])之前有一個 transforms.ToTensor()就是用來生成Tensor的黔龟。另外transforms.Scale(256)其實就是resize操作妇智,目前已經(jīng)被transforms.Resize類取代了确沸。
————————————————

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

前面torchvision.datasets.ImageFolder只是返回list,list是不能作為模型輸入的俘陷,因此在PyTorch中需要用另一個類來封裝list罗捎,那就是:torch.utils.data.DataLoader。torch.utils.data.DataLoader類可以將list類型的輸入數(shù)據(jù)封裝成Tensor數(shù)據(jù)格式拉盾,以備模型使用桨菜。注意,這里是對圖像和標(biāo)簽分別封裝成一個Tensor捉偏。這里要提到另一個很重要的類:torch.utils.data.Dataset倒得,這是一個抽象類,在pytorch中所有和數(shù)據(jù)相關(guān)的類都要繼承這個類來實現(xiàn)夭禽。比如前面說的torchvision.datasets.ImageFolder類是這樣的霞掺,以及這里的torch.util.data.DataLoader類也是這樣的。所以當(dāng)你的數(shù)據(jù)不是按照一個類別一個文件夾這種方式存儲時讹躯,你就要自定義一個類來讀取數(shù)據(jù)菩彬,自定義的這個類必須繼承自torch.utils.data.Dataset這個基類,最后同樣用torch.utils.data.DataLoader封裝成Tensor潮梯。

dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],     
                                            batch_size=4, 
                                            shuffle=True,
                                            num_workers=4) 
                                            for x in ['train', 'val']}

生成dataloaders后再有一步就可以作為模型的輸入了骗灶,那就是將Tensor數(shù)據(jù)類型封裝成Variable數(shù)據(jù)類型,來看下面這段代碼秉馏。dataloaders是一個字典耙旦,dataloders[‘train’]存的就是訓(xùn)練的數(shù)據(jù),這個for循環(huán)就是從dataloders[‘train’]中讀取batch_size個數(shù)據(jù)萝究,batch_size在前面生成dataloaders的時候就設(shè)置了免都。因此這個data里面包含圖像數(shù)據(jù)(inputs)這個Tensor和標(biāo)簽(labels)這個Tensor。然后用torch.autograd.Variable將Tensor封裝成模型真正可以用的Variable數(shù)據(jù)類型帆竹。
為什么要封裝成Variable呢绕娘?在pytorch中,torch.tensor和torch.autograd.Variable是兩種比較重要的數(shù)據(jù)結(jié)構(gòu)馆揉,Variable可以看成是tensor的一種包裝业舍,其不僅包含了tensor的內(nèi)容抖拦,還包含了梯度等信息升酣,因此在神經(jīng)網(wǎng)絡(luò)中常常用Variable數(shù)據(jù)結(jié)構(gòu)。那么怎么從一個Variable類型中取出tensor呢态罪?也很簡單噩茄,比如下面封裝后的inputs是一個Variable,那么inputs.data就是對應(yīng)的tensor复颈。

for data in dataloders['train']:
   inputs, labels = data

   if use_gpu:
       inputs = Variable(inputs.cuda())
       labels = Variable(labels.cuda())
   else:
       inputs, labels = Variable(inputs), Variable(labels)

封裝好了數(shù)據(jù)后绩聘,就可以作為模型的輸入了沥割。所以要先導(dǎo)入你的模型。在PyTorch中已經(jīng)默認(rèn)為大家準(zhǔn)備了一些常用的網(wǎng)絡(luò)結(jié)構(gòu)凿菩,比如分類中的VGG机杜,ResNet,DenseNet等等衅谷,可以用torchvision.models模塊來導(dǎo)入椒拗。比如用torchvision.models.resnet18(pretrained=True)來導(dǎo)入ResNet18網(wǎng)絡(luò),同時指明導(dǎo)入的是已經(jīng)預(yù)訓(xùn)練過的網(wǎng)絡(luò)获黔。因為預(yù)訓(xùn)練網(wǎng)絡(luò)一般是在1000類的ImageNet數(shù)據(jù)集上進(jìn)行的蚀苛,所以要遷移到你自己數(shù)據(jù)集的2分類,需要替換最后的全連接層為你所需要的輸出玷氏。因此下面這三行代碼進(jìn)行的就是用models模塊導(dǎo)入resnet18網(wǎng)絡(luò)堵未,然后獲取全連接層的輸入channel個數(shù),用這個channel個數(shù)和你要做的分類類別數(shù)(這里是2)替換原來模型中的全連接層盏触。這樣網(wǎng)絡(luò)結(jié)果也準(zhǔn)備好渗蟹。

model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

但是只有網(wǎng)絡(luò)結(jié)構(gòu)和數(shù)據(jù)還不足以讓代碼運行起來,還需要定義損失函數(shù)赞辩。在PyTorch中采用torch.nn模塊來定義網(wǎng)絡(luò)的所有層拙徽,比如卷積、降采樣诗宣、損失層等等膘怕,這里采用交叉熵函數(shù),因此可以這樣定義:

criterion = nn.CrossEntropyLoss()

然后你還需要定義優(yōu)化函數(shù)召庞,比如最常見的隨機(jī)梯度下降岛心,在PyTorch中是通過torch.optim模塊來實現(xiàn)的。另外這里雖然寫的是SGD篮灼,但是因為有momentum忘古,所以是Adam的優(yōu)化方式。這個類的輸入包括需要優(yōu)化的參數(shù):model.parameters()诅诱,學(xué)習(xí)率髓堪,還有Adam相關(guān)的momentum參數(shù)。現(xiàn)在很多優(yōu)化方式的默認(rèn)定義形式就是這樣的娘荡。

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

然后一般還會定義學(xué)習(xí)率的變化策略干旁,這里采用的是torch.optim.lr_scheduler模塊的StepLR類,表示每隔step_size個epoch就將學(xué)習(xí)率降為原來的gamma倍炮沐。

scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

開始訓(xùn)練

首先訓(xùn)練開始的時候需要先更新下學(xué)習(xí)率争群,這是因為我們前面制定了學(xué)習(xí)率的變化策略,所以在每個epoch開始時都要更新下:

scheduler.step()

然后設(shè)置模型狀態(tài)為訓(xùn)練狀態(tài):

model.train(True)

然后先將網(wǎng)絡(luò)中的所有梯度置0:

optimizer.zero_grad()

然后就是網(wǎng)絡(luò)的前向傳播了:

outputs = model(inputs)

然后將輸出的outputs和原來導(dǎo)入的labels作為loss函數(shù)的輸入就可以得到損失了

loss = criterion(outputs, labels)

輸出的outputs也是torch.autograd.Variable格式大年,得到輸出后(網(wǎng)絡(luò)的全連接層的輸出)還希望能到到模型預(yù)測該樣本屬于哪個類別的信息换薄,這里采用torch.max玉雾。torch.max()的第一個輸入是tensor格式,所以用outputs.data而不是outputs作為輸入轻要;第二個參數(shù)1是代表dim的意思复旬,也就是取每一行的最大值,其實就是我們常見的取概率最大的那個index冲泥;第三個參數(shù)loss也是torch.autograd.Variable格式赢底。

 _, preds = torch.max(outputs.data, 1)

計算得到loss后就要回傳損失。要注意的是這是在訓(xùn)練的時候才會有的操作柏蘑,測試時候只有forward過程幸冻。

loss.backward()

回傳損失過程中會計算梯度,然后需要根據(jù)這些梯度更新參數(shù)咳焚,optimizer.step()就是用來更新參數(shù)的洽损。optimizer.step()后,你就可以從optimizer.param_groups[0][‘params’]里面看到各個層的梯度和權(quán)值信息革半。

optimizer.step()

這樣一個batch數(shù)據(jù)的訓(xùn)練就結(jié)束了碑定!當(dāng)你不斷重復(fù)這樣的訓(xùn)練過程,最終就可以達(dá)到你想要的結(jié)果了又官。

另外如果你有g(shù)pu可用延刘,那么包括你的數(shù)據(jù)和模型都可以在gpu上操作,這在PyTorch中也非常簡單六敬。判斷你是否有g(shù)pu可以用可以通過下面這行代碼碘赖,如果有,則use_gpu是true外构。

use_gpu = torch.cuda.is_available()

參考

PyTorch使用及源碼解讀
Pytorch-Transfer Learning
Welcome to PyTorch Tutorials
PyTorch學(xué)習(xí)之路(level1)——訓(xùn)練一個圖像分類模型
https://github.com/miraclewkf/ImageClassification-PyTorch/blob/master/level1/train.py

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末普泡,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子审编,更是在濱河造成了極大的恐慌撼班,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,941評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件垒酬,死亡現(xiàn)場離奇詭異砰嘁,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)勘究,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,397評論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來板祝,“玉大人券时,你說我怎么就攤上這事橘洞≌ㄔ妫” “怎么了?”我有些...
    開封第一講書人閱讀 165,345評論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長逛揩。 經(jīng)常有香客問我麸俘,道長逞泄,這世上最難降的妖魔是什么炭懊? 我笑而不...
    開封第一講書人閱讀 58,851評論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮父阻,結(jié)果婚禮上加矛,老公的妹妹穿的比我還像新娘。我一直安慰自己苛茂,他們只是感情好妓羊,可當(dāng)我...
    茶點故事閱讀 67,868評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著剥哑,像睡著了一般淹父。 火紅的嫁衣襯著肌膚如雪弹灭。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,688評論 1 305
  • 那天捡鱼,我揣著相機(jī)與錄音,去河邊找鬼缠诅。 笑死管引,一個胖子當(dāng)著我的面吹牛褥伴,可吹牛的內(nèi)容都是我干的重慢。 我是一名探鬼主播,決...
    沈念sama閱讀 40,414評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼囚戚,長吁一口氣:“原來是場噩夢啊……” “哼弯淘!你這毒婦竟也來了吉懊?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,319評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體蕉拢,經(jīng)...
    沈念sama閱讀 45,775評論 1 315
  • 正文 獨居荒郊野嶺守林人離奇死亡晕换,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,945評論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了站宗。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片闸准。...
    茶點故事閱讀 40,096評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖梢灭,靈堂內(nèi)的尸體忽然破棺而出恕汇,到底是詐尸還是另有隱情,我是刑警寧澤或辖,帶...
    沈念sama閱讀 35,789評論 5 346
  • 正文 年R本政府宣布瘾英,位于F島的核電站,受9級特大地震影響颂暇,放射性物質(zhì)發(fā)生泄漏缺谴。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,437評論 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望湿蛔。 院中可真熱鬧膀曾,春花似錦、人聲如沸阳啥。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,993評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽察迟。三九已至斩狱,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間扎瓶,已是汗流浹背所踊。 一陣腳步聲響...
    開封第一講書人閱讀 33,107評論 1 271
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留概荷,地道東北人秕岛。 一個月前我還...
    沈念sama閱讀 48,308評論 3 372
  • 正文 我出身青樓,卻偏偏與公主長得像误证,于是被迫代替她去往敵國和親继薛。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,037評論 2 355

推薦閱讀更多精彩內(nèi)容