PyTorch學(xué)習(xí)筆記(四):構(gòu)建神經(jīng)網(wǎng)絡(luò)

Tensor和自動(dòng)求導(dǎo)屬于PyTorch中較為底層的特性,如果要實(shí)現(xiàn)一個(gè)神經(jīng)網(wǎng)絡(luò)我們不需要從Tensor開(kāi)始腻要,PyTorch已經(jīng)為我們封裝了專門(mén)為深度學(xué)習(xí)而設(shè)計(jì)的模塊范咨,這個(gè)模塊就是torch.nn

NN工具箱

為方便用戶使用掏觉,PyTorch實(shí)現(xiàn)了神經(jīng)網(wǎng)絡(luò)中絕大多數(shù)的layer,這些layer都繼承于nn.Module榴鼎。這個(gè)類(lèi)封裝了可學(xué)習(xí)參數(shù)伯诬,并實(shí)現(xiàn)了forward函數(shù),且很多都專門(mén)針對(duì)GPU運(yùn)算進(jìn)行了CuDNN優(yōu)化巫财,其速度和性能都十分優(yōu)異盗似。

torch.nn.Module:PyTorch中神經(jīng)網(wǎng)絡(luò)的基本類(lèi),模型和網(wǎng)絡(luò)層都是它的子類(lèi)平项。

class Model(nn.Module): 
      # 繼承nn.Module
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__() 
        ...

    def forward(self, x):
        ...

torch.nn:PyTorch中的神經(jīng)網(wǎng)絡(luò)工具箱赫舒,所有的常用神經(jīng)網(wǎng)絡(luò)模型都在這個(gè)模塊中,包括全連接層闽瓢、卷積層接癌、RNN等。

 torch.nn.Linear()
 torch.nn.Conv2d()
 torch.nn.BatchNorm2d()
...

torch.nn.functional:提供了一些功能的函數(shù)化接口扣讼,torch.nn中的大多數(shù)layer在其中都有一個(gè)與之相對(duì)應(yīng)的函數(shù)缺猛。

# 含參數(shù)
 torch.nn.functional.linear()

# 不含參數(shù)
 torch.nn.functional.relu()
 torch.nn.functional.max_pool2d()

構(gòu)建簡(jiǎn)易的神經(jīng)網(wǎng)絡(luò)

torch.nn的核心數(shù)據(jù)結(jié)構(gòu)是Module,它是一個(gè)抽象概念椭符,既可以表示神經(jīng)網(wǎng)絡(luò)中的個(gè)層(layer)荔燎,也可以表示一個(gè)包含很多層的神經(jīng)網(wǎng)絡(luò)。要構(gòu)造一個(gè)神經(jīng)網(wǎng)絡(luò)模型销钝,首先我們需要?jiǎng)?chuàng)建一個(gè)nn.Module的子類(lèi)有咨,然后在這個(gè)子類(lèi)中構(gòu)造我們的模型。

網(wǎng)絡(luò)

定義一個(gè)簡(jiǎn)單的三層神經(jīng)網(wǎng)絡(luò)曙搬。

NN

__init__():初始化父類(lèi)摔吏,定義各個(gè)層的結(jié)構(gòu)。
forward():根據(jù)定義的層纵装,構(gòu)建向前傳播的流程。

from torch import nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(20, 120)
        self.fc2 = nn.Linear(120, 64)
        self.fc3 = nn.Linear(64, 1)

        self.drop = nn.Dropout(0.3)

    def forward(self, x):

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.drop(x)
        x = self.fc3(x)

        return x

if __name__ == '__main__':
    net = Net()
    print(net)

可以打印模型的結(jié)構(gòu):

Net(
  (fc1): Linear(in_features=20, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=1, bias=True)
  (drop): Dropout(p=0.3)
)

訓(xùn)練

配置損失函數(shù)据某,因?yàn)槭腔貧w任務(wù)橡娄,這里選擇的是MSE。

loss_function = nn.MSELoss()

通過(guò)torch.optim配置優(yōu)化器癣籽,這里使用的是SGD挽唉。optim接受兩個(gè)參數(shù),第一個(gè)是模型的可訓(xùn)練參數(shù)net.parameters()筷狼,第二個(gè)是學(xué)習(xí)率lr瓶籽。

import torch.optim as optim

optimizer = optim.SGD(net.parameters(), lr=0.001)

對(duì)模型進(jìn)行優(yōu)化的過(guò)程如下所示:

  1. 首先通過(guò)optimizer.zero_grad()把梯度置零,也就是把loss關(guān)于weight的導(dǎo)數(shù)變成0埂材。這是因?yàn)镻yTorch中梯度是累加的塑顺,在每個(gè)batch中我們不需要前面batch的梯度。
  2. 根據(jù)輸入數(shù)據(jù)得到模型的輸出。
  3. 通過(guò)之前定義的loss_function來(lái)計(jì)算loss严拒。
  4. 通過(guò)loss.backward()對(duì)loss進(jìn)行反向傳播扬绪。
  5. 通過(guò)optimizer.step()使用之前定義的優(yōu)化器優(yōu)化網(wǎng)絡(luò)。
optimizer.zero_grad()
    
output = net(inputs)
loss = loss_function(output, target)
    
loss.backward()
optimizer.step()

保存與載入

PyTorch可以把數(shù)據(jù)保存為.pth或者.pt等文件裤唠。

保存和加載整個(gè)模型:

torch.save(net, 'net.pth')
net = torch.load('net.pth')

僅保存和加載模型參數(shù)(推薦使用挤牛,需要提前手動(dòng)構(gòu)建模型):

torch.save(net.state_dict(), 'net.pth')
net.load_state_dict(torch.load('net.pth'))

完整的訓(xùn)練流程

if __name__ == '__main__':
    net = Net()
    print(net)

    x = torch.randn(100, 20)
    y = torch.randn(100, 1)
    
    optimizer = optim.SGD(net.parameters(), lr=0.001)
    loss_function = nn.MSELoss()

    running_loss = 0.0
    for i in range(10):
        index = torch.randperm(100);
        x = x[index]
        y = y[index]
        
        b = list(range(0, 100, 10))
        for j, b_index in enumerate(b):
            inputs = x[b_index: b_index + 10, :]
            target = y[b_index: b_index + 10, :]
            
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                target = target.cuda()

            optimizer.zero_grad()
    
            output = net(inputs)
            loss = loss_function(output, target)
    
            loss.backward()
    
            optimizer.step()
    
            running_loss += loss.item()
            if i != 0 and i % 2 == 0:
                print('epoch:{} | batch:{}| loss:{:.5f}'.format(i, j, running_loss / 2))
                running_loss = 0.0
    
            torch.save(net.state_dict(), 'net.pth')

訓(xùn)練結(jié)果:

...
epoch:6 | batch:7| loss:0.56948
epoch:6 | batch:8| loss:0.50152
epoch:6 | batch:9| loss:0.18269
epoch:8 | batch:0| loss:4.45438
epoch:8 | batch:1| loss:0.20041
epoch:8 | batch:2| loss:0.69559
epoch:8 | batch:3| loss:0.17017
epoch:8 | batch:4| loss:0.29025
epoch:8 | batch:5| loss:0.23748
epoch:8 | batch:6| loss:0.40649
epoch:8 | batch:7| loss:0.36893
epoch:8 | batch:8| loss:0.32532
epoch:8 | batch:9| loss:0.57664

序列模型

nn.Sequential是一個(gè)有序的容器,神經(jīng)網(wǎng)絡(luò)模塊將按照在傳入構(gòu)造器的順序依次被添加到計(jì)算圖中執(zhí)行种蘸,同時(shí)以神經(jīng)網(wǎng)絡(luò)模塊為元素的有序字典也可以作為傳入?yún)?shù)墓赴。對(duì)于沒(méi)有分支的網(wǎng)絡(luò)結(jié)構(gòu),使用序列的方法構(gòu)建模型更加容易航瞭。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.seq = nn.Sequential(
            nn.Conv2d(3, 64, 7, 2, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1)
        )
        
    def forward(self, x):
        x = self.seq(x)
        
        return x


if __name__ == '__main__':
    net = Net()
    print(net)

可以看出竣蹦,多個(gè)不同的網(wǎng)絡(luò)層被一個(gè)Sequential類(lèi)包裹了。

Net(
  (seq): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末沧奴,一起剝皮案震驚了整個(gè)濱河市痘括,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌滔吠,老刑警劉巖纲菌,帶你破解...
    沈念sama閱讀 222,104評(píng)論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異疮绷,居然都是意外死亡翰舌,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,816評(píng)論 3 399
  • 文/潘曉璐 我一進(jìn)店門(mén)冬骚,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)椅贱,“玉大人,你說(shuō)我怎么就攤上這事只冻”勇螅” “怎么了?”我有些...
    開(kāi)封第一講書(shū)人閱讀 168,697評(píng)論 0 360
  • 文/不壞的土叔 我叫張陵喜德,是天一觀的道長(zhǎng)山橄。 經(jīng)常有香客問(wèn)我,道長(zhǎng)舍悯,這世上最難降的妖魔是什么航棱? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 59,836評(píng)論 1 298
  • 正文 為了忘掉前任,我火速辦了婚禮萌衬,結(jié)果婚禮上饮醇,老公的妹妹穿的比我還像新娘。我一直安慰自己秕豫,他們只是感情好朴艰,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,851評(píng)論 6 397
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著,像睡著了一般呵晚。 火紅的嫁衣襯著肌膚如雪蜘腌。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書(shū)人閱讀 52,441評(píng)論 1 310
  • 那天饵隙,我揣著相機(jī)與錄音撮珠,去河邊找鬼。 笑死金矛,一個(gè)胖子當(dāng)著我的面吹牛芯急,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播驶俊,決...
    沈念sama閱讀 40,992評(píng)論 3 421
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼娶耍,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了饼酿?” 一聲冷哼從身側(cè)響起榕酒,我...
    開(kāi)封第一講書(shū)人閱讀 39,899評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎故俐,沒(méi)想到半個(gè)月后想鹰,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,457評(píng)論 1 318
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡药版,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,529評(píng)論 3 341
  • 正文 我和宋清朗相戀三年辑舷,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片槽片。...
    茶點(diǎn)故事閱讀 40,664評(píng)論 1 352
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡何缓,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出还栓,到底是詐尸還是另有隱情碌廓,我是刑警寧澤,帶...
    沈念sama閱讀 36,346評(píng)論 5 350
  • 正文 年R本政府宣布蝙云,位于F島的核電站氓皱,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏勃刨。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 42,025評(píng)論 3 334
  • 文/蒙蒙 一股淡、第九天 我趴在偏房一處隱蔽的房頂上張望身隐。 院中可真熱鬧,春花似錦唯灵、人聲如沸贾铝。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 32,511評(píng)論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)垢揩。三九已至玖绿,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間叁巨,已是汗流浹背斑匪。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 33,611評(píng)論 1 272
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留锋勺,地道東北人蚀瘸。 一個(gè)月前我還...
    沈念sama閱讀 49,081評(píng)論 3 377
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像庶橱,于是被迫代替她去往敵國(guó)和親贮勃。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,675評(píng)論 2 359