訓(xùn)練中的騷操作——數(shù)據(jù)增強(qiáng)凌节、模型微調(diào)

先說說數(shù)據(jù)增強(qiáng)
大規(guī)模數(shù)據(jù)集是成功應(yīng)用深度神經(jīng)網(wǎng)絡(luò)的前提蕉斜。圖像增廣(image augmentation)技術(shù)通過對訓(xùn)練圖像做一系列隨機(jī)改變,來產(chǎn)生相似但又不同的訓(xùn)練樣本怕膛,從而擴(kuò)大訓(xùn)練數(shù)據(jù)集的規(guī)模熟嫩。圖像增廣的另一種解釋是,隨機(jī)改變訓(xùn)練樣本可以降低模型對某些屬性的依賴褐捻,從而提高模型的泛化能力掸茅。例如椅邓,我們可以對圖像進(jìn)行不同方式的裁剪,使感興趣的物體出現(xiàn)在不同位置昧狮,從而減輕模型對物體出現(xiàn)位置的依賴性景馁。我們也可以調(diào)整亮度、色彩等因素來降低模型對色彩的敏感度逗鸣『献。可以說,在當(dāng)年AlexNet的成功中撒璧,圖像增廣技術(shù)功不可沒透葛。
顯示圖像:

def show_images(imgs, num_rows, num_cols, scale=2):
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    for i in range(num_rows):
        for j in range(num_cols):
            axes[i][j].imshow(imgs[i * num_cols + j])
            axes[i][j].axes.get_xaxis().set_visible(False)
            axes[i][j].axes.get_yaxis().set_visible(False)
    return axes

構(gòu)造輔助函數(shù):

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    show_images(Y, num_rows, num_cols, scale)

翻轉(zhuǎn)和裁剪

左右翻轉(zhuǎn)圖像通常不改變物體的類別。它是最早也是最廣泛使用的一種圖像增廣方法卿樱。下面我們通過torchvision.transforms模塊創(chuàng)建RandomHorizontalFlip實例來實現(xiàn)一半概率的圖像水平(左右)翻轉(zhuǎn)僚害。
apply(img, torchvision.transforms.RandomHorizontalFlip())

上下翻轉(zhuǎn)不如左右翻轉(zhuǎn)通用。但是至少對于樣例圖像繁调,上下翻轉(zhuǎn)不會造成識別障礙萨蚕。下面我們創(chuàng)建RandomVerticalFlip實例來實現(xiàn)一半概率的圖像垂直(上下)翻轉(zhuǎn)。
apply(img, torchvision.transforms.RandomVerticalFlip())


變化顏色蹄胰。我們可以從4個方面改變圖像的顏色:亮度(brightness)岳遥、對比度(contrast)飽和度(saturation)色調(diào)(hue)烤送。在下面的例子里寒随,我們將圖像的亮度隨機(jī)變化為原圖亮度的()()。
apply(img, torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0))

img, torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.5)
另外帮坚,我們還可以疊加多個圖像增廣方法

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)
  • 為了在預(yù)測時得到確定的結(jié)果妻往,我們通常只將圖像增廣應(yīng)用在訓(xùn)練樣本上,而不在預(yù)測時使用含隨機(jī)操作的圖像增廣试和。在這里我們只使用最簡單的隨機(jī)左右翻轉(zhuǎn)讯泣。此外,我們使用ToTensor將小批量圖像轉(zhuǎn)成PyTorch需要的格式阅悍,即形狀為(批量大小, 通道數(shù), 高, 寬)好渠、值域在0到1之間且類型為32位浮點數(shù)。
flip_aug = torchvision.transforms.Compose([
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.ToTensor()])

no_aug = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor()])
num_workers = 0 if sys.platform.startswith('win32') else 4
def load_cifar10(is_train, augs, batch_size, root=CIFAR_ROOT_PATH):
    dataset = torchvision.datasets.CIFAR10(root=root, train=is_train, transform=augs, download=False)
    return DataLoader(dataset, batch_size=batch_size, shuffle=is_train, num_workers=num_workers)
def train_with_data_aug(train_augs, test_augs, lr=0.001):
# 設(shè)計方法使得訓(xùn)練圖像進(jìn)行aug圖像處理
    batch_size, net = 256, d2l.resnet18(10)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = torch.nn.CrossEntropyLoss()
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=10)

模型微調(diào)

在前面的一些章節(jié)中节视,我們介紹了如何在只有6萬張圖像的Fashion-MNIST訓(xùn)練數(shù)據(jù)集上訓(xùn)練模型拳锚。我們還描述了學(xué)術(shù)界當(dāng)下使用最廣泛的大規(guī)模圖像數(shù)據(jù)集ImageNet,它有超過1,000萬的圖像和1,000類的物體寻行。然而霍掺,我們平常接觸到數(shù)據(jù)集的規(guī)模通常在這兩者之間。

假設(shè)我們想從圖像中識別出不同種類的椅子,然后將購買鏈接推薦給用戶杆烁。一種可能的方法是先找出100種常見的椅子牙丽,為每種椅子拍攝1,000張不同角度的圖像,然后在收集到的圖像數(shù)據(jù)集上訓(xùn)練一個分類模型兔魂。這個椅子數(shù)據(jù)集雖然可能比Fashion-MNIST數(shù)據(jù)集要龐大烤芦,但樣本數(shù)仍然不及ImageNet數(shù)據(jù)集中樣本數(shù)的十分之一。這可能會導(dǎo)致適用于ImageNet數(shù)據(jù)集的復(fù)雜模型在這個椅子數(shù)據(jù)集上過擬合析校。同時构罗,因為數(shù)據(jù)量有限,最終訓(xùn)練得到的模型的精度也可能達(dá)不到實用的要求勺良。

為了應(yīng)對上述問題绰播,一個顯而易見的解決辦法是收集更多的數(shù)據(jù)骄噪。然而尚困,收集和標(biāo)注數(shù)據(jù)會花費大量的時間和資金。例如链蕊,為了收集ImageNet數(shù)據(jù)集事甜,研究人員花費了數(shù)百萬美元的研究經(jīng)費。雖然目前的數(shù)據(jù)采集成本已降低了不少滔韵,但其成本仍然不可忽略逻谦。

另外一種解決辦法是應(yīng)用遷移學(xué)習(xí)(transfer learning),將從源數(shù)據(jù)集學(xué)到的知識遷移到目標(biāo)數(shù)據(jù)集上陪蜻。例如邦马,雖然ImageNet數(shù)據(jù)集的圖像大多跟椅子無關(guān),但在該數(shù)據(jù)集上訓(xùn)練的模型可以抽取較通用的圖像特征宴卖,從而能夠幫助識別邊緣滋将、紋理、形狀和物體組成等症昏。這些類似的特征對于識別椅子也可能同樣有效随闽。

本節(jié)我們介紹遷移學(xué)習(xí)中的一種常用技術(shù):微調(diào)(fine tuning)。如圖9.1所示肝谭,微調(diào)由以下4步構(gòu)成掘宪。

  1. 在源數(shù)據(jù)集(如ImageNet數(shù)據(jù)集)上預(yù)訓(xùn)練一個神經(jīng)網(wǎng)絡(luò)模型,即源模型攘烛。
  2. 創(chuàng)建一個新的神經(jīng)網(wǎng)絡(luò)模型魏滚,即目標(biāo)模型。它復(fù)制了源模型上除了輸出層外的所有模型設(shè)計及其參數(shù)坟漱。我們假設(shè)這些模型參數(shù)包含了源數(shù)據(jù)集上學(xué)習(xí)到的知識鼠次,且這些知識同樣適用于目標(biāo)數(shù)據(jù)集。我們還假設(shè)源模型的輸出層跟源數(shù)據(jù)集的標(biāo)簽緊密相關(guān),因此在目標(biāo)模型中不予采用须眷。
  3. 為目標(biāo)模型添加一個輸出大小為目標(biāo)數(shù)據(jù)集類別個數(shù)的輸出層竖瘾,并隨機(jī)初始化該層的模型參數(shù)。
  4. 在目標(biāo)數(shù)據(jù)集(如椅子數(shù)據(jù)集)上訓(xùn)練目標(biāo)模型花颗。我們將從頭訓(xùn)練輸出層捕传,而其余層的參數(shù)都是基于源模型的參數(shù)微調(diào)得到的。
微調(diào)

舉個例子

我們將基于一個小數(shù)據(jù)集對在ImageNet數(shù)據(jù)集上訓(xùn)練好的ResNet模型進(jìn)行微調(diào)扩劝。該小數(shù)據(jù)集含有數(shù)千張包含熱狗和不包含熱狗的圖像庸论。我們將使用微調(diào)得到的模型來識別一張圖像中是否包含熱狗。

首先棒呛,導(dǎo)入實驗所需的包或模塊聂示。torchvision的models包提供了常用的預(yù)訓(xùn)練模型。如果希望獲取更多的預(yù)訓(xùn)練模型簇秒,可以使用使用pretrained-models.pytorch倉庫鱼喉。

完整代碼:
%matplotlib inline
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import models
import os

import sys

sys.path.append("/home/kesci/input/")
import d2lzh1981 as d2l

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

使用的熱狗數(shù)據(jù)集(點擊下載)是從網(wǎng)上抓取的,它含有1400張包含熱狗的正類圖像趋观,和同樣多包含其他食品的負(fù)類圖像扛禽。各類的1000張圖像被用于訓(xùn)練,其余則用于測試皱坛。

我們首先將壓縮后的數(shù)據(jù)集下載到路徑data_dir之下编曼,然后在該路徑將下載好的數(shù)據(jù)集解壓,得到兩個文件夾hotdog/trainhotdog/test剩辟。這兩個文件夾下面均有hotdognot-hotdog兩個類別文件夾掐场,每個類別文件夾里面是圖像文件。
我們創(chuàng)建兩個ImageFolder實例來分別讀取訓(xùn)練數(shù)據(jù)集和測試數(shù)據(jù)集中的所有圖像文件贩猎。

import os
os.listdir('/home/kesci/input/resnet185352')
data_dir = '/home/kesci/input/hotdog4014'
os.listdir(os.path.join(data_dir, "hotdog"))
train_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/train'))
test_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/test'))
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

在訓(xùn)練時熊户,我們先從圖像中裁剪出隨機(jī)大小和隨機(jī)高寬比的一塊隨機(jī)區(qū)域,然后將該區(qū)域縮放為高和寬均為224像素的輸入融欧。測試時敏弃,我們將圖像的高和寬均縮放為256像素,然后從中裁剪出高和寬均為224像素的中心區(qū)域作為輸入噪馏。此外麦到,我們對RGB(紅、綠欠肾、藍(lán))三個顏色通道的數(shù)值做標(biāo)準(zhǔn)化:每個數(shù)值減去該通道所有數(shù)值的平均值瓶颠,再除以該通道所有數(shù)值的標(biāo)準(zhǔn)差作為輸出。

注: 在使用預(yù)訓(xùn)練模型時刺桃,一定要和預(yù)訓(xùn)練時作同樣的預(yù)處理粹淋。

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_augs = transforms.Compose([
        transforms.RandomResizedCrop(size=224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])

test_augs = transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        normalize
    ])

定義和初始化模型

我們使用在ImageNet數(shù)據(jù)集上預(yù)訓(xùn)練的ResNet-18作為源模型。這里指定pretrained=True來自動下載并加載預(yù)訓(xùn)練的模型參數(shù)。在第一次使用時需要聯(lián)網(wǎng)下載模型參數(shù)桃移。

pretrained_net = models.resnet18(pretrained=False)
pretrained_net.load_state_dict(torch.load('/home/kesci/input/resnet185352/resnet18-5c106cde.pth'))

下面打印源模型的成員變量fc屋匕。作為一個全連接層,它將ResNet最終的全局平均池化層輸出變換成ImageNet數(shù)據(jù)集上1000類的輸出借杰。
print(pretrained_net.fc)

注: 如果你使用的是其他模型过吻,那可能沒有成員變量fc(比如models中的VGG預(yù)訓(xùn)練模型),所以正確做法是查看對應(yīng)模型源碼中其定義部分蔗衡,這樣既不會出錯也能加深我們對模型的理解纤虽。pretrained-models.pytorch倉庫貌似統(tǒng)一了接口,但是我還是建議使用時查看一下對應(yīng)模型的源碼绞惦。

可見此時pretrained_net最后的輸出個數(shù)等于目標(biāo)數(shù)據(jù)集的類別數(shù)1000逼纸。所以我們應(yīng)該將最后的fc成修改我們需要的輸出類別數(shù):

pretrained_net.fc = nn.Linear(512, 2)
print(pretrained_net.fc)

Linear(in_features=512, out_features=2, bias=True)

此時,pretrained_netfc層就被隨機(jī)初始化了济蝉,但是其他層依然保存著預(yù)訓(xùn)練得到的參數(shù)杰刽。由于是在很大的ImageNet數(shù)據(jù)集上預(yù)訓(xùn)練的,所以參數(shù)已經(jīng)足夠好堆生,因此一般只需使用較小的學(xué)習(xí)率來微調(diào)這些參數(shù)专缠,而fc中的隨機(jī)初始化參數(shù)一般需要更大的學(xué)習(xí)率從頭訓(xùn)練。PyTorch可以方便的對模型的不同部分設(shè)置不同的學(xué)習(xí)參數(shù)淑仆,我們在下面代碼中將fc的學(xué)習(xí)率設(shè)為已經(jīng)預(yù)訓(xùn)練過的部分的10倍。

output_params = list(map(id, pretrained_net.fc.parameters()))
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())

lr = 0.01
optimizer = optim.SGD([{'params': feature_params},
                       {'params': pretrained_net.fc.parameters(), 'lr': lr * 10}],
                       lr=lr, weight_decay=0.001)
# 模型微調(diào)
def train_fine_tuning(net, optimizer, batch_size=128, num_epochs=5):
    train_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/train'), transform=train_augs),
                            batch_size, shuffle=True)
    test_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/test'), transform=test_augs),
                           batch_size)
    loss = torch.nn.CrossEntropyLoss()
    d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末哥力,一起剝皮案震驚了整個濱河市蔗怠,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌吩跋,老刑警劉巖寞射,帶你破解...
    沈念sama閱讀 211,376評論 6 491
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異锌钮,居然都是意外死亡桥温,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,126評論 2 385
  • 文/潘曉璐 我一進(jìn)店門梁丘,熙熙樓的掌柜王于貴愁眉苦臉地迎上來侵浸,“玉大人,你說我怎么就攤上這事氛谜√途酰” “怎么了?”我有些...
    開封第一講書人閱讀 156,966評論 0 347
  • 文/不壞的土叔 我叫張陵值漫,是天一觀的道長澳腹。 經(jīng)常有香客問我,道長,這世上最難降的妖魔是什么酱塔? 我笑而不...
    開封第一講書人閱讀 56,432評論 1 283
  • 正文 為了忘掉前任沥邻,我火速辦了婚禮,結(jié)果婚禮上羊娃,老公的妹妹穿的比我還像新娘谋国。我一直安慰自己,他們只是感情好迁沫,可當(dāng)我...
    茶點故事閱讀 65,519評論 6 385
  • 文/花漫 我一把揭開白布芦瘾。 她就那樣靜靜地躺著,像睡著了一般集畅。 火紅的嫁衣襯著肌膚如雪近弟。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,792評論 1 290
  • 那天挺智,我揣著相機(jī)與錄音祷愉,去河邊找鬼。 笑死赦颇,一個胖子當(dāng)著我的面吹牛二鳄,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播媒怯,決...
    沈念sama閱讀 38,933評論 3 406
  • 文/蒼蘭香墨 我猛地睜開眼订讼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了扇苞?” 一聲冷哼從身側(cè)響起欺殿,我...
    開封第一講書人閱讀 37,701評論 0 266
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎鳖敷,沒想到半個月后脖苏,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,143評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡定踱,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,488評論 2 327
  • 正文 我和宋清朗相戀三年棍潘,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片崖媚。...
    茶點故事閱讀 38,626評論 1 340
  • 序言:一個原本活蹦亂跳的男人離奇死亡亦歉,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出至扰,到底是詐尸還是另有隱情鳍徽,我是刑警寧澤,帶...
    沈念sama閱讀 34,292評論 4 329
  • 正文 年R本政府宣布敢课,位于F島的核電站阶祭,受9級特大地震影響绷杜,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜濒募,卻給世界環(huán)境...
    茶點故事閱讀 39,896評論 3 313
  • 文/蒙蒙 一鞭盟、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧瑰剃,春花似錦齿诉、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,742評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至挥唠,卻和暖如春抵恋,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背宝磨。 一陣腳步聲響...
    開封第一講書人閱讀 31,977評論 1 265
  • 我被黑心中介騙來泰國打工弧关, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人唤锉。 一個月前我還...
    沈念sama閱讀 46,324評論 2 360
  • 正文 我出身青樓世囊,卻偏偏與公主長得像,于是被迫代替她去往敵國和親窿祥。 傳聞我的和親對象是個殘疾皇子株憾,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 43,494評論 2 348

推薦閱讀更多精彩內(nèi)容