圖像分類學(xué)習(xí)(3):X光胸片診斷識(shí)別——遷移學(xué)習(xí)

1结窘、數(shù)據(jù)介紹

  • 數(shù)據(jù)源于kaggle错负,可在此鏈接自行下載
  • 數(shù)據(jù)集分為3個(gè)文件夾(train项贺,test,val)寇蚊,并包含每個(gè)圖像類別(Pneumonia / Normal)的子文件夾笔时。有5,863個(gè)X射線圖像(JPEG)和2個(gè)類別(肺炎/正常)

2、遷移學(xué)習(xí)

由于從頭訓(xùn)練一個(gè)神經(jīng)網(wǎng)絡(luò)需要花費(fèi)的時(shí)間較長仗岸,而且對(duì)數(shù)據(jù)量的要求也比較大允耿。在實(shí)踐中,很少有人從頭開始訓(xùn)練整個(gè)卷積網(wǎng)絡(luò)(隨機(jī)初始化)扒怖,因?yàn)閾碛凶銐虼笮〉臄?shù)據(jù)集相對(duì)來說比較少見较锡。相反,pytorch中提供的這些模型都已經(jīng)預(yù)先在1000類的Imagenet數(shù)據(jù)集上訓(xùn)練完成盗痒÷煸蹋可以直接拿來訓(xùn)練自己的數(shù)據(jù)集,即稱為模型微調(diào)或者遷移學(xué)習(xí)俯邓。
遷移學(xué)習(xí)包含微調(diào)和特征提取骡楼。 在微調(diào)中,我們從一個(gè)預(yù)訓(xùn)練模型開始稽鞭,然后為我們的新任務(wù)更新所有的模型參數(shù)鸟整,實(shí)質(zhì)上就是重新訓(xùn)練整個(gè)模型。 在特征提取中朦蕴,我們從預(yù)訓(xùn)練模型開始吃嘿,只更新產(chǎn)生預(yù)測(cè)的最后一層的權(quán)重。它被稱為特征提取是因?yàn)槲覀兪褂妙A(yù)訓(xùn)練的CNN作為固定的特征提取器梦重,并且僅改變輸出層兑燥。
本次圖像分類只對(duì)模型進(jìn)行特征提取,即更改最后一個(gè)全連接層琴拧,然后進(jìn)行模型訓(xùn)練

3降瞳、建立模型

3.1 了解數(shù)據(jù)

首先先導(dǎo)入需要用到的模塊,定義一下基本參數(shù)。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

data_dir = './cv/chest_xray'
model_name = 'vgg'
num_classes = 2
batch_size = 16
num_epochs = 10
input_size = 224
device = torch.device('cuda' if torch.cuda.is_available() else 'gpu')

進(jìn)行一系列數(shù)據(jù)增強(qiáng)挣饥,然后生成訓(xùn)練除师、驗(yàn)證、和測(cè)試數(shù)據(jù)集扔枫。

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}

print("Initializing Datasets and Dataloaders...")


image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val', 'test']}

dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val', 'test']}

定義一個(gè)查看圖片和標(biāo)簽的函數(shù)

def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) 

imgs, labels = next(iter(dataloaders_dict['train']))


out = torchvision.utils.make_grid(imgs[:8])

imshow(out, title=[classes[x] for x in labels[:8]])

OUT :


可以看到圖片有兩個(gè)類別的X光片汛聚,PNEUMONIA(肺炎),NORMAL(正常)

classes = image_datasets['test'].classes
classes

OUT :
['NORMAL', 'PNEUMONIA']

3.2 建立VGG16遷移學(xué)習(xí)模型

model = torchvision.models.vgg16(pretrained=True)
model

pretrained=True,則會(huì)下載預(yù)訓(xùn)練權(quán)重短荐,需要耐心等待一段時(shí)間倚舀。
查看一下VGG16的結(jié)構(gòu):

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

可以看到VGG16主要由features和classifier兩種結(jié)構(gòu)組成,classifier[6]為最后一層忍宋,我們將它的輸出改為我們的類別數(shù)2痕貌。由于我們只需要訓(xùn)練最后一層,再改之前我們先將模型的參數(shù)設(shè)置為不可更新糠排。

# 先將模型參數(shù)改為不可更行
for param in model.parameters():
    param.requires_grad = False
# 再更改最后一層的輸出舵稠,至此網(wǎng)絡(luò)只能更行該層參數(shù)
model.classifier[6] = nn.Linear(4096, num_classes)

3.3 定義訓(xùn)練函數(shù)

def train_model(model, dataloaders, criterion, optimizer, mun_epochs=25):
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs-1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0.0
            
            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += (preds == labels).sum().item()
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects / len(dataloaders[phase].dataset)
            
            print('{} loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        print()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:.4f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model

3.4 定義優(yōu)化器和損失函數(shù)

model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

3.5 開始訓(xùn)練

model_ft = train_model(model, dataloaders_dict, criterion, optimizer, num_epochs)

OUT :

Epoch 0/9
----------
train loss: 0.4240 Acc: 0.8171
val loss: 0.2968 Acc: 0.8125

Epoch 1/9
----------
train loss: 0.4095 Acc: 0.8284
val loss: 0.1901 Acc: 0.9375

Epoch 2/9
----------
train loss: 0.3972 Acc: 0.8424
val loss: 0.2445 Acc: 0.9375

Epoch 3/9
----------
train loss: 0.4145 Acc: 0.8315
val loss: 0.1973 Acc: 0.9375

Epoch 4/9
----------
train loss: 0.4012 Acc: 0.8416
val loss: 0.1253 Acc: 1.0000

Epoch 5/9
----------
train loss: 0.3976 Acc: 0.8489
val loss: 0.1904 Acc: 0.9375

Epoch 6/9
----------
train loss: 0.4025 Acc: 0.8432
val loss: 0.1527 Acc: 1.0000

Epoch 7/9
----------
train loss: 0.3768 Acc: 0.8495
val loss: 0.1761 Acc: 1.0000

Epoch 8/9
----------
train loss: 0.3906 Acc: 0.8472
val loss: 0.1346 Acc: 1.0000

Epoch 9/9
----------
train loss: 0.3847 Acc: 0.8403
val loss: 0.0996 Acc: 1.0000

Training complete in 7m 16s
Best val Acc: 1.0000

經(jīng)過10輪訓(xùn)練,驗(yàn)證集的準(zhǔn)確率已達(dá)到100%入宦,由于驗(yàn)證集的圖片很少(大概只有十多張)哺徊,可能不太能說明網(wǎng)絡(luò)的訓(xùn)練效果。接下來看一下模型在測(cè)試集表現(xiàn)如何乾闰。

3.6 測(cè)試集評(píng)估

首先我們先拿出10張X光片給模型進(jìn)行判斷落追,看看它能否準(zhǔn)確預(yù)測(cè)出X光片的類別。

imgs, labels = next(iter(dataloaders_dict['test']))
imgs, labels = imgs.to(device), labels.to(device)
outputs = model_ft(imgs)
_, preds = torch.max(outputs, 1)
print('real:' + ' '.join('%9s' % classes[labels[j]] for j in range(10)))
print('pred:' + ' '.join('%9s' % classes[preds[j]] for j in range(10)))

OUT :

real:   NORMAL PNEUMONIA PNEUMONIA    NORMAL    NORMAL    NORMAL    NORMAL PNEUMONIA PNEUMONIA PNEUMONIA
pred:PNEUMONIA PNEUMONIA PNEUMONIA PNEUMONIA    NORMAL    NORMAL    NORMAL PNEUMONIA PNEUMONIA PNEUMONIA

十張X片汹忠,其中第一張和第四張錯(cuò)誤的將正常預(yù)測(cè)成了肺炎淋硝,其他八張預(yù)測(cè)正確雹熬。有80%的準(zhǔn)確率宽菜,最后我們查看一下在全部的測(cè)試集中的準(zhǔn)確率是否有80%左右。

correct = 0.0
for imgs, labels in dataloaders_dict['test']:
    imgs, labels = imgs.to(device), labels.to(device)
    output = model_ft(imgs)
    _, preds = torch.max(output, 1)
    correct += (preds == labels).sum().item()
print('test accuracy:{:.2f}%'.format(100 * correct / len(dataloaders_dict['test'].dataset)))

OUT :
test accuracy:82.53%

4竿报、總結(jié)

模型預(yù)測(cè)的準(zhǔn)確率為82.53%铅乡,并沒有想象中的高。
若想進(jìn)一步提升模型準(zhǔn)確率烈菌,我覺得可以在以下幾個(gè)方面改進(jìn):

  • 換一個(gè)帶批標(biāo)準(zhǔn)化的VGG模型或直接換一個(gè)更強(qiáng)大的模型阵幸,如ResNet。
  • 減小學(xué)習(xí)率芽世,或者訓(xùn)練時(shí)進(jìn)行學(xué)習(xí)率衰減挚赊。
  • 嘗試增加epoch,更改batch_size等济瓢。
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末荠割,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌蔑鹦,老刑警劉巖夺克,帶你破解...
    沈念sama閱讀 216,544評(píng)論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異嚎朽,居然都是意外死亡铺纽,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,430評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門哟忍,熙熙樓的掌柜王于貴愁眉苦臉地迎上來狡门,“玉大人,你說我怎么就攤上這事魁索∪谧玻” “怎么了?”我有些...
    開封第一講書人閱讀 162,764評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵粗蔚,是天一觀的道長尝偎。 經(jīng)常有香客問我,道長鹏控,這世上最難降的妖魔是什么致扯? 我笑而不...
    開封第一講書人閱讀 58,193評(píng)論 1 292
  • 正文 為了忘掉前任,我火速辦了婚禮当辐,結(jié)果婚禮上抖僵,老公的妹妹穿的比我還像新娘。我一直安慰自己缘揪,他們只是感情好耍群,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,216評(píng)論 6 388
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著找筝,像睡著了一般蹈垢。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上袖裕,一...
    開封第一講書人閱讀 51,182評(píng)論 1 299
  • 那天曹抬,我揣著相機(jī)與錄音,去河邊找鬼急鳄。 笑死谤民,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的疾宏。 我是一名探鬼主播张足,決...
    沈念sama閱讀 40,063評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼坎藐!你這毒婦竟也來了为牍?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 38,917評(píng)論 0 274
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎吵聪,沒想到半個(gè)月后凌那,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,329評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡吟逝,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,543評(píng)論 2 332
  • 正文 我和宋清朗相戀三年帽蝶,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片块攒。...
    茶點(diǎn)故事閱讀 39,722評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡励稳,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出囱井,到底是詐尸還是另有隱情驹尼,我是刑警寧澤,帶...
    沈念sama閱讀 35,425評(píng)論 5 343
  • 正文 年R本政府宣布庞呕,位于F島的核電站新翎,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏住练。R本人自食惡果不足惜地啰,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,019評(píng)論 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望讲逛。 院中可真熱鬧亏吝,春花似錦、人聲如沸盏混。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,671評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽许赃。三九已至止喷,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間图焰,已是汗流浹背启盛。 一陣腳步聲響...
    開封第一講書人閱讀 32,825評(píng)論 1 269
  • 我被黑心中介騙來泰國打工蹦掐, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留技羔,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 47,729評(píng)論 2 368
  • 正文 我出身青樓卧抗,卻偏偏與公主長得像藤滥,于是被迫代替她去往敵國和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子社裆,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,614評(píng)論 2 353