Pytorch深度學(xué)習(xí)-用GoogleNet訓(xùn)練MNIST數(shù)據(jù)集

全連接限制圖像的尺寸弟孟,而卷積則不關(guān)心圖像尺寸大小汁咏,只需要接受輸入的通道數(shù)亚斋,輸出的通道數(shù)和卷積核大小即可確定圖像尺寸的變換過(guò)程,即

H_{out} = {H_{in}+2*padding-kernalsize\over stride}+1.
W_{out} = {W_{in}+2*padding-kernalsize\over stride}+1.
padding:對(duì)輸入圖片進(jìn)行填充攘滩,一般用0填充帅刊,padding=1,代表填充一圈漂问,保證卷積前后的圖像尺寸大小一致赖瞒,padding計(jì)算公式如下:
padding = {kernalsize-1\over 2}.

stride步長(zhǎng):指的是卷積核每次滑動(dòng)的距離大小

本文采用GoogleNet來(lái)構(gòu)建深度網(wǎng)絡(luò)模型

1. 數(shù)據(jù)集構(gòu)建

每個(gè)像素點(diǎn)即每條數(shù)據(jù)中的值范圍為0-255,有的數(shù)字過(guò)大不利于訓(xùn)練且難以收斂蚤假,故將其歸一化到(0-1)之間

# 數(shù)據(jù)集處理

transform = transforms.Compose([
    transforms.ToTensor(),  # 轉(zhuǎn)化成Tensor張量
    transforms.Normalize((0.1307,), (0.3081,))  # 歸一化處理,將其(0-255)映射到(0-1)
])
# 1.準(zhǔn)備數(shù)據(jù)集
train_dataset = datasets.MNIST(root="../DataSet/mnist",
                               train=True,
                               transform=transform,
                               download=False)
test_dataset = datasets.MNIST(root="../DataSet/mnist",
                              train=False,
                              transform=transform,
                              download=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


2. 構(gòu)建GoogleNet---構(gòu)造Inception單元

# 定義GoogleNet---構(gòu)造 Inception 單元---GoogleNet不改變圖片的尺寸大小即 w 和 h 不變栏饮,只改變其 channel 大小
class GoogleNet(torch.nn.Module):
    def __init__(self, input_channels):
        super(GoogleNet, self).__init__()
        # 第一個(gè)分支
        self.branch_pool1 = torch.nn.Conv2d(input_channels, 16, kernel_size=1)
        # 第二個(gè)分支
        self.branch_pool2_1 = torch.nn.Conv2d(input_channels, 16, kernel_size=1)
        self.branch_pool2_2 = torch.nn.Conv2d(16, 24, kernel_size=5, padding=2)
        # 第三個(gè)分支
        # ---padding = (k-1)/2 k為卷積核大小, 即可保證卷積后圖片大小不變
        # ---padding作用: 保證卷積后的圖片大小與原圖片一致即作用于 w 和 h
        # ---同時(shí)用0填充,保證足夠多的信息量也不存在噪音問(wèn)題
        self.branch_pool3_1 = torch.nn.Conv2d(input_channels, 16, kernel_size=1)
        self.branch_pool3_2 = torch.nn.Conv2d(16, 24, kernel_size=3, padding=1)
        self.branch_pool3_3 = torch.nn.Conv2d(24, 24, kernel_size=3, padding=1)
        # 第四個(gè)分支
        self.branch_pool4 = torch.nn.Conv2d(input_channels, 24, kernel_size=1)

    # 不管 input_channels 是多少,輸出的 channels 是24*3+16=88
    def forward(self, x):
        branch1 = self.branch_pool1(x)  # torch.Size([64, 16, 28, 28])
        branch2 = self.branch_pool2_2(self.branch_pool2_1(x))  # torch.Size([64, 24, 28, 28])
        branch3 = self.branch_pool3_3(self.branch_pool3_2(self.branch_pool3_1(x)))  # torch.Size([64, 24, 28, 28])
        branch4 = self.branch_pool4(F.avg_pool2d(x, kernel_size=3, stride=1, padding=1))  # torch.Size([64, 24, 28, 28])

        outputs = [branch1, branch2, branch3, branch4]
        # GoogleNet 要求每個(gè)分支輸出的圖片通道數(shù)可以不一樣, 但其他維度的尺寸必須一樣,這樣才可以保證能做 cat 連接
        # 即(batch, channel_branch, width, height)----->(batch, channel_branch, width, height)
        # (batch, channels, width, height): 沿著 channel 通道的方向?qū)⑦@些張量連接起來(lái)
        return torch.cat(outputs, dim=1)

3.采用GoogleNet的神經(jīng)網(wǎng)絡(luò)來(lái)構(gòu)建模型

# 2.構(gòu)建網(wǎng)絡(luò)模型---模型是針對(duì)批次樣本的處理情況
class Module(torch.nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, bias=False)
        self.googleNet1 = GoogleNet(input_channels=10)
        self.googleNet2 = GoogleNet(input_channels=20)
        self.conv2 = torch.nn.Conv2d(88, 20, kernel_size=5, bias=False)
        # 下采樣并不改變 channel 數(shù)量,只改變圖片大小
        self.maxPooling = torch.nn.MaxPool2d(2)

        self.fc = torch.nn.Linear(1408, 10)

    # 卷積---池化---激活函數(shù)---GoogleNet---數(shù)據(jù)扁平化處理---全連接層
    def forward(self, x):
        size = x.size(0)  # torch.Size([64, 1, 28, 28])
        x = F.relu(self.maxPooling(self.conv1(x)))  # torch.Size([64, 10, 12, 12])
        x = self.googleNet1(x)  # torch.Size([64, 88, 12, 12])
        x = F.relu(self.maxPooling(self.conv2(x)))  # torch.Size([64, 20, 4, 4])
        x = self.googleNet2(x)  # torch.Size([64, 88, 4, 4])
        # 數(shù)據(jù)扁平化處理,為接下來(lái)的全連接測(cè)做準(zhǔn)備
        # Flatten data from (64, 88, 4, 4) to (64,1408)
        x = x.view(size, -1)
        x = self.fc(x)
        # 全連接層之后不需要跟激活函數(shù),因?yàn)榧せ詈瘮?shù) softmax 的作用包含在 CrossEntropyLoss 中
        # softmax 函數(shù)的作用包含在 CrossEntropyLoss 中
        return x

4. 構(gòu)建損失函數(shù)和優(yōu)化器

損失函數(shù)采用CrossEntropyLoss
優(yōu)化器采用 SGD 隨機(jī)梯度優(yōu)化算法

# 構(gòu)造損失器和優(yōu)化器
# softmax 函數(shù)的作用包含在 CrossEntropyLoss 中,交叉熵算法
criterion = torch.nn.CrossEntropyLoss()
opt = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.5)
# 動(dòng)態(tài)更新學(xué)習(xí)率------每隔step_size : lr = lr * gamma
schedule = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5, last_epoch=-1)

5.完整代碼

# -*- codeing = utf-8 -*-
# @Time : 2022/4/12 8:57
# @Software : PyCharm

# 超參數(shù): 訓(xùn)練之前設(shè)置的參數(shù)
# 模型參數(shù): 訓(xùn)練過(guò)程中得到的參數(shù)
# Average Pooling: 均值池化---

# network in network:
# 1*1 的卷積核: 單個(gè)1*1的卷積后的圖片大小不變即:c*w*h---------->1*w*h
# 若需要輸出16個(gè)通道的圖片磷仰,則只需要將輸出通道設(shè)置為16,pytorch自動(dòng)構(gòu)建16個(gè)通道,1*1的卷積核
# 1*1 的卷積不改變圖片尺寸,即c1*w*h---------->c2*w*h
# 1*1 的卷積可以有效降低通道數(shù)量,大大減少網(wǎng)絡(luò)模型的浮點(diǎn)數(shù)運(yùn)算量

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

# Maxpooling: 最大池化,尋找每個(gè)空間的最大值然后組成一個(gè)新的圖像
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),  # 轉(zhuǎn)化成Tensor張量
    transforms.Normalize((0.1307,), (0.3081,))  # 歸一化處理,將其(0-255)映射到(0-1)
])
# 1.準(zhǔn)備數(shù)據(jù)集
train_dataset = datasets.MNIST(root="../DataSet/mnist",
                               train=True,
                               transform=transform,
                               download=False)
test_dataset = datasets.MNIST(root="../DataSet/mnist",
                              train=False,
                              transform=transform,
                              download=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


# 定義GoogleNet---構(gòu)造 Inception 單元---GoogleNet不改變圖片的尺寸大小即 w 和 h 不變袍嬉,只改變其 channel 大小
class GoogleNet(torch.nn.Module):
    def __init__(self, input_channels):
        super(GoogleNet, self).__init__()
        # 第一個(gè)分支
        self.branch_pool1 = torch.nn.Conv2d(input_channels, 16, kernel_size=1)
        # 第二個(gè)分支
        self.branch_pool2_1 = torch.nn.Conv2d(input_channels, 16, kernel_size=1)
        self.branch_pool2_2 = torch.nn.Conv2d(16, 24, kernel_size=5, padding=2)
        # 第三個(gè)分支
        # ---padding = (k-1)/2 k為卷積核大小, 即可保證卷積后圖片大小不變
        # ---padding作用: 保證卷積后的圖片大小與原圖片一致即作用于 w 和 h
        # ---同時(shí)用0填充,保證足夠多的信息量也不存在噪音問(wèn)題
        self.branch_pool3_1 = torch.nn.Conv2d(input_channels, 16, kernel_size=1)
        self.branch_pool3_2 = torch.nn.Conv2d(16, 24, kernel_size=3, padding=1)
        self.branch_pool3_3 = torch.nn.Conv2d(24, 24, kernel_size=3, padding=1)
        # 第四個(gè)分支
        self.branch_pool4 = torch.nn.Conv2d(input_channels, 24, kernel_size=1)

    # 不管 input_channels 是多少,輸出的 channels 是24*3+16=88
    def forward(self, x):
        branch1 = self.branch_pool1(x)  # torch.Size([64, 16, 28, 28])
        branch2 = self.branch_pool2_2(self.branch_pool2_1(x))  # torch.Size([64, 24, 28, 28])
        branch3 = self.branch_pool3_3(self.branch_pool3_2(self.branch_pool3_1(x)))  # torch.Size([64, 24, 28, 28])
        branch4 = self.branch_pool4(F.avg_pool2d(x, kernel_size=3, stride=1, padding=1))  # torch.Size([64, 24, 28, 28])

        outputs = [branch1, branch2, branch3, branch4]
        # GoogleNet 要求每個(gè)分支輸出的圖片通道數(shù)可以不一樣, 但其他維度的尺寸必須一樣,這樣才可以保證能做 cat 連接
        # 即(batch, channel_branch, width, height)----->(batch, channel_branch, width, height)
        # (batch, channels, width, height): 沿著 channel 通道的方向?qū)⑦@些張量連接起來(lái)
        return torch.cat(outputs, dim=1)


# 2.構(gòu)建網(wǎng)絡(luò)模型---模型是針對(duì)批次樣本的處理情況
class Module(torch.nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, bias=False)
        self.googleNet1 = GoogleNet(input_channels=10)
        self.googleNet2 = GoogleNet(input_channels=20)
        self.conv2 = torch.nn.Conv2d(88, 20, kernel_size=5, bias=False)
        # 下采樣并不改變 channel 數(shù)量,只改變圖片大小
        self.maxPooling = torch.nn.MaxPool2d(2)

        self.fc = torch.nn.Linear(1408, 10)

    # 卷積---池化---激活函數(shù)---GoogleNet---數(shù)據(jù)扁平化處理---全連接層
    def forward(self, x):
        size = x.size(0)  # torch.Size([64, 1, 28, 28])
        x = F.relu(self.maxPooling(self.conv1(x)))  # torch.Size([64, 10, 12, 12])
        x = self.googleNet1(x)  # torch.Size([64, 88, 12, 12])
        x = F.relu(self.maxPooling(self.conv2(x)))  # torch.Size([64, 20, 4, 4])
        x = self.googleNet2(x)  # torch.Size([64, 88, 4, 4])
        # 數(shù)據(jù)扁平化處理,為接下來(lái)的全連接測(cè)做準(zhǔn)備
        # Flatten data from (64, 88, 4, 4) to (64,1408)
        x = x.view(size, -1)
        x = self.fc(x)
        # 全連接層之后不需要跟激活函數(shù),因?yàn)榧せ詈瘮?shù) softmax 的作用包含在 CrossEntropyLoss 中
        # softmax 函數(shù)的作用包含在 CrossEntropyLoss 中
        return x


model = Module().to(device)

# 3.構(gòu)造損失器和優(yōu)化器
# softmax 函數(shù)的作用包含在 CrossEntropyLoss 中,交叉熵算法
criterion = torch.nn.CrossEntropyLoss()
opt = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.5)
# 動(dòng)態(tài)更新學(xué)習(xí)率------每隔step_size : lr = lr * gamma
schedule = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5, last_epoch=-1)


# 4.訓(xùn)練數(shù)據(jù)集
def train():
    running_loss = 0
    for batch_idx, (inputs, target) in enumerate(train_loader, 0):
        inputs, target = inputs.to(device), target.to(device)
        opt.zero_grad()
        y_pred_data = model(inputs)
        loss = criterion(y_pred_data, target)
        loss.backward()
        opt.step()

        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print("[%5d, %5d] loss: %.5f" % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss == 0.0


# 5.測(cè)試數(shù)據(jù)集
def verify():
    correct = 0
    total = 0
    with torch.no_grad():  # 該語(yǔ)句下的所有tensor在進(jìn)行反向傳播時(shí),不會(huì)被計(jì)算梯度
        for (images, labels) in test_loader:
            images, labels = images.to(device), labels.to(device)
            # 數(shù)據(jù)進(jìn)入模型進(jìn)行計(jì)算
            outputs = model(images)
            # 沿著維度為1的方向(行方向) 尋找每行最大元素的值與其下標(biāo)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("==============================")
    print("Accuracy on test set: %d%%" % (100 * correct / total))
    print("==============================")


if __name__ == '__main__':
    for epoch in range(15):
        train()
        verify()
        # GoogleNet: 分支卷積--->cat(dim=1:channels)---全連接
        # 使用 卷積 + GoogleNet + 全連接 的神經(jīng)網(wǎng)絡(luò)的準(zhǔn)確率在 99% 左右, 同時(shí)減少了參數(shù)量和計(jì)算量




6.結(jié)果展示

result.png
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末芒划,一起剝皮案震驚了整個(gè)濱河市冬竟,隨后出現(xiàn)的幾起案子欧穴,更是在濱河造成了極大的恐慌,老刑警劉巖泵殴,帶你破解...
    沈念sama閱讀 211,042評(píng)論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件涮帘,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡笑诅,警方通過(guò)查閱死者的電腦和手機(jī)调缨,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 89,996評(píng)論 2 384
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)吆你,“玉大人弦叶,你說(shuō)我怎么就攤上這事「径啵” “怎么了伤哺?”我有些...
    開(kāi)封第一講書(shū)人閱讀 156,674評(píng)論 0 345
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)者祖。 經(jīng)常有香客問(wèn)我立莉,道長(zhǎng),這世上最難降的妖魔是什么七问? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 56,340評(píng)論 1 283
  • 正文 為了忘掉前任蜓耻,我火速辦了婚禮,結(jié)果婚禮上械巡,老公的妹妹穿的比我還像新娘刹淌。我一直安慰自己,他們只是感情好讥耗,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,404評(píng)論 5 384
  • 文/花漫 我一把揭開(kāi)白布有勾。 她就那樣靜靜地躺著,像睡著了一般葛账。 火紅的嫁衣襯著肌膚如雪柠衅。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書(shū)人閱讀 49,749評(píng)論 1 289
  • 那天籍琳,我揣著相機(jī)與錄音菲宴,去河邊找鬼。 笑死趋急,一個(gè)胖子當(dāng)著我的面吹牛喝峦,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播呜达,決...
    沈念sama閱讀 38,902評(píng)論 3 405
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼谣蠢,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起眉踱,我...
    開(kāi)封第一講書(shū)人閱讀 37,662評(píng)論 0 266
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤挤忙,失蹤者是張志新(化名)和其女友劉穎,沒(méi)想到半個(gè)月后谈喳,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體册烈,經(jīng)...
    沈念sama閱讀 44,110評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,451評(píng)論 2 325
  • 正文 我和宋清朗相戀三年婿禽,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了赏僧。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,577評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡扭倾,死狀恐怖淀零,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情膛壹,我是刑警寧澤驾中,帶...
    沈念sama閱讀 34,258評(píng)論 4 328
  • 正文 年R本政府宣布,位于F島的核電站恢筝,受9級(jí)特大地震影響哀卫,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜撬槽,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,848評(píng)論 3 312
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望趾撵。 院中可真熱鬧侄柔,春花似錦、人聲如沸占调。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 30,726評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)究珊。三九已至薪者,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間剿涮,已是汗流浹背言津。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 31,952評(píng)論 1 264
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留取试,地道東北人悬槽。 一個(gè)月前我還...
    沈念sama閱讀 46,271評(píng)論 2 360
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像瞬浓,于是被迫代替她去往敵國(guó)和親初婆。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,452評(píng)論 2 348

推薦閱讀更多精彩內(nèi)容