不考慮從tensorflow或者keras等平臺上直接下載轉(zhuǎn)換好的mnist數(shù)據(jù)集的方法蛇更,直接手動處理mnist官方的數(shù)據(jù)集。分別有四個文件,對應(yīng)訓(xùn)練集圖像知押,訓(xùn)練集標(biāo)簽,測試集圖像鹃骂,測試集標(biāo)簽台盯。
官網(wǎng)給的數(shù)據(jù)集并不是圖像數(shù)據(jù)格式,而是編碼后的二進制格式畏线。這是官網(wǎng)的數(shù)據(jù)說明:
image.png
前16個字節(jié)分為4個整型數(shù)據(jù)静盅,每個4字節(jié),分別代表數(shù)據(jù)信息寝殴、圖像數(shù)量蒿叠、行數(shù)、列數(shù)蚣常,之后的數(shù)據(jù)全部為像素市咽,色素值為0-255。
代碼如下:
import numpy as np
import struct
mnist_dir = r'./digit/'
def fetch_mnist(mnist_dir,data_type):
train_data_path = mnist_dir + 'train-images.idx3-ubyte'
train_label_path = mnist_dir + 'train-labels.idx1-ubyte'
test_data_path = mnist_dir + 't10k-images.idx3-ubyte'
test_label_path = mnist_dir + 't10k-labels.idx1-ubyte'
# train_img
with open(train_data_path, 'rb') as f:
data = f.read(16)
des,img_nums,row,col = struct.unpack_from('>IIII', data, 0) // >IIII中每個I代表integral 或者 long類型數(shù)據(jù)
train_x = np.zeros((img_nums, row*col))
for index in range(img_nums):
data = f.read(784)
if len(data) == 784:
train_x[index,:] = np.array(struct.unpack_from('>' + 'B' * (row * col), data, 0)).reshape(1,784)
f.close()
# train label
with open(train_label_path, 'rb') as f:
data = f.read(8)
des,label_nums = struct.unpack_from('>II', data, 0)
train_y = np.zeros((label_nums, 1))
for index in range(label_nums):
data = f.read(1)
train_y[index,:] = np.array(struct.unpack_from('>B', data, 0)).reshape(1,1)
f.close()
# test_img
with open(test_data_path, 'rb') as f:
data = f.read(16)
des, img_nums, row, col = struct.unpack_from('>IIII', data, 0)
test_x = np.zeros((img_nums, row * col))
for index in range(img_nums):
data = f.read(784)
if len(data) == 784:
test_x[index, :] = np.array(struct.unpack_from('>' + 'B' * (row * col), data, 0)).reshape(1, 784)
f.close()
# test label
with open(test_label_path, 'rb') as f:
data = f.read(8)
des, label_nums = struct.unpack_from('>II', data, 0)
test_y = np.zeros((label_nums, 1))
for index in range(label_nums):
data = f.read(1)
test_y[index, :] = np.array(struct.unpack_from('>B', data, 0)).reshape(1, 1)
f.close()
if data_type == 'train':
return train_x, train_y
elif data_type == 'test':
return test_x, test_y
elif data_type == 'all':
return train_x, train_y,test_x, test_y
else:
print('type error')
if __name__ == '__main__':
tr_x, tr_y, te_x, te_y = fetch_mnist(mnist_dir,'all')
import matplotlib.pyplot as plt # plt 用于顯示圖片
img_0 = tr_x[59999,:].reshape(28,28)
plt.imshow(img_0)
print(tr_y[59999,:])
img_1 = te_x[500,:].reshape(28,28)
plt.imshow(img_1)
print(te_y[500,:])
plt.show()
————————————————
https://blog.csdn.net/jinxiaonian11/article/details/78172613