深入淺出PyTorch_3_主要組成部分

[toc]

基本流程

完成一項(xiàng)深度學(xué)習(xí)任務(wù)的基本流程大致如下:

  1. 數(shù)據(jù)預(yù)處理
  2. 模型構(gòu)建
  3. 模型訓(xùn)練
  4. 模型導(dǎo)出及應(yīng)用

模型構(gòu)建則是關(guān)鍵懒闷,選擇適當(dāng)?shù)哪P褪酰⒃O(shè)定損失函數(shù)和優(yōu)化函數(shù),以及對(duì)應(yīng)的超參數(shù)(當(dāng)然可以使用sklearn這樣的機(jī)器學(xué)習(xí)庫(kù)中模型自帶的損失函數(shù)和優(yōu)化器)愤估;再進(jìn)行模型訓(xùn)練帮辟,用模型去擬合訓(xùn)練集數(shù)據(jù),并在驗(yàn)證集/測(cè)試集上計(jì)算模型表現(xiàn)玩焰;最后將訓(xùn)練好的模型進(jìn)行導(dǎo)出由驹,進(jìn)行下一步應(yīng)用。

數(shù)據(jù)預(yù)處理

數(shù)據(jù)預(yù)處理主要包括數(shù)據(jù)讀入昔园、數(shù)據(jù)集劃分及相關(guān)任務(wù)的預(yù)處理蔓榄,該過(guò)程中需要重點(diǎn)關(guān)注數(shù)據(jù)格式的統(tǒng)一和必要的數(shù)據(jù)變換闹炉。

PyTorch數(shù)據(jù)讀入是通過(guò)Dataset+Dataloader的方式完成的,Dataset定義好數(shù)據(jù)的格式和數(shù)據(jù)變換形式润樱,Dataloader用iterative的方式不斷讀入批次數(shù)據(jù)渣触。

我們可以定義自己的Dataset類來(lái)實(shí)現(xiàn)靈活的數(shù)據(jù)讀取,定義的類需要繼承PyTorch自身的Dataset類壹若。主要包含三個(gè)函數(shù):

  • __init__: 用于向類中傳入外部參數(shù)嗅钻,同時(shí)定義樣本集
  • __getitem__: 用于逐個(gè)讀取樣本集合中的元素,可以進(jìn)行一定的變換店展,并將返回訓(xùn)練/驗(yàn)證所需的數(shù)據(jù)
  • __len__: 用于返回?cái)?shù)據(jù)集的樣本數(shù)

構(gòu)建好Dataset后养篓,就可以使用DataLoader來(lái)按批次讀入數(shù)據(jù)了,實(shí)現(xiàn)代碼如下:

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)

其中:

  • batch_size:樣本是按“批”讀入的赂蕴,batch_size就是每次讀入的樣本數(shù)
  • num_workers:有多少個(gè)進(jìn)程用于讀取數(shù)據(jù)
  • shuffle:是否將讀入的數(shù)據(jù)打亂
  • drop_last:對(duì)于樣本最后一部分沒(méi)有達(dá)到批次數(shù)的樣本柳弄,使其不再參與訓(xùn)練

模型構(gòu)建

模型構(gòu)建則是關(guān)鍵,選擇適當(dāng)?shù)哪P透潘担瑯?gòu)建出神經(jīng)網(wǎng)絡(luò)中相應(yīng)的層碧注,并設(shè)定損失函數(shù)和優(yōu)化函數(shù),以及對(duì)應(yīng)的超參數(shù)(當(dāng)然可以使用sklearn這樣的機(jī)器學(xué)習(xí)庫(kù)中模型自帶的損失函數(shù)和優(yōu)化器)

超參數(shù)

常見(jiàn)的超參數(shù)有糖赔,如下:

  • batch size
  • 初始學(xué)習(xí)率(初始)
  • 訓(xùn)練次數(shù)(max_epochs)
  • GPU配置

神經(jīng)網(wǎng)絡(luò)的層

深度學(xué)習(xí)的一個(gè)魅力在于神經(jīng)網(wǎng)絡(luò)中各式各樣的層萍丐,以卷積神經(jīng)網(wǎng)絡(luò)(CNN)為例,其主要涵蓋如下層:

  • 卷積層
  • 池化層
  • 激活函數(shù)層
  • 歸一化層
  • 全連接層

卷積層

卷積層由參數(shù)可學(xué)習(xí)的卷積核組成放典。卷積核的寬度和長(zhǎng)度可改變逝变,深度必須與輸入層的通道數(shù)一致。

比如說(shuō)輸入32x32x3的圖片奋构,一個(gè)卷積核的大小為5x5x3壳影,一個(gè)卷積核在padding=0情況下劃窗生成一個(gè)二維的激活圖(28x28x1)。


image

如果我們有6個(gè)5x5x3的卷積核弥臼,就可以生成28286的激活圖宴咧。輸出層的通道數(shù)與卷積核個(gè)數(shù)一致。

image

三維卷積的Pytorch操作如下:

import torch
import torch.nn as nn


x=torch.randn(5,3,10,224,224)
conv = nn.Conv3d(3, 64, kernel_size=(5,5,3), stride=1, padding=1)
print(conv.weight.size())# torch.Size([64, 3, 5, 5, 3])
out=conv(x)
print(out.size())#torch.Size([5, 64, 8, 222, 224])

池化層

池化層用來(lái)控制圖片的空間尺寸醋火,相當(dāng)于一個(gè)降采樣的過(guò)程悠汽。同時(shí),池化層也有著控制過(guò)擬合的作用芥驳。有maxpooling柿冲,averagepooling等類型。


image

激活函數(shù)層

所謂激活兆旬,實(shí)際上是對(duì)卷積層的輸出結(jié)果做一次非線性映射假抄。激活函數(shù)可以引入非線性因素,解決線性模型所不能解決的問(wèn)題。

常用的激活函數(shù)有sigmoid宿饱,ReLU熏瞄,tanh,leakyReLU等等

image

歸一化層

最常用的歸一化層是Batch Normalization。能使訓(xùn)練速度大大加快谬以。


image

全連接層

全連接層(fully connected layers强饮,F(xiàn)C)指的是神經(jīng)元完全與輸入的變量連接,在整個(gè)卷積神經(jīng)網(wǎng)絡(luò)中起到“分類器”的作用为黎。


image

損失函數(shù)

在PyTorch中邮丰,損失函數(shù)是必不可少的。它是數(shù)據(jù)輸入到模型當(dāng)中铭乾,產(chǎn)生的結(jié)果與真實(shí)標(biāo)簽的評(píng)價(jià)指標(biāo)剪廉,我們的模型可以按照損失函數(shù)的目標(biāo)來(lái)做出改進(jìn)。常見(jiàn)的損失函數(shù)主要有:

  • 二分類交叉熵?fù)p失函數(shù)
  • 交叉熵?fù)p失函數(shù)
  • L1損失函數(shù)
  • MSE損失函數(shù)
  • ...

以二分類交叉熵?fù)p失函數(shù)為例炕檩,在pytorch代碼如下:

torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')


功能:計(jì)算二分類任務(wù)時(shí)的交叉熵(Cross Entropy)函數(shù)斗蒋。在二分類中,label是{0,1}笛质。對(duì)于進(jìn)入交叉熵函數(shù)的input為概率分布的形式泉沾。一般來(lái)說(shuō),input為sigmoid激活層的輸出经瓷,或者softmax的輸出爆哑。

主要參數(shù)
weight:每個(gè)類別的loss設(shè)置權(quán)值

size_average:數(shù)據(jù)為bool,為True時(shí)舆吮,返回的loss為平均值;為False時(shí)队贱,返回的各樣本的loss之和色冀。

reduce:數(shù)據(jù)類型為bool,為True時(shí)柱嫌,loss的返回是標(biāo)量锋恬。

優(yōu)化器

深度學(xué)習(xí)的目標(biāo)是通過(guò)不斷改變網(wǎng)絡(luò)參數(shù),使得參數(shù)能夠?qū)斎胱龈鞣N非線性變換擬合輸出编丘,本質(zhì)上就是一個(gè)函數(shù)去尋找最優(yōu)解与学,只不過(guò)這個(gè)最優(yōu)解使一個(gè)矩陣,而如何快速求得這個(gè)最優(yōu)解是深度學(xué)習(xí)研究的一個(gè)重點(diǎn)嘉抓,以經(jīng)典的resnet-50為例索守,它大約有2000萬(wàn)個(gè)系數(shù)需要進(jìn)行計(jì)算,那么我們?nèi)绾斡?jì)算出來(lái)這么多的系數(shù)抑片,有以下兩種方法:

  1. 第一種是最直接的暴力窮舉一遍參數(shù)卵佛,這種方法的實(shí)施可能性基本為0,堪比愚公移山plus的難度。
  2. 為了使求解參數(shù)過(guò)程更加快截汪,人們提出了第二種辦法疾牲,即就是是BP+優(yōu)化器逼近求解。

因此衙解,優(yōu)化器就是根據(jù)網(wǎng)絡(luò)反向傳播的梯度信息來(lái)更新網(wǎng)絡(luò)的參數(shù)阳柔,以起到降低loss函數(shù)計(jì)算值,使得模型輸出更加接近真實(shí)標(biāo)簽蚓峦。

Pytorch很人性化的給我們提供了一個(gè)優(yōu)化器的庫(kù)torch.optim舌剂,在這里面給我們提供了十種優(yōu)化器。

  • torch.optim.ASGD
  • torch.optim.Adadelta
  • torch.optim.Adagrad
  • torch.optim.Adam
  • torch.optim.AdamW
  • torch.optim.Adamax
  • torch.optim.LBFGS
  • torch.optim.RMSprop
  • torch.optim.Rprop
  • torch.optim.SGD
  • torch.optim.SparseAdam

而以上這些優(yōu)化算法均繼承于Optimizer

模型訓(xùn)練

訓(xùn)練和評(píng)估

完成了上述設(shè)定后就可以加載數(shù)據(jù)開始訓(xùn)練模型了枫匾。首先應(yīng)該設(shè)置模型的狀態(tài):如果是訓(xùn)練狀態(tài)架诞,那么模型的參數(shù)應(yīng)該支持反向傳播的修改;如果是驗(yàn)證/測(cè)試狀態(tài)干茉,則不應(yīng)該修改模型參數(shù)谴忧。在PyTorch中,模型的狀態(tài)設(shè)置非常簡(jiǎn)便角虫,如下的兩個(gè)操作二選一即可:

model.train()   # 訓(xùn)練狀態(tài)
model.eval()   # 驗(yàn)證/測(cè)試狀態(tài)

我們前面在DataLoader構(gòu)建完成后介紹了如何從中讀取數(shù)據(jù)沾谓,在訓(xùn)練過(guò)程中使用類似的操作即可,區(qū)別在于此時(shí)要用for循環(huán)讀取DataLoader中的全部數(shù)據(jù)戳鹅。

for data, label in train_loader:

之后將數(shù)據(jù)放到GPU上用于后續(xù)計(jì)算均驶,此處以.cuda()為例

data, label = data.cuda(), label.cuda()

開始用當(dāng)前批次數(shù)據(jù)做訓(xùn)練時(shí)缘琅,應(yīng)當(dāng)先將優(yōu)化器的梯度置零:

optimizer.zero_grad()

之后將data送入模型中訓(xùn)練:

output = model(data)

根據(jù)預(yù)先定義的criterion計(jì)算損失函數(shù):

loss = criterion(output, label)

將loss反向傳播回網(wǎng)絡(luò):

loss.backward()

使用優(yōu)化器更新模型參數(shù):

optimizer.step()

這樣一個(gè)訓(xùn)練過(guò)程就完成了赁炎,后續(xù)還可以計(jì)算模型準(zhǔn)確率等指標(biāo),這部分會(huì)在下一節(jié)的圖像分類實(shí)戰(zhàn)中加以介紹拱燃。

驗(yàn)證/測(cè)試的流程基本與訓(xùn)練過(guò)程一致隶债,不同點(diǎn)在于:

  • 需要預(yù)先設(shè)置torch.no_grad腾它,以及將model調(diào)至eval模式
  • 不需要將優(yōu)化器的梯度置零
  • 不需要將loss反向回傳到網(wǎng)絡(luò)
  • 不需要更新optimizer

一個(gè)完整的訓(xùn)練過(guò)程如下所示:

def train(epoch):
    model.train()
    train_loss = 0
    for data, label in train_loader:
        data, label = data.cuda(), label.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(label, output)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*data.size(0)
    train_loss = train_loss/len(train_loader.dataset)
        print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))

對(duì)應(yīng)的,一個(gè)完整的驗(yàn)證過(guò)程如下所示:

def val(epoch):       
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, label in val_loader:
            data, label = data.cuda(), label.cuda()
            output = model(data)
            preds = torch.argmax(output, 1)
            loss = criterion(output, label)
            val_loss += loss.item()*data.size(0)
            running_accu += torch.sum(preds == label.data)
    val_loss = val_loss/len(val_loader.dataset)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, val_loss))

模型導(dǎo)出及應(yīng)用

簡(jiǎn)單的保存與加載方法:

# 保存整個(gè)網(wǎng)絡(luò)
torch.save(net, PATH) 
# 保存網(wǎng)絡(luò)中的參數(shù), 速度快死讹,占空間少
torch.save(net.state_dict(),PATH)
#--------------------------------------------------
#針對(duì)上面一般的保存方法瞒滴,加載的方法分別是:
model_dict=torch.load(PATH)
model_dict=model.load_state_dict(torch.load(PATH))
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市赞警,隨后出現(xiàn)的幾起案子妓忍,更是在濱河造成了極大的恐慌,老刑警劉巖愧旦,帶你破解...
    沈念sama閱讀 219,188評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件世剖,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡忘瓦,警方通過(guò)查閱死者的電腦和手機(jī)搁廓,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,464評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門引颈,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人境蜕,你說(shuō)我怎么就攤上這事蝙场。” “怎么了粱年?”我有些...
    開封第一講書人閱讀 165,562評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵售滤,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我台诗,道長(zhǎng)完箩,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,893評(píng)論 1 295
  • 正文 為了忘掉前任拉队,我火速辦了婚禮弊知,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘粱快。我一直安慰自己秩彤,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,917評(píng)論 6 392
  • 文/花漫 我一把揭開白布事哭。 她就那樣靜靜地躺著漫雷,像睡著了一般。 火紅的嫁衣襯著肌膚如雪鳍咱。 梳的紋絲不亂的頭發(fā)上降盹,一...
    開封第一講書人閱讀 51,708評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音谤辜,去河邊找鬼蓄坏。 笑死,一個(gè)胖子當(dāng)著我的面吹牛丑念,可吹牛的內(nèi)容都是我干的剑辫。 我是一名探鬼主播,決...
    沈念sama閱讀 40,430評(píng)論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼渠欺,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了椎眯?” 一聲冷哼從身側(cè)響起挠将,我...
    開封第一講書人閱讀 39,342評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎编整,沒(méi)想到半個(gè)月后舔稀,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,801評(píng)論 1 317
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡掌测,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,976評(píng)論 3 337
  • 正文 我和宋清朗相戀三年内贮,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,115評(píng)論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡夜郁,死狀恐怖什燕,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情竞端,我是刑警寧澤屎即,帶...
    沈念sama閱讀 35,804評(píng)論 5 346
  • 正文 年R本政府宣布,位于F島的核電站事富,受9級(jí)特大地震影響技俐,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜统台,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,458評(píng)論 3 331
  • 文/蒙蒙 一雕擂、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧贱勃,春花似錦井赌、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,008評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至拔鹰,卻和暖如春仪缸,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背列肢。 一陣腳步聲響...
    開封第一講書人閱讀 33,135評(píng)論 1 272
  • 我被黑心中介騙來(lái)泰國(guó)打工恰画, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人瓷马。 一個(gè)月前我還...
    沈念sama閱讀 48,365評(píng)論 3 373
  • 正文 我出身青樓拴还,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親欧聘。 傳聞我的和親對(duì)象是個(gè)殘疾皇子片林,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,055評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容