基于Pytorch訓練CIFAR-10數(shù)據(jù)集神經(jīng)網(wǎng)絡分類器

什么是CIFAR-10數(shù)據(jù)集?

CIFAR-10 是一個包含了10類震捣,60000 張 32x32像素彩色圖像的數(shù)據(jù)集直秆。
CIFAR-10數(shù)據(jù)集

每類圖像有6000張蔫劣;分為50000張訓練數(shù)據(jù)和10000張測試數(shù)據(jù)。CIFAR-10 數(shù)據(jù)網(wǎng)址:http://www.cs.toronto.edu/~kriz/cifar.html
數(shù)據(jù)集分為5個訓練數(shù)據(jù)集和1個測試數(shù)據(jù)集吭产,每個批次10000張圖像

cifar10數(shù)據(jù)分批

第一步:下載數(shù)據(jù)集并加載到內(nèi)存侣监。圖像數(shù)據(jù)會經(jīng)過標準化(Normalize)和歸一化處理。對數(shù)據(jù)集進行標準化處理臣淤,就是讓數(shù)據(jù)集的均值為0橄霉,方差為1,把數(shù)據(jù)集映射到(-1,1)之間邑蒋,這樣可以加速訓練過程姓蜂,提高模型泛化能力按厘。

為什么要標準化輸入數(shù)據(jù)

歸一化將像素值從0~255已經(jīng)轉化為0~1之間,加快訓練網(wǎng)絡的收斂性钱慢。圖像的像素處于0-1范圍時逮京,由于仍然介于0-255之間,所以圖像依舊是有效的束莫,并且可以正常查看圖像

import torch
import torchvision # 圖像處理工具包
import torchvision.transforms as transforms 
N = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=N, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=N, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

運行結果:

Using downloaded and verified file: ./data\cifar-10-python.tar.gz
Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified

第二步:隨機查看一批圖片

import matplotlib.pyplot as plt 
import numpy as np 
#圖像的像素處于0-1范圍時懒棉,由于仍然介于0-255之間,所以圖像依舊是有效的览绿,并且可以正常查看圖像
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
print(type(dataiter))
images, labels = dataiter.next()
print(dataiter.next())
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(N)))

運行結果

第三步:定義卷積神經(jīng)網(wǎng)絡策严。需要注意的是,開發(fā)者必須對圖像像素變化負責饿敲,要非常清楚圖像經(jīng)過每個神經(jīng)網(wǎng)絡層處理后妻导,輸出的像素尺寸,例如怀各,經(jīng)過一個5x5, stride=1的卷積后倔韭,一個32x32輸入的圖像會變成28x28。

import torch.nn as nn
import torch.nn.functional as F 

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  # 32 -> 28
        self.pool1 = nn.MaxPool2d(2)     # 28 -> 14
        self.conv2 = nn.Conv2d(6, 16, 5) # 14 -> 10
        self.pool2 = nn.MaxPool2d(2)     # 10 -> 5
        self.fc1   = nn.Linear(16*5*5, 120) # 展平
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)   # 10類

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

輸出:

Net(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)

第四步:定義損失函數(shù)并訓練網(wǎng)絡渠啤。作為分類應用狐肢,選擇交叉熵損失函數(shù)添吗;優(yōu)化方案選擇adam沥曹。

import torch.optim as optim
criterion = nn.CrossEntropyLoss() #分類應用,選擇交叉熵損失函數(shù)
# torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
optimizer = optim.Adam(net.parameters()) #其余參數(shù)默認

for epoch in range(3):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data # inputs類型和尺寸:<class 'torch.Tensor'> torch.Size([N, 3, 32, 32])
        optimizer.zero_grad() # 將上一次的梯度值清零
        output = net(inputs)  # 前向計算forward()
        loss = criterion(output, labels) # 計算損失值
        loss.backward()       # 反向計算backward()
        running_loss += loss.item() #累積loss值
        optimizer.step()      # 更新神經(jīng)網(wǎng)絡參數(shù)

        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000)) #計算平均loss值
            running_loss = 0.0
print('Finished Training')

輸出:

[1, 2000] loss: 1.644
[2, 2000] loss: 1.421
[3, 2000] loss: 1.211
Finished Training

第五步 保存訓練的模型碟联,Pytorch支持兩種保存方式

  • 僅保存模型參數(shù)
  • 保存完整模型(包含參數(shù))
WEIGHT = './cifar_net_weights.pth'
MODEL  = './cifar_net_model.pth'
torch.save(net.state_dict(), WEIGHT) # 僅保存模型參數(shù)
torch.save(net, MODEL)               # 保存整個模型(包含參數(shù))
保存模型

netron分別打開模型文件和權重文件可以看到區(qū)別

打開模型文件 vs 權重文件

第六步 基于模型文件做推理計算

import torch,torchvision
import torch.nn as nn
import torch.nn.functional as F 
import torchvision.transforms as transforms 

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  # 32 -> 28
        self.pool1 = nn.MaxPool2d(2)     # 28 -> 14
        self.conv2 = nn.Conv2d(6, 16, 5) # 14 -> 10
        self.pool2 = nn.MaxPool2d(2)     # 10 -> 5
        self.fc1   = nn.Linear(16*5*5, 120) # 展平
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)   # 10類

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

MODEL  = './cifar_net_model.pth'
net = torch.load(MODEL)
print(net)

N = 16
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=N, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

輸出結果:

Net(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
Files already downloaded and verified
Accuracy of the network on the 10000 test images: 58 %
Accuracy of plane : 53 %
Accuracy of car : 70 %
Accuracy of bird : 44 %
Accuracy of cat : 41 %
Accuracy of deer : 42 %
Accuracy of dog : 40 %
Accuracy of frog : 81 %
Accuracy of horse : 66 %
Accuracy of ship : 80 %
Accuracy of truck : 69 %

迷思:加載模型文件妓美,還需要Net類的定義?不符合常理袄鸱酢壶栋!~~

第七步 用GPU加速訓練

  • net.to(device) # 把網(wǎng)絡送入GPU
  • inputs, labels = data[0].to(device), data[1].to(device) # 把數(shù)據(jù)送到GPU
    測試下來,GPU訓練并沒有提升多少速度普监,是因為本例神經(jīng)網(wǎng)絡很淺很窄贵试。把神經(jīng)網(wǎng)絡加寬加深后,GPU的加速效果就會出來
?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末凯正,一起剝皮案震驚了整個濱河市毙玻,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌廊散,老刑警劉巖桑滩,帶你破解...
    沈念sama閱讀 219,589評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異允睹,居然都是意外死亡运准,警方通過查閱死者的電腦和手機幌氮,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,615評論 3 396
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來胁澳,“玉大人该互,你說我怎么就攤上這事【禄” “怎么了慢洋?”我有些...
    開封第一講書人閱讀 165,933評論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長陆盘。 經(jīng)常有香客問我普筹,道長,這世上最難降的妖魔是什么隘马? 我笑而不...
    開封第一講書人閱讀 58,976評論 1 295
  • 正文 為了忘掉前任太防,我火速辦了婚禮,結果婚禮上酸员,老公的妹妹穿的比我還像新娘蜒车。我一直安慰自己,他們只是感情好幔嗦,可當我...
    茶點故事閱讀 67,999評論 6 393
  • 文/花漫 我一把揭開白布酿愧。 她就那樣靜靜地躺著,像睡著了一般邀泉。 火紅的嫁衣襯著肌膚如雪嬉挡。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,775評論 1 307
  • 那天汇恤,我揣著相機與錄音庞钢,去河邊找鬼。 笑死因谎,一個胖子當著我的面吹牛基括,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播财岔,決...
    沈念sama閱讀 40,474評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼风皿,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了匠璧?” 一聲冷哼從身側響起桐款,我...
    開封第一講書人閱讀 39,359評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎患朱,沒想到半個月后鲁僚,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,854評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,007評論 3 338
  • 正文 我和宋清朗相戀三年冰沙,在試婚紗的時候發(fā)現(xiàn)自己被綠了瘪撇。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片叼屠。...
    茶點故事閱讀 40,146評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡丝格,死狀恐怖算芯,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情侥啤,我是刑警寧澤当叭,帶...
    沈念sama閱讀 35,826評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站盖灸,受9級特大地震影響蚁鳖,放射性物質發(fā)生泄漏。R本人自食惡果不足惜赁炎,卻給世界環(huán)境...
    茶點故事閱讀 41,484評論 3 331
  • 文/蒙蒙 一醉箕、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧徙垫,春花似錦讥裤、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,029評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至吴旋,卻和暖如春损肛,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背邮府。 一陣腳步聲響...
    開封第一講書人閱讀 33,153評論 1 272
  • 我被黑心中介騙來泰國打工荧关, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人褂傀。 一個月前我還...
    沈念sama閱讀 48,420評論 3 373
  • 正文 我出身青樓,卻偏偏與公主長得像加勤,于是被迫代替她去往敵國和親仙辟。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 45,107評論 2 356