[翻譯]pytorch官方文檔-遷移學習教程

原官方網(wǎng)頁:https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

通過本教程,你將學到如何使用遷移學習訓練你的網(wǎng)絡(luò)孵稽。你可以在cs231n notes了解更多關(guān)于遷移學習

引用一些筆記:

  • 實際中,基本沒有人會從零開始(隨機初始化)訓練一個完整的卷積網(wǎng)絡(luò)莉撇,因為相對于網(wǎng)絡(luò),很難得到一個足夠大的數(shù)據(jù)集[網(wǎng)絡(luò)很深, 需要足夠大數(shù)據(jù)集]譬胎。通常的做法是在一個很大的數(shù)據(jù)集上進行預(yù)訓練得到卷積網(wǎng)絡(luò)ConvNet, 然后將這個ConvNet的參數(shù)作為目標任務(wù)的初始化參數(shù)或者固定這些參數(shù)

以下是應(yīng)用遷移學習的兩種場景:

  • 微調(diào)Convnet:使用預(yù)訓練的網(wǎng)絡(luò)(如在imagenet 1000上訓練而來的網(wǎng)絡(luò))來初始化自己的網(wǎng)絡(luò)磕蛇,而不是隨機初始化。其他的訓練步驟不變囱桨。
  • Convnet看成固定的特征提取器仓犬。首先固定ConvNet除了最后的全連接層外的其他所有層。最后的全連接層被替換成一個新的隨機初始化的層舍肠,只有這個新的層會被訓練[只有這層參數(shù)會在反向傳播時更新]
# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # interactive mode

1. 數(shù)據(jù)加載

我們通常會使用torchvisiontorch .utils.data包來加載數(shù)據(jù)

今天要解決的問題是訓練一個模型來分類螞蟻ants和蜜蜂bees搀继。ants和bees各有約120張訓練圖片窘面。每個類有75張驗證圖片。從零開始在如此小的數(shù)據(jù)集上進行訓練通常是很難泛化的叽躯。由于我們使用遷移學習财边,模型的泛化能力會相當好

這個數(shù)據(jù)集是imagenet的子集,可以在這里下載

# 訓練集數(shù)據(jù)增廣和歸一化
# 在驗證集上僅僅歸一化
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224), # 隨機裁剪一個area之后再resize
        transforms.RandomHorizontalFlip(), # 隨機水平翻轉(zhuǎn)
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

1.1 可視化一些數(shù)據(jù)

我們可視化了一些訓練圖片來明白數(shù)據(jù)增廣操作

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])
可視化結(jié)果

2. 訓練網(wǎng)絡(luò)

現(xiàn)在我們寫一個通用的函數(shù)來訓練網(wǎng)絡(luò)点骑。我們將展示:

  • 調(diào)整學習速率
  • 保存最好的模型

如下制圈,參數(shù)scheduler是一個來自torch.optim.lr_scheduler的學習速率調(diào)整類的對象(LR scheduler object)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

2.1 可視化模型的預(yù)測結(jié)果

一個通用的展示少量預(yù)測圖片的函數(shù)

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

3. 微調(diào)convnent

加載預(yù)訓練模型并且重置最后一個全連接層

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

3.1 訓練并評估

在CPU上將耗時大約15-25分鐘,在GPU上將花少于1分鐘的時間

  • 訓練
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)
  • output
Epoch 0/24
----------
train Loss: 0.6849 Acc: 0.6762
val Loss: 0.2146 Acc: 0.9281
.
.
.
Epoch 23/24
----------
train Loss: 0.2282 Acc: 0.9139
val Loss: 0.2709 Acc: 0.8954

Epoch 24/24
----------
train Loss: 0.3081 Acc: 0.8566
val Loss: 0.3045 Acc: 0.9020

Training complete in 0m 58s
Best val Acc: 0.928105
  • 可視化
visualize_model(model_ft)
訓練結(jié)果

4. 將convnent看成特征提取器

這里畔况,我們將凍結(jié)全部網(wǎng)絡(luò),除了最后一層慧库。我們應(yīng)該將需要設(shè)置欲凍結(jié)的參數(shù)的requires_grad == False跷跪,這樣在反向傳播backward()的時候他們的梯度就不會被計算
更多關(guān)于grad的文檔在這里

model_conv = torchvision.models.resnet18(pretrained=True)
# 最重要的一步
for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opoosed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

4.1 訓練和評估

在CPU上,固定參數(shù)相比于之前的作為初始化參數(shù)的做法齐板,會節(jié)約大約一半的時間吵瞻。這是可以預(yù)期的,因為網(wǎng)絡(luò)的絕大部分參數(shù)的梯度不會在反向傳播中計算甘磨。(但是這些參數(shù)是參與前向傳播的)

  • 訓練
model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)
  • output
Epoch 0/24
----------
train Loss: 0.6421 Acc: 0.6557
val Loss: 0.4560 Acc: 0.7451

Epoch 1/24
----------
train Loss: 0.4694 Acc: 0.7746
val Loss: 0.1616 Acc: 0.9608

Epoch 2/24
----------
train Loss: 0.4500 Acc: 0.7746
val Loss: 0.3041 Acc: 0.8627
.
.
.
Epoch 24/24
----------
train Loss: 0.3382 Acc: 0.8566
val Loss: 0.1605 Acc: 0.9542

Training complete in 0m 46s
Best val Acc: 0.967320
  • 可視化
visualize_model(model_conv)

plt.ioff()
plt.show()
訓練結(jié)果

5. 文件下載

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末橡羞,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子济舆,更是在濱河造成了極大的恐慌卿泽,老刑警劉巖,帶你破解...
    沈念sama閱讀 221,430評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件滋觉,死亡現(xiàn)場離奇詭異签夭,居然都是意外死亡,警方通過查閱死者的電腦和手機椎侠,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,406評論 3 398
  • 文/潘曉璐 我一進店門第租,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人我纪,你說我怎么就攤上這事慎宾。” “怎么了浅悉?”我有些...
    開封第一講書人閱讀 167,834評論 0 360
  • 文/不壞的土叔 我叫張陵趟据,是天一觀的道長。 經(jīng)常有香客問我术健,道長之宿,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 59,543評論 1 296
  • 正文 為了忘掉前任苛坚,我火速辦了婚禮比被,結(jié)果婚禮上色难,老公的妹妹穿的比我還像新娘。我一直安慰自己等缀,他們只是感情好枷莉,可當我...
    茶點故事閱讀 68,547評論 6 397
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著尺迂,像睡著了一般笤妙。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上噪裕,一...
    開封第一講書人閱讀 52,196評論 1 308
  • 那天蹲盘,我揣著相機與錄音,去河邊找鬼膳音。 笑死召衔,一個胖子當著我的面吹牛,可吹牛的內(nèi)容都是我干的祭陷。 我是一名探鬼主播苍凛,決...
    沈念sama閱讀 40,776評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼兵志!你這毒婦竟也來了醇蝴?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,671評論 0 276
  • 序言:老撾萬榮一對情侶失蹤想罕,失蹤者是張志新(化名)和其女友劉穎悠栓,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體按价,經(jīng)...
    沈念sama閱讀 46,221評論 1 320
  • 正文 獨居荒郊野嶺守林人離奇死亡闸迷,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,303評論 3 340
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了俘枫。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片腥沽。...
    茶點故事閱讀 40,444評論 1 352
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖鸠蚪,靈堂內(nèi)的尸體忽然破棺而出今阳,到底是詐尸還是另有隱情,我是刑警寧澤茅信,帶...
    沈念sama閱讀 36,134評論 5 350
  • 正文 年R本政府宣布盾舌,位于F島的核電站,受9級特大地震影響蘸鲸,放射性物質(zhì)發(fā)生泄漏妖谴。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,810評論 3 333
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望膝舅。 院中可真熱鬧嗡载,春花似錦、人聲如沸仍稀。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,285評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽技潘。三九已至遥巴,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間享幽,已是汗流浹背铲掐。 一陣腳步聲響...
    開封第一講書人閱讀 33,399評論 1 272
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留值桩,地道東北人摆霉。 一個月前我還...
    沈念sama閱讀 48,837評論 3 376
  • 正文 我出身青樓,卻偏偏與公主長得像颠毙,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子砂碉,可洞房花燭夜當晚...
    茶點故事閱讀 45,455評論 2 359

推薦閱讀更多精彩內(nèi)容

  • 翻譯論文匯總:https://github.com/SnailTyan/deep-learning-papers-...
    SnailTyan閱讀 2,206評論 0 7
  • 文章主要分為:一蛀蜜、深度學習概念;二增蹭、國內(nèi)外研究現(xiàn)狀滴某;三、深度學習模型結(jié)構(gòu)滋迈;四霎奢、深度學習訓練算法;五饼灿、深度學習的優(yōu)點...
    艾剪疏閱讀 21,844評論 0 58
  • 摘要 針對時空特征的學習幕侠,我們提出了一個簡單有效的方法,在大規(guī)模有監(jiān)督視頻數(shù)據(jù)集上使用深度3維卷積網(wǎng)絡(luò)(3D Co...
    鐘速閱讀 53,400評論 0 21
  • 包括: 理解卷積神經(jīng)網(wǎng)絡(luò) 使用數(shù)據(jù)增強緩解過擬合 使用預(yù)訓練卷積網(wǎng)絡(luò)做特征提取 微調(diào)預(yù)訓練網(wǎng)絡(luò)模型 可視化卷積網(wǎng)絡(luò)...
    七八音閱讀 1,812評論 0 2
  • 回顧一下之前的人生碍彭,有哪些事或者那些階段是自己認為成功的晤硕,沒有虛度人生的。 在回顧一下有哪些是自己頹廢庇忌,虛度時光的...
    顯微無間閱讀 189評論 0 0