pytorch學(xué)習(xí)(十二)—遷移學(xué)習(xí)Transfer Learning

前言

在訓(xùn)練深度學(xué)習(xí)模型時(shí),有時(shí)候我們沒(méi)有海量的訓(xùn)練樣本,只有少數(shù)的訓(xùn)練樣本(比如幾百個(gè)圖片)诲宇,幾百個(gè)訓(xùn)練樣本顯然對(duì)于深度學(xué)習(xí)遠(yuǎn)遠(yuǎn)不夠。這時(shí)候惶翻,我們可以使用別人預(yù)訓(xùn)練好的網(wǎng)絡(luò)模型權(quán)重姑蓝,在此基礎(chǔ)上進(jìn)行訓(xùn)練,這就引入了一個(gè)概念——遷移學(xué)習(xí)(Transfer Learning)吕粗。


遷移學(xué)習(xí)

What(什么是遷移學(xué)習(xí))

遷移學(xué)習(xí)(Transfer Learning,TL)對(duì)于人類來(lái)說(shuō)纺荧,就是掌握舉一反三的學(xué)習(xí)能力。比如我們學(xué)會(huì)騎自行車后颅筋,學(xué)騎摩托車就很簡(jiǎn)單了宙暇;在學(xué)會(huì)打羽毛球之后,再學(xué)打網(wǎng)球也就沒(méi)那么難了议泵。對(duì)于計(jì)算機(jī)而言客给,所謂遷移學(xué)習(xí),就是能讓現(xiàn)有的模型算法稍加調(diào)整即可應(yīng)用于一個(gè)新的領(lǐng)域和功能的一項(xiàng)技術(shù)

How(如何進(jìn)行遷移學(xué)習(xí))

  • 首先需要選擇一個(gè)預(yù)訓(xùn)練好的模型肢簿,需要注意的是該模型的訓(xùn)練過(guò)程最好與我們要進(jìn)行訓(xùn)練的任務(wù)相似靶剑。比如我們要訓(xùn)練一個(gè)Cat,dog圖像分類的模型,最好應(yīng)該選擇一個(gè)圖像分類的預(yù)訓(xùn)練模型池充。

  • 針對(duì)實(shí)際任務(wù)桩引,對(duì)網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行調(diào)整。比如找到了一個(gè)預(yù)訓(xùn)練好的AlexNet(1000類別)收夸, 但是我們實(shí)際的任務(wù)的2分類坑匠,因此需要把最后一層的全連接輸出改為2.

Why(為何要使用遷移學(xué)習(xí))

https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.


目的

  • 了解ResNet
  • 基于預(yù)訓(xùn)練好的ResNet-18, 進(jìn)行一個(gè)圖像二分類遷移學(xué)習(xí)

開發(fā)/測(cè)試環(huán)境

  • Ubuntu 18.04
  • pycharm
  • Anaconda3, python3.6
  • pytorch1.0, torchvision

ResNet-18

image.png

實(shí)驗(yàn)內(nèi)容

準(zhǔn)備數(shù)據(jù)集

  • 訓(xùn)練集合
  • 驗(yàn)證集合

數(shù)據(jù)集下載鏈接

下載好之后,復(fù)制到工程 /data/ 路徑下


image.png

訓(xùn)練集合卧惜,驗(yàn)證集合


image.png

訓(xùn)練集厘灼,驗(yàn)證集 分別包含2個(gè)子文件夾,這是一個(gè)2分類問(wèn)題咽瓷。分類對(duì)象:螞蟻设凹,蜜蜂


image.png
  • 代碼
    因?yàn)橛?xùn)練一個(gè)2分類的模型,數(shù)據(jù)集加載直接使用pytorch提供的API——ImageFolder最方便茅姜。原始圖像為jpg格式闪朱,在制作數(shù)據(jù)集時(shí)候進(jìn)行了變換transforms。 加入對(duì)GPU的支持,首先判斷torch.cuda.is_available(),然后決定使用GPU or CPU
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import transforms
from torchvision import models
from torchvision.models import ResNet
import numpy as np
import matplotlib.pyplot as plt
import os
import utils


data_dir = './data/hymenoptera_data'

train_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'train'),
                                                 transform=transforms.Compose(
                                                     [
                                                         transforms.RandomResizedCrop(224),
                                                         transforms.RandomHorizontalFlip(),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize(
                                                             mean=(0.485, 0.456, 0.406),
                                                             std=(0.229, 0.224, 0.225))
                                                     ]))

val_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'val'),
                                               transform=transforms.Compose(
                                                     [
                                                         transforms.RandomResizedCrop(224),
                                                         transforms.RandomHorizontalFlip(),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize(
                                                             mean=(0.485, 0.456, 0.406),
                                                             std=(0.229, 0.224, 0.225))
                                                     ]))

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=4)

# 類別名稱
class_names = train_dataset.classes
print('class_names:{}'.format(class_names))

# 訓(xùn)練設(shè)備  CPU/GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('trian_device:{}'.format(device.type))

# 隨機(jī)顯示一個(gè)batch
plt.figure()
utils.imshow(next(iter(train_dataloader)))
plt.show()

獲取預(yù)訓(xùn)練模型

torchvision.models
torchvision中包含了一些常見的預(yù)訓(xùn)練模型:

image.png

AlexNet, VGG, SqueezeNet, Resnet奋姿,Inception, DenseNet

此次實(shí)驗(yàn)采用ResNet18網(wǎng)絡(luò)模型锄开。
torchvision.models中包含resnet18,首先會(huì)實(shí)例化一個(gè)ResNet網(wǎng)絡(luò)称诗, 然后model.load_dict()加載預(yù)訓(xùn)練好的模型萍悴。

def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

torchvision 默認(rèn)將模型保存在/home/.torch/models路徑。

image.png

預(yù)訓(xùn)練模型文件:


image.png
  • 代碼
    加載預(yù)訓(xùn)練模型寓免。需要注意的地方:修改ResNet最后一個(gè)全連接層的輸出個(gè)數(shù)癣诱,二分類問(wèn)題需要將輸出個(gè)數(shù)改為2。
# -------------------------模型選擇再榄,優(yōu)化方法狡刘, 學(xué)習(xí)率策略----------------------
model = models.resnet18(pretrained=True)

# 全連接層的輸入通道in_channels個(gè)數(shù)
num_fc_in = model.fc.in_features

# 改變?nèi)B接層,2分類問(wèn)題困鸥,out_features = 2
model.fc = nn.Linear(num_fc_in, 2)

# 模型遷移到CPU/GPU
model = model.to(device)

# 定義損失函數(shù)
loss_fc = nn.CrossEntropyLoss()

# 選擇優(yōu)化方法
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

# 學(xué)習(xí)率調(diào)整策略
# 每7個(gè)epoch調(diào)整一次
exp_lr_scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.5)  # step_size


訓(xùn)練嗅蔬,測(cè)試網(wǎng)絡(luò)

Epoch: 訓(xùn)練50個(gè)epoch
注意地方: 訓(xùn)練時(shí)候,需要調(diào)用model.train()將模型設(shè)置為訓(xùn)練模式疾就。測(cè)試時(shí)候澜术,調(diào)用model.eval() 將模型設(shè)置為測(cè)試模型,否則訓(xùn)練和測(cè)試結(jié)果不正確猬腰。

# ----------------訓(xùn)練過(guò)程-----------------
num_epochs = 50

for epoch in range(num_epochs):

    running_loss = 0.0
    exp_lr_scheduler.step()

    for i, sample_batch in enumerate(train_dataloader):
        inputs = sample_batch[0]
        labels = sample_batch[1]

        model.train()

        # GPU/CPU
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # foward
        outputs = model(inputs)

        # loss
        loss = loss_fc(outputs, labels)

        # loss求導(dǎo)鸟废,反向
        loss.backward()

        # 優(yōu)化
        optimizer.step()

        #
        running_loss += loss.item()

        # 測(cè)試
        if i % 20 == 19:
            correct = 0
            total = 0
            model.eval()
            for images_test, labels_test in val_dataloader:
                images_test = images_test.to(device)
                labels_test = labels_test.to(device)

                outputs_test = model(images_test)
                _, prediction = torch.max(outputs_test, 1)
                correct += (torch.sum((prediction == labels_test))).item()
               # print(prediction, labels_test, correct)
                total += labels_test.size(0)
            print('[{}, {}] running_loss = {:.5f} accurcay = {:.5f}'.format(epoch + 1, i + 1, running_loss / 20,
                                                                        correct / total))
            running_loss = 0.0

        # if i % 10 == 9:
        #     print('[{}, {}] loss={:.5f}'.format(epoch+1, i+1, running_loss / 10))
        #     running_loss = 0.0

print('training finish !')
torch.save(model.state_dict(), './model/model_2.pth')

訓(xùn)練輸出結(jié)果

image.png
image.png
image.png
image.png

隨著訓(xùn)練次數(shù)增加,accuracy基本上是上升趨勢(shì)姑荷,最終達(dá)到93%的準(zhǔn)確率盒延。

image.png
image.png
image.png

完整代碼

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import transforms
from torchvision import models
from torchvision.models import ResNet
import numpy as np
import matplotlib.pyplot as plt
import os
import utils


data_dir = './data/hymenoptera_data'

train_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'train'),
                                                 transform=transforms.Compose(
                                                     [
                                                         transforms.RandomResizedCrop(224),
                                                         transforms.RandomHorizontalFlip(),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize(
                                                             mean=(0.485, 0.456, 0.406),
                                                             std=(0.229, 0.224, 0.225))
                                                     ]))

val_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'val'),
                                               transform=transforms.Compose(
                                                     [
                                                         transforms.RandomResizedCrop(224),
                                                         transforms.RandomHorizontalFlip(),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize(
                                                             mean=(0.485, 0.456, 0.406),
                                                             std=(0.229, 0.224, 0.225))
                                                     ]))

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=4)

# 類別名稱
class_names = train_dataset.classes
print('class_names:{}'.format(class_names))

# 訓(xùn)練設(shè)備  CPU/GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('trian_device:{}'.format(device.type))

# 隨機(jī)顯示一個(gè)batch
#plt.figure()
#utils.imshow(next(iter(train_dataloader)))
#plt.show()

# -------------------------模型選擇,優(yōu)化方法鼠冕, 學(xué)習(xí)率策略----------------------
model = models.resnet18(pretrained=True)

# 全連接層的輸入通道in_channels個(gè)數(shù)
num_fc_in = model.fc.in_features

# 改變?nèi)B接層添寺,2分類問(wèn)題,out_features = 2
model.fc = nn.Linear(num_fc_in, 2)

# 模型遷移到CPU/GPU
model = model.to(device)

# 定義損失函數(shù)
loss_fc = nn.CrossEntropyLoss()

# 選擇優(yōu)化方法
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

# 學(xué)習(xí)率調(diào)整策略
# 每7個(gè)epoch調(diào)整一次
exp_lr_scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.5)  # step_size


# ----------------訓(xùn)練過(guò)程-----------------
num_epochs = 50

for epoch in range(num_epochs):

    running_loss = 0.0
    exp_lr_scheduler.step()

    for i, sample_batch in enumerate(train_dataloader):
        inputs = sample_batch[0]
        labels = sample_batch[1]

        model.train()

        # GPU/CPU
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # foward
        outputs = model(inputs)

        # loss
        loss = loss_fc(outputs, labels)

        # loss求導(dǎo)懈费,反向
        loss.backward()

        # 優(yōu)化
        optimizer.step()

        #
        running_loss += loss.item()

        # 測(cè)試
        if i % 20 == 19:
            correct = 0
            total = 0
            model.eval()
            for images_test, labels_test in val_dataloader:
                images_test = images_test.to(device)
                labels_test = labels_test.to(device)

                outputs_test = model(images_test)
                _, prediction = torch.max(outputs_test, 1)
                correct += (torch.sum((prediction == labels_test))).item()
               # print(prediction, labels_test, correct)
                total += labels_test.size(0)
            print('[{}, {}] running_loss = {:.5f} accurcay = {:.5f}'.format(epoch + 1, i + 1, running_loss / 20,
                                                                        correct / total))
            running_loss = 0.0

        # if i % 10 == 9:
        #     print('[{}, {}] loss={:.5f}'.format(epoch+1, i+1, running_loss / 10))
        #     running_loss = 0.0

print('training finish !')
torch.save(model.state_dict(), './model/model_2.pth')


End

參考:
https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
https://blog.csdn.net/sunqiande88/article/details/80100891

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末计露,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子憎乙,更是在濱河造成了極大的恐慌票罐,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,284評(píng)論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件泞边,死亡現(xiàn)場(chǎng)離奇詭異该押,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)繁堡,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,115評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門沈善,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)乡数,“玉大人椭蹄,你說(shuō)我怎么就攤上這事闻牡。” “怎么了绳矩?”我有些...
    開封第一講書人閱讀 164,614評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵罩润,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我翼馆,道長(zhǎng)割以,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,671評(píng)論 1 293
  • 正文 為了忘掉前任应媚,我火速辦了婚禮严沥,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘中姜。我一直安慰自己消玄,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,699評(píng)論 6 392
  • 文/花漫 我一把揭開白布丢胚。 她就那樣靜靜地躺著翩瓜,像睡著了一般。 火紅的嫁衣襯著肌膚如雪携龟。 梳的紋絲不亂的頭發(fā)上兔跌,一...
    開封第一講書人閱讀 51,562評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音峡蟋,去河邊找鬼坟桅。 笑死,一個(gè)胖子當(dāng)著我的面吹牛蕊蝗,可吹牛的內(nèi)容都是我干的仅乓。 我是一名探鬼主播,決...
    沈念sama閱讀 40,309評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼匿又,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼方灾!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起碌更,我...
    開封第一講書人閱讀 39,223評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤裕偿,失蹤者是張志新(化名)和其女友劉穎,沒(méi)想到半個(gè)月后痛单,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體嘿棘,經(jīng)...
    沈念sama閱讀 45,668評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,859評(píng)論 3 336
  • 正文 我和宋清朗相戀三年旭绒,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了鸟妙。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片焦人。...
    茶點(diǎn)故事閱讀 39,981評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖重父,靈堂內(nèi)的尸體忽然破棺而出花椭,到底是詐尸還是另有隱情,我是刑警寧澤房午,帶...
    沈念sama閱讀 35,705評(píng)論 5 347
  • 正文 年R本政府宣布矿辽,位于F島的核電站,受9級(jí)特大地震影響郭厌,放射性物質(zhì)發(fā)生泄漏袋倔。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,310評(píng)論 3 330
  • 文/蒙蒙 一折柠、第九天 我趴在偏房一處隱蔽的房頂上張望宾娜。 院中可真熱鬧,春花似錦扇售、人聲如沸前塔。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,904評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)嘱根。三九已至,卻和暖如春巷懈,著一層夾襖步出監(jiān)牢的瞬間该抒,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,023評(píng)論 1 270
  • 我被黑心中介騙來(lái)泰國(guó)打工顶燕, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留凑保,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,146評(píng)論 3 370
  • 正文 我出身青樓涌攻,卻偏偏與公主長(zhǎng)得像欧引,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子恳谎,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,933評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容