3.5 圖像分類數(shù)據(jù)集(Fashion-MNIST)
在介紹softmax回歸的實現(xiàn)前我們先引入一個多類圖像分類數(shù)據(jù)集或详。它將在后面的章節(jié)中被多次使用,以方便我們觀察比較算法之間在模型精度和計算效率上的區(qū)別。圖像分類數(shù)據(jù)集中最常用的是手寫數(shù)字識別數(shù)據(jù)集MNIST[1]筛欢。但大部分模型在MNIST上的分類精度都超過了95%煌茴。為了更直觀地觀察算法之間的差異,我們將使用一個圖像內容更加復雜的數(shù)據(jù)集Fashion-MNIST[2](這個數(shù)據(jù)集也比較小,只有幾十M溉贿,沒有GPU的電腦也能吃得消)看疙。
本節(jié)我們將使用torchvision包豆拨,它是服務于PyTorch深度學習框架的,主要用來構建計算機視覺模型能庆。torchvision主要由以下幾部分構成:
-
torchvision.datasets
: 一些加載數(shù)據(jù)的函數(shù)及常用的數(shù)據(jù)集接口施禾; -
torchvision.models
: 包含常用的模型結構(含預訓練模型),例如AlexNet搁胆、VGG弥搞、ResNet等; -
torchvision.transforms
: 常用的圖片變換渠旁,例如裁剪拓巧、旋轉等; -
torchvision.utils
: 其他的一些有用的方法一死。
3.5.1 獲取數(shù)據(jù)集
首先導入本節(jié)需要的包或模塊肛度。
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..") # 為了導入上層目錄的d2lzh_pytorch
import d2lzh_pytorch as d2l
下面,我們通過torchvision的torchvision.datasets
來下載這個數(shù)據(jù)集投慈。第一次調用時會自動從網上獲取數(shù)據(jù)承耿。我們通過參數(shù)train
來指定獲取訓練數(shù)據(jù)集或測試數(shù)據(jù)集(testing data set)。測試數(shù)據(jù)集也叫測試集(testing set)伪煤,只用來評價模型的表現(xiàn)加袋,并不用來訓練模型。
另外我們還指定了參數(shù)transform = transforms.ToTensor()
使所有數(shù)據(jù)轉換為Tensor
抱既,如果不進行轉換則返回的是PIL圖片职烧。transforms.ToTensor()
將尺寸為 (H x W x C) 且數(shù)據(jù)位于[0, 255]的PIL圖片或者數(shù)據(jù)類型為np.uint8
的NumPy數(shù)組轉換為尺寸為(C x H x W)且數(shù)據(jù)類型為torch.float32
且位于[0.0, 1.0]的Tensor
。
注意: 由于像素值為0到255的整數(shù)防泵,所以剛好是uint8所能表示的范圍蚀之,包括
transforms.ToTensor()
在內的一些關于圖片的函數(shù)就默認輸入的是uint8型,若不是捷泞,可能不會報錯但可能得不到想要的結果足删。所以,如果用像素值(0-255整數(shù))表示圖片數(shù)據(jù)锁右,那么一律將其類型設置成uint8失受,避免不必要的bug讶泰。 本人就被這點坑過,詳見我的這個博客2.2.4節(jié)拂到。
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
上面的mnist_train
和mnist_test
都是torch.utils.data.Dataset
的子類痪署,所以我們可以用len()
來獲取該數(shù)據(jù)集的大小,還可以用下標來獲取具體的一個樣本兄旬。訓練集中和測試集中的每個類別的圖像數(shù)分別為6,000和1,000狼犯。因為有10個類別,所以訓練集和測試集的樣本數(shù)分別為60,000和10,000辖试。
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))
輸出:
<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000
我們可以通過下標來訪問任意一個樣本:
feature, label = mnist_train[0]
print(feature.shape, label) # Channel x Height x Width
輸出:
torch.Size([1, 28, 28]) tensor(9)
變量feature
對應高和寬均為28像素的圖像辜王。由于我們使用了transforms.ToTensor()
劈狐,所以每個像素的數(shù)值為[0.0, 1.0]的32位浮點數(shù)罐孝。需要注意的是,feature
的尺寸是 (C x H x W) 的肥缔,而不是 (H x W x C)莲兢。第一維是通道數(shù),因為數(shù)據(jù)集中是灰度圖像续膳,所以通道數(shù)為1改艇。后面兩維分別是圖像的高和寬。
Fashion-MNIST中一共包括了10個類別坟岔,分別為t-shirt(T恤)谒兄、trouser(褲子)、pullover(套衫)社付、dress(連衣裙)承疲、coat(外套)、sandal(涼鞋)鸥咖、shirt(襯衫)燕鸽、sneaker(運動鞋)、bag(包)和ankle boot(短靴)啼辣。以下函數(shù)可以將數(shù)值標簽轉成相應的文本標簽啊研。
# 本函數(shù)已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
下面定義一個可以在一行里畫出多張圖像和對應標簽的函數(shù)。
# 本函數(shù)已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
d2l.use_svg_display()
# 這里的_表示我們忽略(不使用)的變量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
現(xiàn)在鸥拧,我們看一下訓練數(shù)據(jù)集中前10個樣本的圖像內容和文本標簽党远。
X, y = [], []
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
3.5.2 讀取小批量
我們將在訓練數(shù)據(jù)集上訓練模型,并將訓練好的模型在測試數(shù)據(jù)集上評價模型的表現(xiàn)富弦。前面說過麸锉,mnist_train
是torch.utils.data.Dataset
的子類,所以我們可以將其傳入torch.utils.data.DataLoader
來創(chuàng)建一個讀取小批量數(shù)據(jù)樣本的DataLoader實例舆声。
在實踐中花沉,數(shù)據(jù)讀取經常是訓練的性能瓶頸柳爽,特別當模型較簡單或者計算硬件性能較高時。PyTorch的DataLoader
中一個很方便的功能是允許使用多進程來加速數(shù)據(jù)讀取碱屁。這里我們通過參數(shù)num_workers
來設置4個進程讀取數(shù)據(jù)磷脯。
batch_size = 256
if sys.platform.startswith('win'):
num_workers = 0 # 0表示不用額外的進程來加速讀取數(shù)據(jù)
else:
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
我們將獲取并讀取Fashion-MNIST數(shù)據(jù)集的邏輯封裝在d2lzh_pytorch.load_data_fashion_mnist
函數(shù)中供后面章節(jié)調用。該函數(shù)將返回train_iter
和test_iter
兩個變量娩脾。隨著本書內容的不斷深入赵誓,我們會進一步改進該函數(shù)。它的完整實現(xiàn)將在5.6節(jié)中描述柿赊。
最后我們查看讀取一遍訓練數(shù)據(jù)需要的時間俩功。
start = time.time()
for X, y in train_iter:
continue
print('%.2f sec' % (time.time() - start))
輸出:
1.57 sec
小結
參考文獻
[1] LeCun, Y., Cortes, C., & Burges, C. http://yann.lecun.com/exdb/mnist/
[2] Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747.
注:本節(jié)除了代碼之外與原書基本相同,原書傳送門