Pytorch入門學(xué)習(xí)(四)-training a classifier

未經(jīng)允許,不得轉(zhuǎn)載,謝謝~~

我們現(xiàn)在已經(jīng)知道了:

  • 怎么樣用pytorch定義一個(gè)神經(jīng)網(wǎng)絡(luò)亿鲜;
  • 怎么樣計(jì)算損失值台舱;
  • 怎么樣更新網(wǎng)絡(luò)的權(quán)重;

現(xiàn)在剩下的問題就是怎么樣獲取數(shù)據(jù)了,pytorch除了支持將包含數(shù)據(jù)信息的numpy array轉(zhuǎn)換成Tensor以外,也提供了各個(gè)常見數(shù)據(jù)集的加載方式,并封裝到了torchvision中绪商,本文簡單介紹數(shù)據(jù)獲取的方式,然后訓(xùn)練一個(gè)簡單的分類網(wǎng)絡(luò)作為入門級(jí)的example辅鲸。

數(shù)據(jù)獲取

當(dāng)你想要處理圖像格郁,文本,語音或者視頻信息時(shí)独悴,一般可以用標(biāo)準(zhǔn)的python包將數(shù)據(jù)加載到numpy array中例书,然后將其轉(zhuǎn)換成Tensor.

  • 對(duì)于圖像,常用的有:Pillow刻炒,OpenCV
  • 對(duì)于語音决采,常用的有:scipy, libosa
  • 對(duì)于文本,常用的有:NLTK, SpaCy

Pytorch提供的torchvision包封裝了常見數(shù)據(jù)集的數(shù)據(jù)加載函數(shù)坟奥,比如Imagenet树瞭,CIFAR10拇厢,MNIST等等它都提供了數(shù)據(jù)加載的功能。除此晒喷,它還提供了torchvision.datasetstorch.utils.data.DataLoader用于實(shí)現(xiàn)圖像數(shù)據(jù)轉(zhuǎn)換的功能孝偎。

訓(xùn)練圖像分類器

加載并處理CIFAR10

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

注:torchvision的輸出結(jié)果為PILImage, 處于[0,1]之間,所以我們將其轉(zhuǎn)換為[-1,1]之間的張量厨埋。

如果沒有CIFAR10的數(shù)據(jù)邪媳,代碼會(huì)自動(dòng)下載捐顷,輸出結(jié)果如下所示:


來來來荡陷,讓我們把訓(xùn)練圖片輸出來看看~

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

如下所示:


定義卷積神經(jīng)網(wǎng)絡(luò)

這部分的實(shí)現(xiàn)跟之前定義的神經(jīng)網(wǎng)絡(luò)是一樣的,除了cifar10是三通道輸入的迅涮,代碼如下:

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

print net語句可以將網(wǎng)絡(luò)結(jié)構(gòu)打印出來:

定義損失函數(shù)和優(yōu)化器

我們用交叉熵作為損失值废赞,用帶動(dòng)量的SGD隨機(jī)梯度下降法作為網(wǎng)絡(luò)的優(yōu)化器。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

訓(xùn)練神經(jīng)網(wǎng)絡(luò)

在pytorch中叮姑,我們不需要自己計(jì)算梯度唉地,只要不斷將訓(xùn)練數(shù)據(jù)喂給網(wǎng)絡(luò),然后調(diào)用優(yōu)化器進(jìn)行優(yōu)化就可以了传透。

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data

        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.data[0]
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

如上展示了訓(xùn)練的過程耘沼,實(shí)際中epoch=2還不夠,可以增大epoch來提高精度朱盐。

用測(cè)試數(shù)據(jù)進(jìn)行測(cè)試

現(xiàn)在已經(jīng)在訓(xùn)練數(shù)據(jù)上做了2輪的訓(xùn)練群嗤,我們現(xiàn)在可以檢查一下網(wǎng)絡(luò)是否有學(xué)習(xí)到東西。
我們先展示一些測(cè)試集里面的數(shù)據(jù):

dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

現(xiàn)在來看看網(wǎng)絡(luò)的預(yù)測(cè)結(jié)果

outputs = net(Variable(images))

# transform from score to label
_, predicted = torch.max(outputs.data, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))

可以看到預(yù)測(cè)結(jié)果還是挺準(zhǔn)確的兵琳。

再來看看整個(gè)測(cè)試集的運(yùn)行情況:

correct = 0
total = 0
for data in testloader:
    images, labels = data
    outputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

最后得到的結(jié)果為:


以下是輸出了每個(gè)類別的判斷

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
for data in testloader:
    images, labels = data
    outputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    c = (predicted == labels).squeeze()
    for i in range(4):
        label = labels[i]
        class_correct[label] += c[i]
        class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

結(jié)果如下:


比隨機(jī)識(shí)別的概率10%已經(jīng)要高很多了狂秘,當(dāng)然要精度更高可以增加訓(xùn)練的輪數(shù),改變學(xué)習(xí)率等等躯肌。

以上就完成了一個(gè)簡單分類器網(wǎng)絡(luò)的定義集訓(xùn)練者春,可以用這個(gè)為入門example跑跑看~~

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市清女,隨后出現(xiàn)的幾起案子钱烟,更是在濱河造成了極大的恐慌,老刑警劉巖嫡丙,帶你破解...
    沈念sama閱讀 216,591評(píng)論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件拴袭,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡迄沫,警方通過查閱死者的電腦和手機(jī)稻扬,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,448評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來羊瘩,“玉大人泰佳,你說我怎么就攤上這事盼砍。” “怎么了逝她?”我有些...
    開封第一講書人閱讀 162,823評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵浇坐,是天一觀的道長。 經(jīng)常有香客問我黔宛,道長近刘,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,204評(píng)論 1 292
  • 正文 為了忘掉前任臀晃,我火速辦了婚禮觉渴,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘徽惋。我一直安慰自己案淋,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,228評(píng)論 6 388
  • 文/花漫 我一把揭開白布险绘。 她就那樣靜靜地躺著踢京,像睡著了一般。 火紅的嫁衣襯著肌膚如雪宦棺。 梳的紋絲不亂的頭發(fā)上瓣距,一...
    開封第一講書人閱讀 51,190評(píng)論 1 299
  • 那天,我揣著相機(jī)與錄音代咸,去河邊找鬼蹈丸。 笑死,一個(gè)胖子當(dāng)著我的面吹牛侣背,可吹牛的內(nèi)容都是我干的白华。 我是一名探鬼主播,決...
    沈念sama閱讀 40,078評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼贩耐,長吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼弧腥!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起潮太,我...
    開封第一講書人閱讀 38,923評(píng)論 0 274
  • 序言:老撾萬榮一對(duì)情侶失蹤管搪,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后铡买,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體更鲁,經(jīng)...
    沈念sama閱讀 45,334評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,550評(píng)論 2 333
  • 正文 我和宋清朗相戀三年奇钞,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了澡为。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 39,727評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡景埃,死狀恐怖媒至,靈堂內(nèi)的尸體忽然破棺而出顶别,到底是詐尸還是另有隱情,我是刑警寧澤拒啰,帶...
    沈念sama閱讀 35,428評(píng)論 5 343
  • 正文 年R本政府宣布驯绎,位于F島的核電站,受9級(jí)特大地震影響谋旦,放射性物質(zhì)發(fā)生泄漏剩失。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,022評(píng)論 3 326
  • 文/蒙蒙 一册着、第九天 我趴在偏房一處隱蔽的房頂上張望拴孤。 院中可真熱鬧,春花似錦指蚜、人聲如沸乞巧。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,672評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至蚕冬,卻和暖如春免猾,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背囤热。 一陣腳步聲響...
    開封第一講書人閱讀 32,826評(píng)論 1 269
  • 我被黑心中介騙來泰國打工猎提, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人旁蔼。 一個(gè)月前我還...
    沈念sama閱讀 47,734評(píng)論 2 368
  • 正文 我出身青樓锨苏,卻偏偏與公主長得像,于是被迫代替她去往敵國和親棺聊。 傳聞我的和親對(duì)象是個(gè)殘疾皇子伞租,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,619評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容