1结窘、數(shù)據(jù)介紹
- 數(shù)據(jù)源于kaggle错负,可在此鏈接自行下載
- 數(shù)據(jù)集分為3個(gè)文件夾(train项贺,test,val)寇蚊,并包含每個(gè)圖像類別(Pneumonia / Normal)的子文件夾笔时。有5,863個(gè)X射線圖像(JPEG)和2個(gè)類別(肺炎/正常)
2、遷移學(xué)習(xí)
由于從頭訓(xùn)練一個(gè)神經(jīng)網(wǎng)絡(luò)需要花費(fèi)的時(shí)間較長仗岸,而且對(duì)數(shù)據(jù)量的要求也比較大允耿。在實(shí)踐中,很少有人從頭開始訓(xùn)練整個(gè)卷積網(wǎng)絡(luò)(隨機(jī)初始化)扒怖,因?yàn)閾碛凶銐虼笮〉臄?shù)據(jù)集相對(duì)來說比較少見较锡。相反,pytorch中提供的這些模型都已經(jīng)預(yù)先在1000類的Imagenet數(shù)據(jù)集上訓(xùn)練完成盗痒÷煸蹋可以直接拿來訓(xùn)練自己的數(shù)據(jù)集,即稱為模型微調(diào)或者遷移學(xué)習(xí)俯邓。
遷移學(xué)習(xí)包含微調(diào)和特征提取骡楼。 在微調(diào)中,我們從一個(gè)預(yù)訓(xùn)練模型開始稽鞭,然后為我們的新任務(wù)更新所有的模型參數(shù)鸟整,實(shí)質(zhì)上就是重新訓(xùn)練整個(gè)模型。 在特征提取中朦蕴,我們從預(yù)訓(xùn)練模型開始吃嘿,只更新產(chǎn)生預(yù)測(cè)的最后一層的權(quán)重。它被稱為特征提取是因?yàn)槲覀兪褂妙A(yù)訓(xùn)練的CNN作為固定的特征提取器梦重,并且僅改變輸出層兑燥。
本次圖像分類只對(duì)模型進(jìn)行特征提取,即更改最后一個(gè)全連接層琴拧,然后進(jìn)行模型訓(xùn)練
3降瞳、建立模型
3.1 了解數(shù)據(jù)
首先先導(dǎo)入需要用到的模塊,定義一下基本參數(shù)。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
data_dir = './cv/chest_xray'
model_name = 'vgg'
num_classes = 2
batch_size = 16
num_epochs = 10
input_size = 224
device = torch.device('cuda' if torch.cuda.is_available() else 'gpu')
進(jìn)行一系列數(shù)據(jù)增強(qiáng)挣饥,然后生成訓(xùn)練除师、驗(yàn)證、和測(cè)試數(shù)據(jù)集扔枫。
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}
print("Initializing Datasets and Dataloaders...")
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val', 'test']}
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val', 'test']}
定義一個(gè)查看圖片和標(biāo)簽的函數(shù)
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
imgs, labels = next(iter(dataloaders_dict['train']))
out = torchvision.utils.make_grid(imgs[:8])
imshow(out, title=[classes[x] for x in labels[:8]])
OUT :
可以看到圖片有兩個(gè)類別的X光片汛聚,PNEUMONIA(肺炎),NORMAL(正常)
classes = image_datasets['test'].classes
classes
OUT :
['NORMAL', 'PNEUMONIA']
3.2 建立VGG16遷移學(xué)習(xí)模型
model = torchvision.models.vgg16(pretrained=True)
model
pretrained=True,則會(huì)下載預(yù)訓(xùn)練權(quán)重短荐,需要耐心等待一段時(shí)間倚舀。
查看一下VGG16的結(jié)構(gòu):
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace)
(2): Dropout(p=0.5)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace)
(5): Dropout(p=0.5)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
可以看到VGG16主要由features和classifier兩種結(jié)構(gòu)組成,classifier[6]為最后一層忍宋,我們將它的輸出改為我們的類別數(shù)2痕貌。由于我們只需要訓(xùn)練最后一層,再改之前我們先將模型的參數(shù)設(shè)置為不可更新糠排。
# 先將模型參數(shù)改為不可更行
for param in model.parameters():
param.requires_grad = False
# 再更改最后一層的輸出舵稠,至此網(wǎng)絡(luò)只能更行該層參數(shù)
model.classifier[6] = nn.Linear(4096, num_classes)
3.3 定義訓(xùn)練函數(shù)
def train_model(model, dataloaders, criterion, optimizer, mun_epochs=25):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs-1))
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0.0
for inputs, labels in dataloaders[phase]:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += (preds == labels).sum().item()
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects / len(dataloaders[phase].dataset)
print('{} loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:.4f}'.format(best_acc))
model.load_state_dict(best_model_wts)
return model
3.4 定義優(yōu)化器和損失函數(shù)
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
3.5 開始訓(xùn)練
model_ft = train_model(model, dataloaders_dict, criterion, optimizer, num_epochs)
OUT :
Epoch 0/9
----------
train loss: 0.4240 Acc: 0.8171
val loss: 0.2968 Acc: 0.8125
Epoch 1/9
----------
train loss: 0.4095 Acc: 0.8284
val loss: 0.1901 Acc: 0.9375
Epoch 2/9
----------
train loss: 0.3972 Acc: 0.8424
val loss: 0.2445 Acc: 0.9375
Epoch 3/9
----------
train loss: 0.4145 Acc: 0.8315
val loss: 0.1973 Acc: 0.9375
Epoch 4/9
----------
train loss: 0.4012 Acc: 0.8416
val loss: 0.1253 Acc: 1.0000
Epoch 5/9
----------
train loss: 0.3976 Acc: 0.8489
val loss: 0.1904 Acc: 0.9375
Epoch 6/9
----------
train loss: 0.4025 Acc: 0.8432
val loss: 0.1527 Acc: 1.0000
Epoch 7/9
----------
train loss: 0.3768 Acc: 0.8495
val loss: 0.1761 Acc: 1.0000
Epoch 8/9
----------
train loss: 0.3906 Acc: 0.8472
val loss: 0.1346 Acc: 1.0000
Epoch 9/9
----------
train loss: 0.3847 Acc: 0.8403
val loss: 0.0996 Acc: 1.0000
Training complete in 7m 16s
Best val Acc: 1.0000
經(jīng)過10輪訓(xùn)練,驗(yàn)證集的準(zhǔn)確率已達(dá)到100%入宦,由于驗(yàn)證集的圖片很少(大概只有十多張)哺徊,可能不太能說明網(wǎng)絡(luò)的訓(xùn)練效果。接下來看一下模型在測(cè)試集表現(xiàn)如何乾闰。
3.6 測(cè)試集評(píng)估
首先我們先拿出10張X光片給模型進(jìn)行判斷落追,看看它能否準(zhǔn)確預(yù)測(cè)出X光片的類別。
imgs, labels = next(iter(dataloaders_dict['test']))
imgs, labels = imgs.to(device), labels.to(device)
outputs = model_ft(imgs)
_, preds = torch.max(outputs, 1)
print('real:' + ' '.join('%9s' % classes[labels[j]] for j in range(10)))
print('pred:' + ' '.join('%9s' % classes[preds[j]] for j in range(10)))
OUT :
real: NORMAL PNEUMONIA PNEUMONIA NORMAL NORMAL NORMAL NORMAL PNEUMONIA PNEUMONIA PNEUMONIA
pred:PNEUMONIA PNEUMONIA PNEUMONIA PNEUMONIA NORMAL NORMAL NORMAL PNEUMONIA PNEUMONIA PNEUMONIA
十張X片汹忠,其中第一張和第四張錯(cuò)誤的將正常預(yù)測(cè)成了肺炎淋硝,其他八張預(yù)測(cè)正確雹熬。有80%的準(zhǔn)確率宽菜,最后我們查看一下在全部的測(cè)試集中的準(zhǔn)確率是否有80%左右。
correct = 0.0
for imgs, labels in dataloaders_dict['test']:
imgs, labels = imgs.to(device), labels.to(device)
output = model_ft(imgs)
_, preds = torch.max(output, 1)
correct += (preds == labels).sum().item()
print('test accuracy:{:.2f}%'.format(100 * correct / len(dataloaders_dict['test'].dataset)))
OUT :
test accuracy:82.53%
4竿报、總結(jié)
模型預(yù)測(cè)的準(zhǔn)確率為82.53%铅乡,并沒有想象中的高。
若想進(jìn)一步提升模型準(zhǔn)確率烈菌,我覺得可以在以下幾個(gè)方面改進(jìn):
- 換一個(gè)帶批標(biāo)準(zhǔn)化的VGG模型或直接換一個(gè)更強(qiáng)大的模型阵幸,如ResNet。
- 減小學(xué)習(xí)率芽世,或者訓(xùn)練時(shí)進(jìn)行學(xué)習(xí)率衰減挚赊。
- 嘗試增加epoch,更改batch_size等济瓢。