我的實(shí)踐:通過螞蟻罢浇、蜜蜂二分類問題了解如何基于Pytorch構(gòu)建分類模型

1.數(shù)據(jù)集準(zhǔn)備

本例采用了pytorch教程提供的蜜蜂陆赋、螞蟻二分類數(shù)據(jù)集(點(diǎn)擊可直接下載)。該數(shù)據(jù)集的文件夾結(jié)構(gòu)如下圖所示己莺。這里面有些黑白的照片奏甫,我把它們刪掉了,因?yàn)楹诎渍掌耐ǖ罃?shù)是1凌受,會(huì)造成Tensor的維度不一致阵子。可以看出數(shù)據(jù)集分為訓(xùn)練集和測(cè)試集胜蛉,訓(xùn)練集用于訓(xùn)練模型挠进,測(cè)試集用于測(cè)試模型的泛化能力。在訓(xùn)練集和測(cè)試集下又包含了"ants"和"bees"兩個(gè)文件夾誊册,這兩個(gè)文件夾的名稱即圖片的標(biāo)簽领突,在加載數(shù)據(jù)的時(shí)候需要用到這一點(diǎn)。有了數(shù)據(jù)案怯,我們就想辦法把這些數(shù)據(jù)處理成pytorch框架下的Dataset需要的格式君旦。

請(qǐng)?zhí)砑訄D片描述

2.pytorch Dataset 處理圖片數(shù)據(jù)

pytorch為我們處理數(shù)據(jù)提供了一個(gè)模板,這個(gè)模板就是Dataset嘲碱,我們?cè)谔幚頂?shù)據(jù)時(shí)繼承這個(gè)類金砍。在處理數(shù)據(jù)時(shí)要注意以下幾點(diǎn):

  1. 可以用PIL的Image加載圖片,但要將圖片處理成tensor麦锯,而且tensor的維度要一致恕稠。這是因?yàn)閚n模型的輸入都是tensor格式,而且要求一個(gè)batchsize的tensor維度是一樣的扶欣。實(shí)現(xiàn)上述可能可以使用torchvision的transforms鹅巍。由于我用的CPU訓(xùn)練模型,所以對(duì)圖片壓縮的比較厲害料祠,全壓縮成33232的圖片了骆捧。
  2. "ants"和"bees"兩個(gè)文件夾的名稱就是圖片的標(biāo)簽,但是getitem的返回值應(yīng)該是一個(gè)值髓绽。在這里"ants"標(biāo)簽返回0凑懂,"bees"標(biāo)簽返回1。
  3. 看數(shù)據(jù)的預(yù)處理對(duì)不對(duì)梧宫,可以用一段代碼測(cè)試一下接谨,將數(shù)據(jù)加載到DataLoader,然后循環(huán)取出數(shù)據(jù)塘匣,并把這些數(shù)據(jù)及其標(biāo)簽打印出來脓豪,或者記錄到tensorboard上去,看每一次迭代返回的數(shù)據(jù)是否和自己預(yù)想的一樣忌卤。

下面是代碼扫夜,保存在dataProcess.py文件中。

rom torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import os
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

class MyData(Dataset):
    # 把圖片所在的文件夾路徑分成兩個(gè)部分驰徊,一部分是根目錄笤闯,一部分是標(biāo)簽?zāi)夸洠@是因?yàn)闃?biāo)簽?zāi)夸浀拿Q我們需要用到
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        # 圖片所在的文件夾路徑由根目錄和標(biāo)簽?zāi)夸浗M成
        self.path = os.path.join(self.root_dir, self.label_dir)
        # 獲取文件夾下所有圖片的名稱
        self.img_names = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        # 將圖片處理成Tensor格式棍厂,并將維度設(shè)置成32*32的
        # 圖片的維度可能不一致颗味,這里一定要用resize統(tǒng)一一下,否則會(huì)出錯(cuò)
        trans = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((32, 32))
            ])
        img_tensor = trans(img)
        # 根據(jù)標(biāo)簽?zāi)夸浀拿Q來確定圖片是哪一類牺弹,如果是"ants"浦马,標(biāo)簽設(shè)置為0,如果是"bees"张漂,標(biāo)簽設(shè)置為1
        # 這個(gè)地方要注意晶默,我們?cè)谟?jì)算loss的時(shí)候用交叉熵nn.CrossEntropyLoss()
        # 交叉熵的輸入有兩個(gè),一個(gè)是模型的輸出outputs航攒,一個(gè)是標(biāo)簽targets磺陡,注意targets是一維tensor
        # 例如batchsize如果是2,ants的targets的應(yīng)該[0,0]漠畜,而不是[[0][0]]
        # 因此label要返回0币他,而不是[0]
        label = 0 if self.label_dir == "ants" else 1
        return img_tensor,  label

    def __len__(self):
        return len(self.img_names)

# 用下面這段代碼測(cè)試一下加載數(shù)據(jù)有沒有問題
if __name__ == "__main__":
    # 注意hymenoptera_data和代碼在同一級(jí)目錄
    root_dir = "hymenoptera_data/train"
    ants_label = "ants"
    bees_label = "bees"
    # 螞蟻數(shù)據(jù)集
    ants_dataset = MyData(root_dir, ants_label)
    # 蜜蜂數(shù)據(jù)集
    bees_dataset = MyData(root_dir, bees_label)
    # 螞蟻數(shù)據(jù)集和蜜蜂數(shù)據(jù)集合并
    train_dataset = ants_dataset + bees_dataset
    # 利用dataLoader加載數(shù)據(jù)集
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
    # tensorboard的writer
    writer = SummaryWriter("logs")
    for step, train_data in enumerate(train_dataloader):
        imgs, targets = train_data
        # 每迭代一次就把一個(gè)batch的圖片記錄到tensorboard
        writer.add_images("test", imgs, step)
        # 每迭代一次就把一個(gè)batch的圖片標(biāo)簽打印出來
        print(targets)
    writer.close()

在測(cè)試時(shí)tensorboard記錄的信息在logs文件夾,在terminal輸入tensorboard --logdir=logs啟動(dòng)tensorboard盆驹,將tensorboard給出的網(wǎng)址輸入到網(wǎng)頁圆丹,可以看到每一個(gè)batch的圖片。下圖展示了第一個(gè)batch的圖片躯喇”璺猓可以看到,取出了64張圖片廉丽,和batchsize=64是對(duì)應(yīng)的倦微。另外可以看到,把圖片壓縮成32*32后正压,確實(shí)很模糊了欣福,人眼都很難看出哪個(gè)是螞蟻,哪個(gè)是蜜蜂焦履。


請(qǐng)?zhí)砑訄D片描述

下面這個(gè)圖展示了第一個(gè)batch所有圖片的標(biāo)簽拓劝,0表示螞蟻雏逾,1表示蜜蜂,仔細(xì)看一下圖片和標(biāo)簽應(yīng)該是對(duì)應(yīng)的郑临。


請(qǐng)?zhí)砑訄D片描述

3.網(wǎng)絡(luò)模型設(shè)計(jì)

我們把圖片處理成3*32*32的tensor了栖博,用如下圖所示的卷積神經(jīng)網(wǎng)絡(luò)模型。第一層卷積網(wǎng)絡(luò)采用5*5的卷積核厢洞,stride=1仇让,pading=2。第一層卷積的代碼是:nn.Conv2d(3, 32, 5, 1, 2)躺翻,第一個(gè)參數(shù)3是輸入的通道數(shù)丧叽,第二個(gè)參數(shù)32是輸出的通道數(shù),第三個(gè)參數(shù)5是卷積核的大小公你,第四個(gè)參數(shù)1是stride踊淳,第五個(gè)參數(shù)2是padding。


卷積網(wǎng)絡(luò)模型.png

輸出高H,和寬度W計(jì)算公式如下所示(注意dilation默認(rèn)為0)省店。

H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
因此嚣崭,通過第一層卷積后,高度H為懦傍,
H_{out}=\frac{32+2 \times 2 -1\times(5-1)-1}{1}+1=32
同理寬度W也為32雹舀。所以輸出的大小就32*32*32。接下來粗俱,再用一個(gè)max-Pooling進(jìn)行一次池化说榆,池化核的大小是2*2。該池化層的代碼是nn.MaxPool2d(2)寸认。池化輸出高H,和寬度W計(jì)算公式和卷積計(jì)算方式一摸一樣签财。在默認(rèn)的情況下,stride和池化和的大小一樣偏塞,pading=0唱蒸,dilation=0。所以第一次池化后灸叼,輸出的高度H為神汹,
H_{out}=\frac{32+2 \times 0 -1\times(2-1)-1}{2}+1=16
同理,輸出的寬度H為16古今。因此屁魏,輸出的維度是32*16*16。
后面的輸出維度計(jì)算方式同上捉腥,不再羅嗦了氓拼。然后再通過兩次卷積和兩次池化,后面的輸出維度計(jì)算方式同上,不再羅嗦了桃漾,最終得到一個(gè)維度為64*4*4的特征坏匪。在做分類之前,首先要把這個(gè)三維Tensor拉直成一維Tensor呈队,代碼是nn.Flatten()剥槐。拉直之后的一維Tensor大小就是64\times4\times4=1024。最后通過一個(gè)全連接層完成分類任務(wù)宪摧,全連接層的輸入大小是1024,輸出的大小是類別的個(gè)數(shù)颅崩,即2几于,代碼是nn.Linear(64 * 4 * 4, 2)。

當(dāng)完成所有模型的構(gòu)建后沿后,可以用一段代碼來測(cè)試一下模型是否有誤沿彭。例如這里模型的輸入在[3,32,32]Tensor的基礎(chǔ)上,還需要再增加一維batchsize尖滚,所以輸入的維度應(yīng)該是[batchsize,3,32,32]喉刘。我們可以生成一個(gè)這樣維度的數(shù)據(jù),例如假設(shè)batchsize=3漆弄,可以這樣生成一個(gè)輸入:x = torch.ones((3, 3, 32, 32))睦裳。然后把x送給模型,看模型是否能正常輸出撼唾,輸出的維度是否是我們預(yù)期的廉邑。我們還可以借助于Tensorboard來將模型可視化,通過界面把模型展開倒谷,看是否正確蛛蒙。
下面是所有的代碼,保存在model.py文件中渤愁。

from torch import nn
import torch
from torch.utils.tensorboard import SummaryWriter

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 2)
        )

    def forward(self, x):
        x = self.model(x)
        return x

# 這段代碼測(cè)試model是否正確
if __name__ == "__main__":
    my_model = MyModel()
    x = torch.ones((3, 3, 32, 32))
    y = my_model(x)
    print(y.shape)
    # 利用tensorboard可視化模型
    writer = SummaryWriter("graph_logs")
    writer.add_graph(my_model, x)
    writer.close()

模型測(cè)試代碼打印的輸出維度是[3,2]牵祟,3是batchsize,2是全連接層最后的輸出維度抖格,和類別的個(gè)數(shù)是一致的诺苹。利用Tensorboard將模型可視化后,如下圖所示他挎,還可以進(jìn)一步展開筝尾。


請(qǐng)?zhí)砑訄D片描述

4.模型的訓(xùn)練與測(cè)試

模型的訓(xùn)練與測(cè)試就不細(xì)講了,和其他模型訓(xùn)練的套路一樣的办桨,基本思路可以看我的第一篇[pytorch入門文章](我的實(shí)踐:通過一個(gè)簡(jiǎn)單線性回歸入門pytorch - 簡(jiǎn)書 (jianshu.com)
)筹淫。下面直接給出代碼棘利。

from model import *
from dataProcess import *
import matplotlib.pyplot as plt
import time

# 加載訓(xùn)練數(shù)據(jù)
train_root_dir = "hymenoptera_data/train"
train_ants_label = "ants"
train_bees_label = "bees"
train_ants_dataset = MyData(train_root_dir, train_ants_label)
train_bees_dataset = MyData(train_root_dir, train_bees_label)
train_dataset = train_ants_dataset + train_bees_dataset
train_data_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
train_data_len = len(train_dataset)
# 加載測(cè)試數(shù)據(jù)
test_root_dir = "hymenoptera_data/val"
test_ants_label = "ants"
test_bees_label = "bees"
test_ants_dataset = MyData(test_root_dir, test_ants_label)
test_bees_dataset = MyData(test_root_dir, test_bees_label)
test_dataset = test_ants_dataset + test_bees_dataset
test_data_loader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=True)
test_data_len = len(test_dataset)
print(f"訓(xùn)練集長(zhǎng)度:{train_data_len}")
print(f"測(cè)試集長(zhǎng)度:{test_data_len}")
# 創(chuàng)建網(wǎng)絡(luò)模型
my_model = MyModel()

# 損失函數(shù)
loss_fn = nn.CrossEntropyLoss()

# 優(yōu)化器
learning_rate = 5e-3
optimizer = torch.optim.SGD(my_model.parameters(), lr=learning_rate)
# Adam 參數(shù)betas=(0.9, 0.99)
# optimizer = torch.optim.Adam(my_model.parameters(), lr=learning_rate, betas=(0.9, 0.99))
# 總共的訓(xùn)練步數(shù)
total_train_step = 0
# 總共的測(cè)試步數(shù)
total_test_step = 0
step = 0
epoch = 500

writer = SummaryWriter("logs")
train_loss_his = []
train_totalaccuracy_his = []
test_totalloss_his = []
test_totalaccuracy_his = []
start_time = time.time()
my_model.train()
for i in range(epoch):
    print(f"-------第{i}輪訓(xùn)練開始-------")
    train_total_accuracy = 0
    for data in train_data_loader:
        imgs, targets = data
        writer.add_images("tarin_data", imgs, total_train_step)
        output = my_model(imgs)
        loss = loss_fn(output, targets)
        train_accuracy = (output.argmax(1) == targets).sum()
        train_total_accuracy = train_total_accuracy + train_accuracy
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_step = total_train_step + 1
        train_loss_his.append(loss)
        writer.add_scalar("train_loss", loss.item(), total_train_step)
    train_total_accuracy = train_total_accuracy / train_data_len
    print(f"訓(xùn)練集上的準(zhǔn)確率:{train_total_accuracy}")
    train_totalaccuracy_his.append(train_total_accuracy)
    # 測(cè)試開始
    total_test_loss = 0
    my_model.eval()
    test_total_accuracy = 0
    with torch.no_grad():
        for data in test_data_loader:
            imgs, targets = data
            output = my_model(imgs)
            loss = loss_fn(output, targets)
            total_test_loss = total_test_loss + loss
            test_accuracy = (output.argmax(1) == targets).sum()
            test_total_accuracy = test_total_accuracy + test_accuracy
        test_total_accuracy = test_total_accuracy / test_data_len
        print(f"測(cè)試集上的準(zhǔn)確率:{test_total_accuracy}")
        print(f"測(cè)試集上的loss:{total_test_loss}")
        test_totalloss_his.append(total_test_loss)
        test_totalaccuracy_his.append(test_total_accuracy)
        writer.add_scalar("test_loss", total_test_loss.item(), i)
end_time = time.time()
total_train_time = end_time-start_time
print(f'訓(xùn)練時(shí)間: {total_train_time}秒')
writer.close()
plt.plot(train_loss_his, label='Train Loss')
plt.legend(loc='best')
plt.xlabel('Steps')
plt.show()
plt.plot(test_totalloss_his, label='Test Loss')
plt.legend(loc='best')
plt.xlabel('Steps')
plt.show()

plt.plot(train_totalaccuracy_his, label='Train accuracy')
plt.plot(test_totalaccuracy_his, label='Test accuracy')
plt.legend(loc='best')
plt.xlabel('Steps')
plt.show()

通過上述代碼,訓(xùn)練得到的結(jié)果如下圖所示狮鸭,


請(qǐng)?zhí)砑訄D片描述

結(jié)果雖然不是很好忠荞,但是我覺得已經(jīng)很不多了,在測(cè)試集上的準(zhǔn)確率差不多達(dá)到0.7了摧阅。為了節(jié)省計(jì)算資源汰蓉,我把圖片壓縮成32*32,連我們?nèi)搜鄱己茈y分辨出哪個(gè)是螞蟻棒卷,哪個(gè)是蜜蜂顾孽。另外,我這個(gè)模型是完全從0開始訓(xùn)練的比规,隔壁在預(yù)訓(xùn)練模型的基礎(chǔ)上進(jìn)行訓(xùn)練得到的效果好像沒好多少若厚。。蜒什。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末测秸,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子灾常,更是在濱河造成了極大的恐慌霎冯,老刑警劉巖,帶你破解...
    沈念sama閱讀 216,470評(píng)論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件钞瀑,死亡現(xiàn)場(chǎng)離奇詭異沈撞,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)仔戈,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,393評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門关串,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人监徘,你說我怎么就攤上這事晋修。” “怎么了凰盔?”我有些...
    開封第一講書人閱讀 162,577評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵墓卦,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我户敬,道長(zhǎng)落剪,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,176評(píng)論 1 292
  • 正文 為了忘掉前任尿庐,我火速辦了婚禮忠怖,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘抄瑟。我一直安慰自己凡泣,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,189評(píng)論 6 388
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著鞋拟,像睡著了一般骂维。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上贺纲,一...
    開封第一講書人閱讀 51,155評(píng)論 1 299
  • 那天航闺,我揣著相機(jī)與錄音,去河邊找鬼猴誊。 笑死潦刃,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的稠肘。 我是一名探鬼主播福铅,決...
    沈念sama閱讀 40,041評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼项阴!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起笆包,我...
    開封第一講書人閱讀 38,903評(píng)論 0 274
  • 序言:老撾萬榮一對(duì)情侶失蹤环揽,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后庵佣,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體歉胶,經(jīng)...
    沈念sama閱讀 45,319評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,539評(píng)論 2 332
  • 正文 我和宋清朗相戀三年巴粪,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了通今。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 39,703評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡肛根,死狀恐怖辫塌,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情派哲,我是刑警寧澤臼氨,帶...
    沈念sama閱讀 35,417評(píng)論 5 343
  • 正文 年R本政府宣布,位于F島的核電站芭届,受9級(jí)特大地震影響储矩,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜褂乍,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,013評(píng)論 3 325
  • 文/蒙蒙 一持隧、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧逃片,春花似錦屡拨、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,664評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽层皱。三九已至,卻和暖如春赠潦,著一層夾襖步出監(jiān)牢的瞬間叫胖,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,818評(píng)論 1 269
  • 我被黑心中介騙來泰國打工她奥, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留瓮增,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 47,711評(píng)論 2 368
  • 正文 我出身青樓哩俭,卻偏偏與公主長(zhǎng)得像绷跑,于是被迫代替她去往敵國和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子凡资,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,601評(píng)論 2 353