PyTorch實(shí)現(xiàn)經(jīng)典網(wǎng)絡(luò)之LeNet5

簡介

本文是使用PyTorch來實(shí)現(xiàn)經(jīng)典神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)LeNet5,并將其用于處理MNIST數(shù)據(jù)集。LeNet5出自論文Gradient-Based Learning Applied to Document Recognition娃肿,是由圖靈獎獲得者Yann LeCun等提出的一種用于手寫體字符識別的非常高效的卷積神經(jīng)網(wǎng)絡(luò)。它曾經(jīng)被應(yīng)用于識別美國郵政服務(wù)提供的手寫郵政編碼數(shù)字蕉堰,錯誤率僅1%泽台。

LetNet5

預(yù)備知識

本文的重點(diǎn)是分析LeNet5的網(wǎng)絡(luò)結(jié)構(gòu),并且給出基于PyTorch的簡易版本實(shí)現(xiàn)访娶,因此需要讀者具有基本的卷積商虐、池化操作相關(guān)的知識。下面是對這些預(yù)備知識的一個簡單補(bǔ)充,詳細(xì)的可以參考斯坦福CS231n秘车。

卷積操作

單通道卷積操作
單通道卷積操作動態(tài)示意圖

上圖的Image大小是5x5典勇,卷積核大小為3x3,步長為1叮趴,最后的輸出大小是3x3割笙。假如輸入圖像大小是n \times n,卷積核的大小是f \times f眯亦,步長是s伤溉,o是最后輸出的feature map大小,則o可由以下公式計(jì)算得到:

多通道卷積操作

多通道卷積操作示意圖

上圖輸入有3個通道妻率,但是只有一個卷積核乱顾,故在計(jì)算的時(shí)候,每個通道都要通過卷積計(jì)算最后累加宫静,最終的輸出的通道數(shù)跟卷積核的數(shù)量一致糯耍。這里只有一個卷積核,故最后輸出是一個通道囊嘉。

池化操作

池化的定義比較簡單,最直觀的作用便是降維革为,常見的池化有最大池化扭粱、平均池化和隨機(jī)池化。池化層不需要訓(xùn)練參數(shù)震檩。


三種池化操作

LeNet5網(wǎng)絡(luò)結(jié)構(gòu)

LetNet5是一個簡單的CNN結(jié)構(gòu)琢蛤,整體框圖如下:
LeNet5結(jié)構(gòu)

整個網(wǎng)絡(luò)一共包含7層(不算輸入層),分別是C1抛虏、S2博其、C3、S4迂猴、C5慕淡、F6、Output沸毁,其中Cx代表的是卷積層峰髓,Sx代表的是下采樣層,接下來分別介紹每一層的作用息尺。

1. 輸入層

網(wǎng)絡(luò)的輸入是32x32大小的圖像數(shù)據(jù)携兵。

2. C1卷積層

C1層的輸入是32x32的原始圖像,卷積核的大小是5x5搂誉,深度為6徐紧,即有6個卷積核,不需要使用0填充,步長為1并级。由上述內(nèi)容可知拂檩,輸出的圖像大小是28x28,又卷積核的深度決定了輸出尺寸的深度死遭,因?yàn)檫@里使用了6個卷積核广恢,故C1層的輸出尺寸是28x28x6。C1層的總共參數(shù)個數(shù)為(5x5+1)x6=156個參數(shù)呀潭,其中+1代表的是每個卷積操作之后需要有一個額外的偏置參數(shù)钉迷。
又C1層一共包含28x28x6=4704個像素點(diǎn),而本層的每一個像素點(diǎn)都是由一個5x5的卷積操作外加一個偏置項(xiàng)操作得到的钠署,故一個像素點(diǎn)的計(jì)算會產(chǎn)生5x5+1=26條連接糠聪,總共會產(chǎn)生4704x26=122304條連接。

3. S2池化層

S2層的輸入是C1卷積層的輸出谐鼎,即28x28x6的特征圖舰蟆。這里使用的是核大小是2x2,步長為2狸棍,這意味著輸入矩陣的每4個相鄰元素經(jīng)過池化操作之后只會輸出1個元素身害,即大小變成了原先的四分之一,故輸出大小為14x14x6草戈。池化操作一般分為最大池化和平均池化塌鸯,這里的池化操作稍微有點(diǎn)不同,它是對輸入矩陣中2x2的區(qū)域中的全部元素先求和唐片,接著乘上一個可訓(xùn)練的系數(shù)w_i丙猬,再加上一個偏置項(xiàng)b_i,最后通過一個sigmoid函數(shù)费韭,得到最終的輸出茧球。因此在經(jīng)過這樣的操作之后,S2的輸出的行和列分別變?yōu)榱溯斎氲囊话胄浅郑?4x14抢埋。
對一張?zhí)卣鲌D進(jìn)行上述池化操作需要的參數(shù)只有2個,即系數(shù)w和偏置b督暂,故總共需要6x2=12個參數(shù)羹令。S2池化層的輸出大小是14x14x6,其中每一個像素點(diǎn)都需要經(jīng)過一次池化操作损痰,又一次池化操作需要產(chǎn)生4+1條連接福侈,故總共產(chǎn)生(4+1)x14x14x6=5880條連接。

4. C3卷積層

C3卷積層的輸入是S2的輸出卢未,即14x14x6的特征圖肪凛。C3卷積層使用的卷積核大小是5x5堰汉,深度為16,即包含了16個卷積核伟墙,不需要使用0填充翘鸭,步長為1。故輸出尺寸為10x10x16戳葵。但是這16個特征圖是如何得到的呢就乓?請看下圖:
S2層中特征圖的組合表

其中縱軸代表的是S2池化層輸出的6張?zhí)卣鲌D,橫軸代表的是C3卷積層的16個卷積核拱烁。這張表按照列可以分為4組生蚁,我分別用不同顏色的方框框出來了。其中綠色部分代表的是C3層中的前6個卷積與S2層中的連續(xù)的3張?zhí)卣鲌D相連戏自,藍(lán)色部分代表的是C3層中的6邦投、7、8號卷積核與S2層中連續(xù)的4張?zhí)卣鲌D相連擅笔,紅色部分代表的是C3層中的9志衣、10、11猛们、12念脯、13、14號卷積核與S2層中不連續(xù)的4張?zhí)卣鲌D相連弯淘,黃色部分代表的是C3層中的最后一個卷積核與S2層中所有特征圖相連和二。
為什么S2中的所有特征圖不直接與C3中的每一個卷積核全部相連呢?作者認(rèn)為有2點(diǎn)原因:第一是因?yàn)椴皇褂萌B接能夠保證有連接的數(shù)量保持在一個合理的界限范圍內(nèi)可以減少參數(shù)耳胎。第二是通過這種方式可以打破對稱性,不同的卷積核通過輸入不同的特征圖以期望得到互補(bǔ)的特征惕它。
同樣我們再來計(jì)算一下參數(shù)數(shù)量怕午。對于綠色部分,C3中一個卷積核要對3張?zhí)卣鲌D進(jìn)行卷積操作淹魄,一共有6個卷積核郁惜,故總共包含(5x5x3+1)x6=456個參數(shù),同理甲锡,藍(lán)色和紅色部分總共(5x5x4+1)x9 = 909個參數(shù)兆蕉,黃色部分(5x5x6+1)x1=151個參數(shù)$吐伲總共456+909+151=1516個參數(shù)虎韵。總共包含10x10x1516=151600個連接缸废。

5. S4池化層

S4池化層與S2池化層方式相同。把輸出降為輸入的四分之一 慌闭,即由C3層的輸出尺寸10x10x16降到5x5x16大小奔脐。核大小為2x2,步長為2亡电。S4層一共包含16x2=32個參數(shù),與S3層一共有(4+1)x5x5x16=2000個連接硅瞧。

6. C5卷積層

C5卷積層包含了120個卷積核份乒,核大小為5x5,填充為0腕唧,步長為1或辖。其中每一個卷積核與S4層的全部輸入相連,故每一個卷積核的輸出大小是1x1四苇,即C5層的輸出是一個120維的向量孝凌。C5層與S4層之間一共包含120x(5x5x16+1)=48120個連接。

7. F6全連接層

F6全連接層包含了84個節(jié)點(diǎn)月腋,故一共包含了(120+1)x84=10164個參數(shù)蟀架。F6層通過將輸入向量與權(quán)重向量求點(diǎn)積,然后在加上偏置項(xiàng)榆骚,最后通過一個sigmoid函數(shù)輸出片拍。

8. OutPut層

Output層也是全連接層,共有10個節(jié)點(diǎn)妓肢,分別代表數(shù)字0到9捌省,且如果節(jié)點(diǎn)i的值為0,則網(wǎng)絡(luò)識別的結(jié)果是數(shù)字i碉钠。采用的是徑向基函數(shù)(RBF)的網(wǎng)絡(luò)連接方式纲缓。假設(shè)x是上一層的輸入,y是RBF的輸出喊废,則RBF輸出的計(jì)算方式是:


上式w_ij的值由i的比特圖編碼確定祝高,i從0到9,j取值從0到7*12-1污筷。RBF輸出的值越接近于0工闺,則越接近于i,即越接近于i的ASCII編碼圖瓣蛀,表示當(dāng)前網(wǎng)絡(luò)輸入的識別結(jié)果是字符i陆蟆。該層有84x10=840個參數(shù)和連接。

LeNet5識別數(shù)字3

LeNet5 網(wǎng)絡(luò)識別數(shù)字3過程

代碼實(shí)踐

論文中的LeNet5結(jié)構(gòu)會稍微復(fù)雜點(diǎn)惋增,尤其是C3卷積層的操作叠殷,我們這里實(shí)現(xiàn)的是一個簡化版本。即不考慮卷積核之間的組合诈皿,直接利用PyTorch中內(nèi)置的卷積操作來進(jìn)行溪猿;同理钩杰,池化層的操作也是使用PyTorch內(nèi)置的操作來進(jìn)行。
總共代碼一共包含3個文件诊县,分別是模型文件LeNet5.py讲弄、模型訓(xùn)練文件LeNet5_Train.py、以及測試文件LeNet5_Test.py依痊。數(shù)據(jù)集來自kaggle避除。
依賴環(huán)境:

  • python3
  • PyTorch
  • pandas
  • matplotlib
  • numpy

模型部分代碼

LeNet5.py代碼如下:

import torch.nn as nn

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        # 包含一個卷積層和池化層,分別對應(yīng)LeNet5中的C1和S2胸嘁,
        # 卷積層的輸入通道為1瓶摆,輸出通道為6,設(shè)置卷積核大小5x5性宏,步長為1
        # 池化層的kernel大小為2x2
        self._conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            nn.MaxPool2d(kernel_size=2)
        )
        # 包含一個卷積層和池化層群井,分別對應(yīng)LeNet5中的C3和S4,
        # 卷積層的輸入通道為6毫胜,輸出通道為16书斜,設(shè)置卷積核大小5x5,步長為1
        # 池化層的kernel大小為2x2
        self._conv2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.MaxPool2d(kernel_size=2)
        )
        # 對應(yīng)LeNet5中C5卷積層酵使,由于它跟全連接層類似荐吉,所以這里使用了nn.Linear模塊
        # 卷積層的輸入通特征為4x4x16,輸出特征為120x1
        self._fc1 = nn.Sequential(
            nn.Linear(in_features=4*4*16, out_features=120)
        )
        # 對應(yīng)LeNet5中的F6口渔,輸入是120維向量样屠,輸出是84維向量
        self._fc2 = nn.Sequential(
            nn.Linear(in_features=120, out_features=84)
        )
        # 對應(yīng)LeNet5中的輸出層,輸入是84維向量缺脉,輸出是10維向量
        self._fc3 = nn.Sequential(
            nn.Linear(in_features=84, out_features=10)
        )

    def forward(self, input):
        # 前向傳播
        # MNIST DataSet image's format is 28x28x1
        # [28,28,1]--->[24,24,6]--->[12,12,6]
        conv1_output = self._conv1(input)
        # [12,12,6]--->[8,8,,16]--->[4,4,16]
        conv2_output = self._conv2(conv1_output)
        # 將[n,4,4,16]維度轉(zhuǎn)化為[n,4*4*16]
        conv2_output = conv2_output.view(-1, 4 * 4 * 16)
        # [n,256]--->[n,120]
        fc1_output = self._fc1(conv2_output)
        # [n,120]-->[n,84]
        fc2_output = self._fc2(fc1_output)
        # [n,84]-->[n,10]
        fc3_output = self._fc3(fc2_output)
        return fc3_output

模型訓(xùn)練部分

本文代碼使用了交叉熵?fù)p失函數(shù)痪欲,SGD優(yōu)化算法,設(shè)置學(xué)習(xí)率為0.001攻礼,動量設(shè)置為0.9业踢,小批量數(shù)據(jù)集大小設(shè)置為30,迭代次數(shù)為1000次秘蛔。
LeNet5_Train.py代碼如下:

import torch
import torch.nn as nn
import torch.optim as optim

import pandas as pd
import matplotlib.pyplot as plt
from PyTorchVersion.Networks.LeNet5 import LeNet5

train_data = pd.DataFrame(pd.read_csv("../Data/mnist_train.csv"))

model = LeNet5()
print(model)

# 定義交叉熵?fù)p失函數(shù)
loss_fc = nn.CrossEntropyLoss()
# 用model的參數(shù)初始化一個隨機(jī)梯度下降優(yōu)化器
optimizer = optim.SGD(params=model.parameters(),lr=0.001, momentum=0.78)
loss_list = []
x = []

# 迭代次數(shù)1000次
for i in range(1000):
    # 小批量數(shù)據(jù)集大小設(shè)置為30
    batch_data = train_data.sample(n=30, replace=False)
    # 每一條數(shù)據(jù)的第一個值是標(biāo)簽數(shù)據(jù)
    batch_y = torch.from_numpy(batch_data.iloc[:,0].values).long()
    #圖片信息,一條數(shù)據(jù)784維將其轉(zhuǎn)化為通道數(shù)為1傍衡,大小28*28的圖片深员。
    batch_x = torch.from_numpy(batch_data.iloc[:,1::].values).float().view(-1,1,28,28)

    # 前向傳播計(jì)算輸出結(jié)果
    prediction = model.forward(batch_x)
    # 計(jì)算損失值
    loss = loss_fc(prediction, batch_y)
    # Clears the gradients of all optimized
    optimizer.zero_grad()
    # back propagation algorithm
    loss.backward()
    # Performs a single optimization step (parameter update).
    optimizer.step()
    print("第%d次訓(xùn)練,loss為%.3f" % (i, loss.item()))
    loss_list.append(loss)
    x.append(i)

# Saves an object to a disk file.
torch.save(model.state_dict(),"../TrainedModel/LeNet5.pkl")
print('Networks''s keys: ', model.state_dict().keys())

plt.figure()
plt.xlabel("number of epochs")
plt.ylabel("loss")
plt.plot(x,loss_list,"r-")
plt.show()

模型訓(xùn)練過程中迭代次數(shù)與損失之間的變化關(guān)系圖:
迭代次數(shù)與損失之間的變化關(guān)系圖

可以看到大概經(jīng)過30次訓(xùn)練之后蛙埂,損失就已經(jīng)降到一個較低的水平了倦畅。

模型測試部分

總共進(jìn)行了100次測試,每次測試從測試集中隨機(jī)挑選50個樣本绣的,然后計(jì)算網(wǎng)絡(luò)的識別準(zhǔn)確率叠赐。
模型測試代碼LeNet5_Train.py如下:

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PyTorchVersion.Networks.LeNet5 import LeNet5

model = LeNet5()
test_data = pd.DataFrame(pd.read_csv("../Data/mnist_test.csv"))
#Load model parameters
model.load_state_dict(torch.load("../TrainedModel/LeNet5.pkl"))

accuracy_list = []
testList = []

with torch.no_grad():
    # 進(jìn)行一百次測試
    for i in range(100):
        # 每次從測試集中隨機(jī)挑選50個樣本
        batch_data = test_data.sample(n=50,replace=False)
        batch_x = torch.from_numpy(batch_data.iloc[:,1::].values).float().view(-1,1,28,28)
        batch_y = batch_data.iloc[:,0].values
        prediction = np.argmax(model(batch_x).numpy(), axis=1)
        acccurcy = np.mean(prediction==batch_y)
        print("第%d組測試集欲账,準(zhǔn)確率為%.3f" % (i,acccurcy))
        accuracy_list.append(acccurcy)
        testList.append(i)

plt.figure()
plt.xlabel("number of tests")
plt.ylabel("accuracy rate")
plt.ylim(0,1)
plt.plot(testList, accuracy_list,"r-")
plt.legend()
plt.show()

測試結(jié)果:
測試準(zhǔn)確率

平均準(zhǔn)確率大概在96%左右。

參考

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末芭概,一起剝皮案震驚了整個濱河市赛不,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌罢洲,老刑警劉巖踢故,帶你破解...
    沈念sama閱讀 217,185評論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異惹苗,居然都是意外死亡殿较,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,652評論 3 393
  • 文/潘曉璐 我一進(jìn)店門桩蓉,熙熙樓的掌柜王于貴愁眉苦臉地迎上來淋纲,“玉大人,你說我怎么就攤上這事院究∏⑺玻” “怎么了?”我有些...
    開封第一講書人閱讀 163,524評論 0 353
  • 文/不壞的土叔 我叫張陵儡首,是天一觀的道長片任。 經(jīng)常有香客問我,道長蔬胯,這世上最難降的妖魔是什么对供? 我笑而不...
    開封第一講書人閱讀 58,339評論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮氛濒,結(jié)果婚禮上产场,老公的妹妹穿的比我還像新娘。我一直安慰自己舞竿,他們只是感情好京景,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,387評論 6 391
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著骗奖,像睡著了一般确徙。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上执桌,一...
    開封第一講書人閱讀 51,287評論 1 301
  • 那天鄙皇,我揣著相機(jī)與錄音,去河邊找鬼仰挣。 笑死伴逸,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的膘壶。 我是一名探鬼主播错蝴,決...
    沈念sama閱讀 40,130評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼洲愤,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了顷锰?” 一聲冷哼從身側(cè)響起柬赐,我...
    開封第一講書人閱讀 38,985評論 0 275
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎馍惹,沒想到半個月后躺率,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,420評論 1 313
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡万矾,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,617評論 3 334
  • 正文 我和宋清朗相戀三年悼吱,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片良狈。...
    茶點(diǎn)故事閱讀 39,779評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡后添,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出薪丁,到底是詐尸還是另有隱情遇西,我是刑警寧澤,帶...
    沈念sama閱讀 35,477評論 5 345
  • 正文 年R本政府宣布严嗜,位于F島的核電站粱檀,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏漫玄。R本人自食惡果不足惜茄蚯,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,088評論 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望睦优。 院中可真熱鬧渗常,春花似錦、人聲如沸汗盘。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,716評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽隐孽。三九已至癌椿,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間菱阵,已是汗流浹背踢俄。 一陣腳步聲響...
    開封第一講書人閱讀 32,857評論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留送粱,地道東北人褪贵。 一個月前我還...
    沈念sama閱讀 47,876評論 2 370
  • 正文 我出身青樓掂之,卻偏偏與公主長得像抗俄,于是被迫代替她去往敵國和親脆丁。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,700評論 2 354

推薦閱讀更多精彩內(nèi)容