前言
入門深度學(xué)習(xí),一般都是跑最經(jīng)典的MNIST+LeNet-5, LeNet-5網(wǎng)絡(luò)結(jié)構(gòu)簡單责嚷,MNIST數(shù)據(jù)集也不是很大,對于初學(xué)者來說很方便和友好掂铐。作為進(jìn)階罕拂,熟悉Pytorch基本用法之后,躍躍欲試全陨,想自己手寫一個(gè)CNN網(wǎng)絡(luò)爆班,在一個(gè)數(shù)據(jù)集上進(jìn)行訓(xùn)練和測試。
FashionMNIST數(shù)據(jù)集作為進(jìn)階的練習(xí)很不錯(cuò)辱姨,本實(shí)驗(yàn)將基于FashionMNIST數(shù)據(jù)集從頭到尾訓(xùn)練測試一個(gè)CNN網(wǎng)絡(luò)柿菩。
FashionMNIST數(shù)據(jù)集
簡介
https://github.com/zalandoresearch/fashion-mnist
Fashion-MNIST
is a dataset of Zalando's article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. We intendFashion-MNIST
to serve as a direct drop-in replacement for the original MNIST dataset for benchmarking machine learning algorithms. It shares the same image size and structure of training and testing splits.
FashionMNIST數(shù)據(jù)特點(diǎn):
- 60,000個(gè)訓(xùn)練樣本+10,000個(gè)測試樣本
- 樣本圖像為灰度雨涛,28x28
- 10個(gè)類別
Labels
Each training and test example is assigned to one of the following labels:
Label | Description |
---|---|
0 | T-shirt/top |
1 | Trouser |
2 | Pullover |
3 | Dress |
4 | Coat |
5 | Sandal |
6 | Shirt |
7 | Sneaker |
8 | Bag |
9 | Ankle boot |
image.png
image.png
Why we made Fashion-MNIST
Why we made Fashion-MNIST
The original MNIST dataset contains a lot of handwritten digits. Members of the AI/ML/Data Science community love this dataset and use it as a benchmark to validate their algorithms. In fact, MNIST is often the first dataset researchers try. "If it doesn't work on MNIST, it won't work at all", they said. "Well, if it does work on MNIST, it may still fail on others."
To Serious Machine Learning Researchers
Seriously, we are talking about replacing MNIST. Here are some good reasons:
- MNIST is too easy. Convolutional nets can achieve 99.7% on MNIST. Classic machine learning algorithms can also achieve 97% easily. Check out our side-by-side benchmark for Fashion-MNIST vs. MNIST, and read "Most pairs of MNIST digits can be distinguished pretty well by just one pixel."
- MNIST is overused. In this April 2017 Twitter thread, Google Brain research scientist and deep learning expert Ian Goodfellow calls for people to move away from MNIST.
- MNIST can not represent modern CV tasks, as noted in this April 2017 Twitter thread, deep learning expert/Keras author Fran?ois Chollet.
實(shí)驗(yàn)
獲取數(shù)據(jù)集
可以自己在網(wǎng)站上下載數(shù)據(jù)枢舶,pytorch提供了更好的方式,直接使用torchvision.datasets
中的API,自動(dòng)下載數(shù)據(jù)替久。
由于采用CPU模式凉泄,batch size 設(shè)置為4, 使用GPU模式蚯根,顯存足夠大的話可以將batch size設(shè)置大一些后众,使用英偉達(dá)1080 Ti, 本人設(shè)置為batch size = 16
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils
# https://blog.csdn.net/weixin_41278720/article/details/80778640
# ---------------------------數(shù)據(jù)集-------------------------------------
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, transform=tranform)
val_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)
# 隨機(jī)顯示一個(gè)batch
plt.figure()
utils.imshow_batch(next(iter(train_dataloader)))
plt.show()
下載完成之后的數(shù)據(jù)集:
定義一個(gè)CNN網(wǎng)絡(luò)
定義網(wǎng)絡(luò)的一般格式:
- 繼承
nn.Module
- 在
__init()__
中定義網(wǎng)絡(luò)的層 - 重寫(override)父類的抽象方法
forward()
區(qū)別與之前定義LeNet-5, 此次采用nn.Sequential
, 傳入一個(gè)有序字典OrderedDict
。加入了BatchNorm, Dropout層稼锅, 并且第一個(gè)卷積之后并沒有進(jìn)行池化吼具,這樣可以保留更多的信息進(jìn)入下一層。
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class Net(nn.Module):
'''
自定義的CNN網(wǎng)絡(luò)矩距,3個(gè)卷積層拗盒,包含batch norm。2個(gè)pool,
3個(gè)全連接層锥债,包含Dropout
輸入:28x28x1s
'''
def __init__(self):
super(Net, self).__init__()
self.feature = nn.Sequential(
OrderedDict(
[
# 28x28x1
('conv1', nn.Conv2d(in_channels=1,
out_channels=32,
kernel_size=5,
stride=1,
padding=2)),
('relu1', nn.ReLU()),
('bn1', nn.BatchNorm2d(num_features=32)),
# 28x28x32
('conv2', nn.Conv2d(in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)),
('relu2', nn.ReLU()),
('bn2', nn.BatchNorm2d(num_features=64)),
('pool1', nn.MaxPool2d(kernel_size=2)),
# 14x14x64
('conv3', nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=1)),
('relu3', nn.ReLU()),
('bn3', nn.BatchNorm2d(num_features=128)),
('pool2', nn.MaxPool2d(kernel_size=2)),
# 7x7x128
('conv4', nn.Conv2d(in_channels=128,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)),
('relu4', nn.ReLU()),
('bn4', nn.BatchNorm2d(num_features=64)),
('pool3', nn.MaxPool2d(kernel_size=2)),
# out 3x3x64
]
)
)
self.classifier = nn.Sequential(
OrderedDict(
[
('fc1', nn.Linear(in_features=3 * 3 * 64,
out_features=128)),
('dropout1', nn.Dropout2d(p=0.5)),
('fc2', nn.Linear(in_features=128,
out_features=64)),
('dropout2', nn.Dropout2d(p=0.6)),
('fc3', nn.Linear(in_features=64, out_features=10))
]
)
)
def forward(self, x):
out = self.feature(x)
out = out.view(-1, 64 * 3 *3)
out = self.classifier(out)
return out
訓(xùn)練CNN網(wǎng)絡(luò)
- epoch num設(shè)置為100陡蝇, GPU跑的話其實(shí)很快就跑完了
- 每迭代100次,進(jìn)行一次測試哮肚,統(tǒng)計(jì)Accuarcy, running loss打印一次登夫,并且保存的log文本中,方便后序的分析
- 訓(xùn)練時(shí)候允趟,調(diào)用net.train() 將模型設(shè)置為train()模式恼策, 測試時(shí)候調(diào)用net.eval()將模型設(shè)置為eval()模式。 否則結(jié)果不正確,因?yàn)榫W(wǎng)絡(luò)中使用了BatchNorm和Dropout涣楷,兩者在eval(), train()模式下有所差異分唾,具體看pytorch文檔。
- 訓(xùn)練完成之后狮斗,保存模型绽乔。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils
# https://blog.csdn.net/weixin_41278720/article/details/80778640
# ---------------------------數(shù)據(jù)集-------------------------------------
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, transform=tranform)
val_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)
# 隨機(jī)顯示一個(gè)batch
plt.figure()
utils.imshow_batch(next(iter(train_dataloader)))
plt.show()
# -------------------------定義網(wǎng)絡(luò),參數(shù)設(shè)置--------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = net.Net()
print(net)
net = net.to(device)
loss_fc = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
# -----------------------------訓(xùn)練-----------------------------------------
file_runing_loss = open('./log/running_loss.txt', 'w')
file_test_accuarcy = open('./log/test_accuracy.txt', 'w')
epoch_num = 100
for epoch in range(epoch_num):
running_loss = 0.0
accuracy = 0.0
scheduler.step()
for i, sample_batch in enumerate(train_dataloader):
inputs = sample_batch[0]
labels = sample_batch[1]
inputs = inputs.to(device)
labels = labels.to(device)
net.train()
optimizer.zero_grad()
outputs = net(inputs)
loss = loss_fc(outputs, labels)
loss.backward()
optimizer.step()
print(i, loss.item())
# 統(tǒng)計(jì)數(shù)據(jù),loss,accuracy
running_loss += loss.item()
if i % 20 == 19:
correct = 0
total = 0
net.eval()
for inputs, labels in val_dataloader:
outputs = net(inputs)
_, prediction = torch.max(outputs, 1)
correct += ((prediction == labels).sum()).item()
total += labels.size(0)
accuracy = correct / total
print('[{},{}] running loss = {:.5f} acc = {:.5f}'.format(epoch + 1, i+1, running_loss / 20, accuracy))
file_runing_loss.write(str(running_loss / 20)+'\n')
file_test_accuarcy.write(str(accuracy)+'\n')
running_loss = 0.0
print('\n train finish')
torch.save(net.state_dict(), './model/model_100_epoch.pth')
訓(xùn)練結(jié)果
訓(xùn)練的結(jié)果還不錯(cuò)碳褒,Accuracy最高達(dá)到93%左右折砸。
測試網(wǎng)絡(luò)
輸入1個(gè)batch, batch=4,加載訓(xùn)練好的模型沙峻。
注意: 之前模型的訓(xùn)練是在GPU上訓(xùn)練的睦授, 模型保存的存儲(chǔ)布局是按照GPU模式的, 在CPU模式下調(diào)用GPU訓(xùn)練的模型時(shí)候需要添加:
net.load_dict(torch.load('xxx.pth', map_loaction='cpu'))
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
test_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=4, num_workers=4, shuffle=False)
plt.figure()
utils.imshow_batch(next(iter(test_dataloader)))
net = net.Net()
net.load_state_dict(torch.load(f='./model/model_100_epoch.pth', map_location='cpu'))
print(net)
images, labels = next(iter(test_dataloader))
outputs = net(images)
_, prediction = torch.max(outputs, 1)
print('label:', labels)
print('prdeiction:', prediction)
plt.show()
完整工程
- 網(wǎng)絡(luò)定義
net.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class Net(nn.Module):
'''
自定義的CNN網(wǎng)絡(luò)摔寨,3個(gè)卷積層睹逃,包含batch norm。2個(gè)pool,
3個(gè)全連接層祷肯,包含Dropout
輸入:28x28x1s
'''
def __init__(self):
super(Net, self).__init__()
self.feature = nn.Sequential(
OrderedDict(
[
# 28x28x1
('conv1', nn.Conv2d(in_channels=1,
out_channels=32,
kernel_size=5,
stride=1,
padding=2)),
('relu1', nn.ReLU()),
('bn1', nn.BatchNorm2d(num_features=32)),
# 28x28x32
('conv2', nn.Conv2d(in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)),
('relu2', nn.ReLU()),
('bn2', nn.BatchNorm2d(num_features=64)),
('pool1', nn.MaxPool2d(kernel_size=2)),
# 14x14x64
('conv3', nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=1)),
('relu3', nn.ReLU()),
('bn3', nn.BatchNorm2d(num_features=128)),
('pool2', nn.MaxPool2d(kernel_size=2)),
# 7x7x128
('conv4', nn.Conv2d(in_channels=128,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)),
('relu4', nn.ReLU()),
('bn4', nn.BatchNorm2d(num_features=64)),
('pool3', nn.MaxPool2d(kernel_size=2)),
# out 3x3x64
]
)
)
self.classifier = nn.Sequential(
OrderedDict(
[
('fc1', nn.Linear(in_features=3 * 3 * 64,
out_features=128)),
('dropout1', nn.Dropout2d(p=0.5)),
('fc2', nn.Linear(in_features=128,
out_features=64)),
('dropout2', nn.Dropout2d(p=0.6)),
('fc3', nn.Linear(in_features=64, out_features=10))
]
)
)
def forward(self, x):
out = self.feature(x)
out = out.view(-1, 64 * 3 *3)
out = self.classifier(out)
return out
- 訓(xùn)練
train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils
# https://blog.csdn.net/weixin_41278720/article/details/80778640
# ---------------------------數(shù)據(jù)集-------------------------------------
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, transform=tranform)
val_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)
# 隨機(jī)顯示一個(gè)batch
plt.figure()
utils.imshow_batch(next(iter(train_dataloader)))
plt.show()
# -------------------------定義網(wǎng)絡(luò),參數(shù)設(shè)置--------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = net.Net()
print(net)
net = net.to(device)
loss_fc = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
# -----------------------------訓(xùn)練-----------------------------------------
file_runing_loss = open('./log/running_loss.txt', 'w')
file_test_accuarcy = open('./log/test_accuracy.txt', 'w')
epoch_num = 100
for epoch in range(epoch_num):
running_loss = 0.0
accuracy = 0.0
scheduler.step()
for i, sample_batch in enumerate(train_dataloader):
inputs = sample_batch[0]
labels = sample_batch[1]
inputs = inputs.to(device)
labels = labels.to(device)
net.train()
optimizer.zero_grad()
outputs = net(inputs)
loss = loss_fc(outputs, labels)
loss.backward()
optimizer.step()
print(i, loss.item())
# 統(tǒng)計(jì)數(shù)據(jù),loss,accuracy
running_loss += loss.item()
if i % 20 == 19:
correct = 0
total = 0
net.eval()
for inputs, labels in val_dataloader:
outputs = net(inputs)
_, prediction = torch.max(outputs, 1)
correct += ((prediction == labels).sum()).item()
total += labels.size(0)
accuracy = correct / total
print('[{},{}] running loss = {:.5f} acc = {:.5f}'.format(epoch + 1, i+1, running_loss / 20, accuracy))
file_runing_loss.write(str(running_loss / 20)+'\n')
file_test_accuarcy.write(str(accuracy)+'\n')
running_loss = 0.0
print('\n train finish')
torch.save(net.state_dict(), './model/model_100_epoch.pth')
- 可視化工具
utils.py
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
def imshow_batch(sample_batch):
images = sample_batch[0]
labels = sample_batch[1]
images = make_grid(images, nrow=4, pad_value=255)
# 1,2, 0
images_transformed = np.transpose(images.numpy(), (1, 2, 0))
plt.imshow(images_transformed)
plt.axis('off')
labels = labels.numpy()
plt.title(labels)
- 測試
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
test_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=4, num_workers=4, shuffle=False)
plt.figure()
utils.imshow_batch(next(iter(test_dataloader)))
net = net.Net()
net.load_state_dict(torch.load(f='./model/model_100_epoch.pth', map_location='cpu'))
print(net)
images, labels = next(iter(test_dataloader))
outputs = net(images)
_, prediction = torch.max(outputs, 1)
print('label:', labels)
print('prdeiction:', prediction)
plt.show()