基于PyTorch的CIFAR10小記

CIFAR-10數(shù)據(jù)集介紹

CIFAR-10數(shù)據(jù)集由10個(gè)類的60000個(gè)32x32彩色圖像組成帝蒿,每個(gè)類有6000個(gè)圖像。有50000個(gè)訓(xùn)練圖像和10000個(gè)測試圖像曙咽。
數(shù)據(jù)集分為五個(gè)訓(xùn)練批次和一個(gè)測試批次省店,每個(gè)批次有10000個(gè)圖像。測試批次包含來自每個(gè)類別的恰好1000個(gè)隨機(jī)選擇的圖像宙拉。訓(xùn)練批次以隨機(jī)順序包含剩余圖像宾尚,但一些訓(xùn)練批次可能包含來自一個(gè)類別的圖像比另一個(gè)更多⌒怀海總體來說煌贴,五個(gè)訓(xùn)練集之和包含來自每個(gè)類的正好5000張圖像。
以下是數(shù)據(jù)集中的類锥忿,以及來自每個(gè)類的10個(gè)隨機(jī)圖像:

CIFAR10數(shù)據(jù)集

下載地址:http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

切入正題

這次實(shí)踐主要參考PyTorch官方的教程(https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py)牛郑,在此基礎(chǔ)上進(jìn)行一些修改,由于主要目的是了解PyTorch的編程方法敬鬓,所以在數(shù)據(jù)集那里并沒有從訓(xùn)練集中切分驗(yàn)證集出來淹朋,在訓(xùn)練時(shí)僅觀察了loss的變化,最后使用測試集觀察準(zhǔn)確率钉答。

關(guān)于數(shù)據(jù)集

CIFAR10的數(shù)據(jù)集可以通過torchvision進(jìn)行下載础芍,但是下載速度太慢,建議使用迅雷下載

開搞

導(dǎo)入基礎(chǔ)包

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 設(shè)置一些參數(shù)
EPOCHS = 20
BATCH_SIZE = 512

創(chuàng)建數(shù)據(jù)集

# 創(chuàng)建一個(gè)轉(zhuǎn)換器数尿,將torchvision數(shù)據(jù)集的輸出范圍[0,1]轉(zhuǎn)換為歸一化范圍的張量[-1,1]者甲。
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 創(chuàng)建訓(xùn)練集
# root -- 數(shù)據(jù)存放的目錄
# train -- 明確是否是訓(xùn)練集
# download -- 是否需要下載
# transform -- 轉(zhuǎn)換器,將數(shù)據(jù)集進(jìn)行轉(zhuǎn)換
trainset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

# 創(chuàng)建測試集
testset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

創(chuàng)建數(shù)據(jù)加載器

# 創(chuàng)建訓(xùn)練/測試加載器砌创,
# trainset/testset -- 數(shù)據(jù)集
# batch_size -- 不解釋
# shuffle -- 是否打亂順序
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=BATCH_SIZE, shuffle=True)

# 類別標(biāo)簽
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

創(chuàng)建網(wǎng)絡(luò)

我這里定義了兩個(gè)CNN網(wǎng)絡(luò)虏缸,分別保存在CNN_1.pyCNN_2.py文件中

CNN_1

# CNN_1.py
import torch

# 學(xué)習(xí)率
LR = 0.005

class Net(torch.nn.Module):
    """Some Information about Net"""

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 16, 3, padding=1),  # 3*32*32 -> 16*32*32
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2)  # 16*32*32 -> 16*16*16
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(16, 32, 3, padding=1),  # 16*16*16 -> 32*16*16
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2)  # 32*16*16 -> 32*8*8
        )
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, 3, padding=1),  # 32*8*8 -> 64*8*8
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2)  # 64*8*8 -> 64*4*4
        )
        self.fc1 = torch.nn.Sequential(
            torch.nn.Linear(64*4*4, 32),
            torch.nn.ReLU(),
            # torch.nn.Dropout()
        )
        self.fc2 = torch.nn.Linear(32, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(-1, 64*4*4)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

net = Net()
net.cuda()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=LR)

網(wǎng)絡(luò)結(jié)構(gòu)如下:

Net(
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=1024, out_features=32, bias=True)
    (1): ReLU()
  )
  (fc2): Linear(in_features=32, out_features=10, bias=True)
)

網(wǎng)絡(luò)創(chuàng)建了3個(gè)卷積層鲫懒,1個(gè)全連接層

CNN_2

import torch
LR = 0.005

class Net(torch.nn.Module):
    """Some Information about CNN"""

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 16, 3, padding=1),  # 3*32*32 -> 16*32*32
            torch.nn.ReLU(),
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(16, 32, 3, padding=1),  # 16*32*32 -> 32*32*32
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2)  # 32*32*32-> 32*16*16
        )
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, 3, padding=1),  #  32*16*16 -> 64*16*16
            torch.nn.ReLU(),
        )

        self.conv4 = torch.nn.Sequential(
            torch.nn.Conv2d(64, 128, 3, padding=1),  #  64*16*16 -> 128*16*16
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2)  # 128*16*16 -> 128*8*8
        )

        self.conv5 = torch.nn.Sequential(
            torch.nn.Conv2d(128, 256, 3, padding=1),  #  128*8*8 -> 256*8*8
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2)  # 256*8*8 -> 256*4*4
        )

        self.gap = torch.nn.AvgPool2d(4,4)
        self.fc = torch.nn.Linear(256, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.gap(x)
        x = x.view(-1, 256)
        x = self.fc(x)
        return x

net = Net()
net.cuda()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=LR)

打印網(wǎng)絡(luò)如下:

Net(
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (conv4): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv5): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (gap): AvgPool2d(kernel_size=4, stride=4, padding=0)
  (fc): Linear(in_features=256, out_features=10, bias=True)
)

創(chuàng)建了5個(gè)卷積層,最后使用GAP連接輸出層

定義訓(xùn)練函數(shù)和測試函數(shù)

為了方便對不同的網(wǎng)絡(luò)進(jìn)行訓(xùn)練和測試刽辙,因此窥岩,定義一個(gè)通用的訓(xùn)練函數(shù)和測試函數(shù),方便使用

定義訓(xùn)練函數(shù)


def train(model, criterion, optimizer, trainloader, epochs=5, log_interval=50):
    print('----- Train Start -----')
    for epoch in range(epochs):
        running_loss = 0.0
        for step, (batch_x, batch_y) in enumerate(trainloader):
            batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

            output = model(batch_x)

            optimizer.zero_grad()
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if step % log_interval == (log_interval-1):
                print('[%d, %5d] loss: %.4f' %
                      (epoch + 1, step + 1, running_loss / log_interval))
                running_loss = 0.0
    print('----- Train Finished -----')

定義測試函數(shù)


def test(model, testloader):
    print('------ Test Start -----')

    correct = 0
    total = 0

    with torch.no_grad():
        for test_x, test_y in testloader:
            images, labels = test_x.cuda(), test_y.cuda()
            output = model(images)
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print('Accuracy of the network is: %.4f %%' % accuracy)
    return accuracy

在測試集上運(yùn)行網(wǎng)絡(luò)

測試CNN_1網(wǎng)絡(luò)

train(CNN_1.net, CNN_1.criterion, CNN_1.optimizer, trainloader, epochs=EPOCHS)
test(CNN_1.net, testloader)

測試CNN_2網(wǎng)絡(luò)

train(CNN_2.net, CNN_2.criterion, CNN_2.optimizer, trainloader, epochs=EPOCHS)
test(CNN_2.net, testloader)

訓(xùn)練結(jié)果

CNN_1在10代訓(xùn)練后宰缤,在測試集準(zhǔn)確率上能夠達(dá)到71.1100 %

CNN_2在10代訓(xùn)練后颂翼,在測試集準(zhǔn)確率上能夠達(dá)到75.6700 %

代碼

鏈接:https://pan.baidu.com/s/1rOmiE35rQnszmYyyN6h16Q
提取碼:zqf2

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市慨灭,隨后出現(xiàn)的幾起案子朦乏,更是在濱河造成了極大的恐慌,老刑警劉巖氧骤,帶你破解...
    沈念sama閱讀 222,104評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件呻疹,死亡現(xiàn)場離奇詭異,居然都是意外死亡筹陵,警方通過查閱死者的電腦和手機(jī)刽锤,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,816評論 3 399
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來朦佩,“玉大人并思,你說我怎么就攤上這事∮锍恚” “怎么了宋彼?”我有些...
    開封第一講書人閱讀 168,697評論 0 360
  • 文/不壞的土叔 我叫張陵,是天一觀的道長仙畦。 經(jīng)常有香客問我宙暇,道長,這世上最難降的妖魔是什么议泵? 我笑而不...
    開封第一講書人閱讀 59,836評論 1 298
  • 正文 為了忘掉前任占贫,我火速辦了婚禮,結(jié)果婚禮上先口,老公的妹妹穿的比我還像新娘型奥。我一直安慰自己,他們只是感情好碉京,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,851評論 6 397
  • 文/花漫 我一把揭開白布厢汹。 她就那樣靜靜地躺著,像睡著了一般谐宙。 火紅的嫁衣襯著肌膚如雪烫葬。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 52,441評論 1 310
  • 那天,我揣著相機(jī)與錄音搭综,去河邊找鬼垢箕。 笑死,一個(gè)胖子當(dāng)著我的面吹牛兑巾,可吹牛的內(nèi)容都是我干的条获。 我是一名探鬼主播,決...
    沈念sama閱讀 40,992評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼蒋歌,長吁一口氣:“原來是場噩夢啊……” “哼帅掘!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起堂油,我...
    開封第一講書人閱讀 39,899評論 0 276
  • 序言:老撾萬榮一對情侶失蹤修档,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后府框,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體吱窝,經(jīng)...
    沈念sama閱讀 46,457評論 1 318
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,529評論 3 341
  • 正文 我和宋清朗相戀三年寓免,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片计维。...
    茶點(diǎn)故事閱讀 40,664評論 1 352
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡袜香,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出鲫惶,到底是詐尸還是另有隱情蜈首,我是刑警寧澤,帶...
    沈念sama閱讀 36,346評論 5 350
  • 正文 年R本政府宣布欠母,位于F島的核電站欢策,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏赏淌。R本人自食惡果不足惜踩寇,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 42,025評論 3 334
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望六水。 院中可真熱鬧俺孙,春花似錦、人聲如沸掷贾。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,511評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽想帅。三九已至场靴,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背旨剥。 一陣腳步聲響...
    開封第一講書人閱讀 33,611評論 1 272
  • 我被黑心中介騙來泰國打工咧欣, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人泞边。 一個(gè)月前我還...
    沈念sama閱讀 49,081評論 3 377
  • 正文 我出身青樓该押,卻偏偏與公主長得像,于是被迫代替她去往敵國和親阵谚。 傳聞我的和親對象是個(gè)殘疾皇子蚕礼,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,675評論 2 359