MNIST 數(shù)據(jù)集已經(jīng)是一個(gè)被"嚼爛"了的數(shù)據(jù)集, 很多教程都會(huì)對(duì)它"下手", 幾乎成為一個(gè) "典范". 不過(guò)有些人可能對(duì)它還不是很了解, 下面來(lái)介紹一下.
MNIST 數(shù)據(jù)集可在 http://yann.lecun.com/exdb/mnist/ 獲取, 它包含了四個(gè)部分:
- Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解壓后 47 MB, 包含 60,000 個(gè)樣本)
- Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解壓后 60 KB, 包含 60,000 個(gè)標(biāo)簽)
- Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解壓后 7.8 MB, 包含 10,000 個(gè)樣本)
- Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解壓后 10 KB, 包含 10,000 個(gè)標(biāo)簽)
MNIST 數(shù)據(jù)集來(lái)自美國(guó)國(guó)家標(biāo)準(zhǔn)與技術(shù)研究所, National Institute of Standards and Technology (NIST). 訓(xùn)練集 (training set) 由來(lái)自 250 個(gè)不同人手寫(xiě)的數(shù)字構(gòu)成, 其中 50% 是高中學(xué)生, 50% 來(lái)自人口普查局 (the Census Bureau) 的工作人員. 測(cè)試集(test set) 也是同樣比例的手寫(xiě)數(shù)字?jǐn)?shù)據(jù).
不妨新建一個(gè)文件夾 -- mnist, 將數(shù)據(jù)集下載到 mnist 以后, 解壓即可:
圖片是以字節(jié)的形式進(jìn)行存儲(chǔ), 我們需要把它們讀取到 NumPy array 中, 以便訓(xùn)練和測(cè)試算法.
import os
import struct
import numpy as np
def load_mnist(path, kind='train'):
"""Load MNIST data from `path`"""
labels_path = os.path.join(path,
'%s-labels-idx1-ubyte'
% kind)
images_path = os.path.join(path,
'%s-images-idx3-ubyte'
% kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II',
lbpath.read(8))
labels = np.fromfile(lbpath,
dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII',
imgpath.read(16))
images = np.fromfile(imgpath,
dtype=np.uint8).reshape(len(labels), 784)
return images, labels
load_mnist
函數(shù)返回兩個(gè)數(shù)組, 第一個(gè)是一個(gè) n x m 維的 NumPy array(images
), 這里的 n 是樣本數(shù)(行數(shù)), m 是特征數(shù)(列數(shù)). 訓(xùn)練數(shù)據(jù)集包含 60,000 個(gè)樣本, 測(cè)試數(shù)據(jù)集包含 10,000 樣本. 在 MNIST 數(shù)據(jù)集中的每張圖片由 28 x 28 個(gè)像素點(diǎn)構(gòu)成, 每個(gè)像素點(diǎn)用一個(gè)灰度值表示. 在這里, 我們將 28 x 28 的像素展開(kāi)為一個(gè)一維的行向量, 這些行向量就是圖片數(shù)組里的行(每行 784 個(gè)值, 或者說(shuō)每行就是代表了一張圖片). load_mnist
函數(shù)返回的第二個(gè)數(shù)組(labels
) 包含了相應(yīng)的目標(biāo)變量, 也就是手寫(xiě)數(shù)字的類標(biāo)簽(整數(shù) 0-9).
第一次見(jiàn)的話, 可能會(huì)覺(jué)得我們讀取圖片的方式有點(diǎn)奇怪:
magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
為了理解這兩行代碼, 我們先來(lái)看一下 MNIST 網(wǎng)站上對(duì)數(shù)據(jù)集的介紹:
TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
通過(guò)使用上面兩行代碼, 我們首先讀入 magic number, 它是一個(gè)文件協(xié)議的描述, 也是在我們調(diào)用 fromfile
方法將字節(jié)讀入 NumPy array 之前在文件緩沖中的 item 數(shù)(n). 作為參數(shù)值傳入 struct.unpack
的 >II
有兩個(gè)部分:
-
>
: 這是指大端(用來(lái)定義字節(jié)是如何存儲(chǔ)的); 如果你還不知道什么是大端和小端, Endianness 是一個(gè)非常好的解釋. (關(guān)于大小端, 更多內(nèi)容可見(jiàn)<<深入理解計(jì)算機(jī)系統(tǒng) -- 2.1 節(jié)信息存儲(chǔ)>>) -
I
: 這是指一個(gè)無(wú)符號(hào)整數(shù).
通過(guò)執(zhí)行下面的代碼, 我們將會(huì)從剛剛解壓 MNIST 數(shù)據(jù)集后的 mnist 目錄下加載 60,000 個(gè)訓(xùn)練樣本和 10,000 個(gè)測(cè)試樣本.
為了了解 MNIST 中的圖片看起來(lái)到底是個(gè)啥, 讓我們來(lái)對(duì)它們進(jìn)行可視化處理. 從 feature matrix 中將 784-像素值 的向量 reshape 為之前的 28*28 的形狀, 然后通過(guò) matplotlib 的 imshow
函數(shù)進(jìn)行繪制:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(
nrows=2,
ncols=5,
sharex=True,
sharey=True, )
ax = ax.flatten()
for i in range(10):
img = X_train[y_train == i][0].reshape(28, 28)
ax[i].imshow(img, cmap='Greys', interpolation='nearest')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
我們現(xiàn)在應(yīng)該可以看到一個(gè) 2*5 的圖片, 里面分別是 0-9 單個(gè)數(shù)字的圖片.
此外, 我們還可以繪制某一數(shù)字的多個(gè)樣本圖片, 來(lái)看一下這些手寫(xiě)樣本到底有多不同:
fig, ax = plt.subplots(
nrows=5,
ncols=5,
sharex=True,
sharey=True, )
ax = ax.flatten()
for i in range(25):
img = X_train[y_train == 7][i].reshape(28, 28)
ax[i].imshow(img, cmap='Greys', interpolation='nearest')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
執(zhí)行上面的代碼后, 我們應(yīng)該看到數(shù)字 7 的 25 個(gè)不同形態(tài):
另外, 我們也可以選擇將 MNIST 圖片數(shù)據(jù)和標(biāo)簽保存為 CSV 文件, 這樣就可以在不支持特殊的字節(jié)格式的程序中打開(kāi)數(shù)據(jù)集. 但是, 有一點(diǎn)要說(shuō)明, CSV 的文件格式將會(huì)占用更多的磁盤(pán)空間, 如下所示:
- train_img.csv: 109.5 MB
- train_labels.csv: 120 KB
- test_img.csv: 18.3 MB
- test_labels: 20 KB
如果我們打算保存這些 CSV 文件, 在將 MNIST 數(shù)據(jù)集加載入 NumPy array 以后, 我們應(yīng)該執(zhí)行下列代碼:
np.savetxt('train_img.csv', X_train,
fmt='%i', delimiter=',')
np.savetxt('train_labels.csv', y_train,
fmt='%i', delimiter=',')
np.savetxt('test_img.csv', X_test,
fmt='%i', delimiter=',')
np.savetxt('test_labels.csv', y_test,
fmt='%i', delimiter=',')
一旦將數(shù)據(jù)集保存為 CSV 文件, 我們也可以用 NumPy 的 genfromtxt
函數(shù)重新將它們加載入程序中:
X_train = np.genfromtxt('train_img.csv',
dtype=int, delimiter=',')
y_train = np.genfromtxt('train_labels.csv',
dtype=int, delimiter=',')
X_test = np.genfromtxt('test_img.csv',
dtype=int, delimiter=',')
y_test = np.genfromtxt('test_labels.csv',
dtype=int, delimiter=',')
不過(guò), 從 CSV 文件中加載 MNIST 數(shù)據(jù)將會(huì)顯著發(fā)給更長(zhǎng)的時(shí)間, 因此如果可能的話, 還是建議你維持?jǐn)?shù)據(jù)集原有的字節(jié)格式.
參考:
- Book , Python Machine Learning.