基于tensorflow的MNIST數(shù)據(jù)集手寫數(shù)字分類

2018年9月16日筆記

MNIST是Mixed National Institue of Standards and Technology database的簡稱蹬耘,中文叫做美國國家標(biāo)準(zhǔn)與技術(shù)研究所數(shù)據(jù)庫综苔。

0.編程環(huán)境

安裝tensorflow命令:pip install tensorflow
操作系統(tǒng):Win10
python版本:3.6
集成開發(fā)環(huán)境:jupyter notebook
tensorflow版本:1.6

1.致謝聲明

1.本文是作者學(xué)習(xí)《周莫煩tensorflow視頻教程》的成果休里,感激前輩妙黍;
視頻鏈接:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/
2.參考云水木石的文章,鏈接:https://mp.weixin.qq.com/s/DJxY_5pyjOsB70HrsBraOA

2.下載并解壓數(shù)據(jù)集

MNIST數(shù)據(jù)集下載鏈接: https://pan.baidu.com/s/1fPbgMqsEvk2WyM9hy5Em6w 密碼: wa9p
下載壓縮文件MNIST_data.rar完成后可免,選擇解壓到當(dāng)前文件夾浇借,不要選擇解壓到MNIST_data妇垢。
文件夾結(jié)構(gòu)如下圖所示:

image.png

3.完整代碼

此章給讀者能夠直接運(yùn)行的完整代碼闯估,使讀者有編程結(jié)果的感性認(rèn)識(shí)。
如果下面一段代碼運(yùn)行成功骑素,則說明安裝tensorflow環(huán)境成功献丑。
想要了解代碼的具體實(shí)現(xiàn)細(xì)節(jié)侠姑,請(qǐng)閱讀后面的章節(jié)结借。

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
X_holder = tf.placeholder(tf.float32)
y_holder = tf.placeholder(tf.float32)

Weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([1,10]))
predict_y = tf.nn.softmax(tf.matmul(X_holder, Weights) + biases)
loss = tf.reduce_mean(-tf.reduce_sum(y_holder * tf.log(predict_y), 1))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

session = tf.Session()
init = tf.global_variables_initializer()
session.run(init)

for i in range(500):
    images, labels = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:images, y_holder:labels})
    if i % 25 == 0:
        correct_prediction = tf.equal(tf.argmax(predict_y, 1), tf.argmax(y_holder, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        accuracy_value = session.run(accuracy, feed_dict={X_holder:mnist.test.images, y_holder:mnist.test.labels})
        print('step:%d accuracy:%.4f' %(i, accuracy_value))

上面一段代碼的運(yùn)行結(jié)果如下:

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
step:0 accuracy:0.4747
step:25 accuracy:0.8553
step:50 accuracy:0.8719
step:75 accuracy:0.8868
step:100 accuracy:0.8911
step:125 accuracy:0.8998
step:150 accuracy:0.8942
step:175 accuracy:0.9050
step:200 accuracy:0.9026
step:225 accuracy:0.9076
step:250 accuracy:0.9071
step:275 accuracy:0.9049
step:300 accuracy:0.9055
step:325 accuracy:0.9101
step:350 accuracy:0.9097
step:375 accuracy:0.9116
step:400 accuracy:0.9102
step:425 accuracy:0.9113
step:450 accuracy:0.9155
step:475 accuracy:0.9151

從上面的運(yùn)行結(jié)果可以看出,經(jīng)過500步訓(xùn)練柳畔,模型準(zhǔn)確率到達(dá)0.9151左右薪韩。

4.數(shù)據(jù)準(zhǔn)備

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
X_holder = tf.placeholder(tf.float32)
y_holder = tf.placeholder(tf.float32)

第1行代碼導(dǎo)入warnings庫俘陷,第2行代碼表示不打印警告信息观谦;
第3行代碼導(dǎo)入tensorflow庫豁状,取別名tf;
第4行代碼人從tensorflow.examples.tutorials.mnist庫中導(dǎo)入input_data文件夭禽;
本文作者使用anaconda集成開發(fā)環(huán)境讹躯,input_data文件所在路徑:C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow\examples\tutorials\mnist蜀撑,如下圖所示:

image.png

第6行代碼調(diào)用input_data文件的read_data_sets方法酷麦,需要2個(gè)參數(shù)沃饶,第1個(gè)參數(shù)的數(shù)據(jù)類型是字符串,是讀取數(shù)據(jù)的文件夾名琴昆,第2個(gè)關(guān)鍵字參數(shù)ont_hot數(shù)據(jù)類型為布爾bool业舍,設(shè)置為True舷暮,表示預(yù)測目標(biāo)值是否經(jīng)過One-Hot編碼噩茄;
第7行代碼定義變量batch_size的值為100绩聘;
第8凿菩、9行代碼中placeholder中文叫做占位符,將每次訓(xùn)練的特征矩陣X和預(yù)測目標(biāo)值y賦值給變量X_holder和y_holder叉庐。

5.數(shù)據(jù)觀察

本章內(nèi)容主要是了解變量mnist中的數(shù)據(jù)內(nèi)容陡叠,并掌握變量mnist中的方法使用枉阵。

5.1 查看變量mnist的方法和屬性

dir(mnist)[-10:]

上面一段代碼的運(yùn)行結(jié)果如下:

['_asdict',
'_fields',
'_make',
'_replace',
'_source',
'count',
'index',
'test',
'train',
'validation']

為了節(jié)省篇幅兴溜,只打印最后10個(gè)方法和屬性。
我們會(huì)用到的是其中test刨沦、train想诅、validation這3個(gè)方法来破。

5.2 對(duì)比三個(gè)集合

train對(duì)應(yīng)訓(xùn)練集忘古,validation對(duì)應(yīng)驗(yàn)證集髓堪,test對(duì)應(yīng)測試集干旁。
查看3個(gè)集合中的樣本數(shù)量,代碼如下:

print(mnist.train.num_examples)
print(mnist.validation.num_examples)
print(mnist.test.num_examples)

上面一段代碼的運(yùn)行結(jié)果如下:

55000
5000
10000

對(duì)比3個(gè)集合的方法和屬性


image.png

從上面的運(yùn)行結(jié)果可以看出央拖,3個(gè)集合的方法和屬性基本相同祭阀。
我們會(huì)用到的是其中images、labels鲜戒、next_batch這3個(gè)屬性或方法专控。

5.3 mnist.train.images觀察

查看mnist.train.images的數(shù)據(jù)類型和矩陣形狀。

images = mnist.train.images
type(images), images.shape

上面一段代碼的運(yùn)行結(jié)果如下:

(numpy.ndarray, (55000, 784))

從上面的運(yùn)行結(jié)果可以看出遏餐,在變量mnist.train中總共有55000個(gè)樣本伦腐,每個(gè)樣本有784個(gè)特征。
原圖片形狀為28*28,28*28=784失都,每個(gè)圖片樣本展平后則有784維特征。
選取1個(gè)樣本粹庞,用3種作圖方式查看其圖片內(nèi)容咳焚,代碼如下:

import matplotlib.pyplot as plt

image = mnist.train.images[1].reshape(-1, 28)
plt.subplot(131)
plt.imshow(image)
plt.axis('off')
plt.subplot(132)
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.subplot(133)
plt.imshow(image, cmap='gray_r')
plt.axis('off')
plt.show()

上面一段代碼的運(yùn)行結(jié)果如下圖所示:

image.png

從上面的運(yùn)行結(jié)果可以看出,調(diào)用plt.show方法時(shí)庞溜,參數(shù)cmap指定值為graygray_r符合正常的觀看效果革半。

5.4 查看手寫數(shù)字圖

從訓(xùn)練集mnist.train中選取一部分樣本查看圖片內(nèi)容,即調(diào)用mnist.train的next_batch方法隨機(jī)獲得一部分樣本,代碼如下:

import matplotlib.pyplot as plt
import math
import numpy as np

def drawDigit(position, image, title):
    plt.subplot(*position)
    plt.imshow(image.reshape(-1, 28), cmap='gray_r')
    plt.axis('off')
    plt.title(title)
    
def batchDraw(batch_size):
    images,labels = mnist.train.next_batch(batch_size)
    image_number = images.shape[0]
    row_number = math.ceil(image_number ** 0.5)
    column_number = row_number
    plt.figure(figsize=(row_number, column_number))
    for i in range(row_number):
        for j in range(column_number):
            index = i * column_number + j
            if index < image_number:
                position = (row_number, column_number, index+1)
                image = images[index]
                title = 'actual:%d' %(np.argmax(labels[index]))
                drawDigit(position, image, title)

batchDraw(196)
plt.show()

上面一段代碼的運(yùn)行結(jié)果如下圖所示又官,本文作者對(duì)難以辨認(rèn)的數(shù)字做了紅色方框標(biāo)注:


image.png

6.搭建神經(jīng)網(wǎng)絡(luò)

Weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([1,10]))
predict_y = tf.nn.softmax(tf.matmul(X_holder, Weights) + biases)
loss = tf.reduce_mean(-tf.reduce_sum(y_holder * tf.log(predict_y), 1))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

該神經(jīng)網(wǎng)絡(luò)只有輸入層和輸出層延刘,沒有隱藏層。
第1行代碼定義形狀為784*10的權(quán)重矩陣Weights六敬;
第2行代碼定義形狀為1*10的偏置矩陣biases碘赖;
第3行代碼定義先通過矩陣計(jì)算,再使用激活函數(shù)softmax得出的每個(gè)分類的預(yù)測概率predict_y觉阅;
第4行代碼定義損失函數(shù)loss崖疤,多分類問題使用交叉熵作為損失函數(shù)。
交叉熵的函數(shù)如下圖所示典勇,其中p(x)是實(shí)際值劫哼,q(x)是預(yù)測值

image.png

第5行代碼定義優(yōu)化器optimizer割笙,使用梯度下降優(yōu)化器权烧;
第6行代碼定義訓(xùn)練步驟train,即最小化損失伤溉。

7.變量初始化

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

對(duì)于神經(jīng)網(wǎng)絡(luò)模型般码,重要是其中的W、b這兩個(gè)參數(shù)乱顾。
開始神經(jīng)網(wǎng)絡(luò)模型訓(xùn)練之前板祝,這兩個(gè)變量需要初始化。
第1行代碼調(diào)用tf.global_variables_initializer實(shí)例化tensorflow中的Operation對(duì)象走净。


image.png

第2行代碼調(diào)用tf.Session方法實(shí)例化會(huì)話對(duì)象券时;
第3行代碼調(diào)用tf.Session對(duì)象的run方法做變量初始化。

8.模型訓(xùn)練

for i in range(500):
    images, labels = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:images, y_holder:labels})
    if i % 25 == 0:
        correct_prediction = tf.equal(tf.argmax(predict_y, 1), tf.argmax(y_holder, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        accuracy_value = session.run(accuracy, feed_dict={X_holder:mnist.test.images, y_holder:mnist.test.labels})
        print('step:%d accuracy:%.4f' %(i, accuracy_value))

第1行代碼表示模型迭代訓(xùn)練500次伏伯;
第2行代碼調(diào)用mnist.train對(duì)象的next_batch方法橘洞,選出數(shù)量為batch_size的樣本;
第3行代碼是模型訓(xùn)練说搅,每運(yùn)行1次此行代碼炸枣,即模型訓(xùn)練1次;
第4-8行代碼是每隔25次訓(xùn)練打印模型準(zhǔn)確率弄唧。
上面一段代碼的運(yùn)行結(jié)果如下:

step:0 accuracy:0.3161
step:25 accuracy:0.8452
step:50 accuracy:0.8668
step:75 accuracy:0.8860
step:100 accuracy:0.8906
step:125 accuracy:0.8948
step:150 accuracy:0.9008
step:175 accuracy:0.9027
step:200 accuracy:0.8956
step:225 accuracy:0.9102
step:250 accuracy:0.9022
step:275 accuracy:0.9097
step:300 accuracy:0.9039
step:325 accuracy:0.9076
step:350 accuracy:0.9137
step:375 accuracy:0.9111
step:400 accuracy:0.9069
step:425 accuracy:0.9097
step:450 accuracy:0.9150
step:475 accuracy:0.9105

9.模型測試

import math
import matplotlib.pyplot as plt
import numpy as np

def drawDigit2(position, image, title, isTrue):
    plt.subplot(*position)
    plt.imshow(image.reshape(-1, 28), cmap='gray_r')
    plt.axis('off')
    if not isTrue:
        plt.title(title, color='red')
    else:
        plt.title(title)
        
def batchDraw2(batch_size):
    images,labels = mnist.test.next_batch(batch_size)
    predict_labels = session.run(predict_y, feed_dict={X_holder:images, y_holder:labels})
    image_number = images.shape[0]
    row_number = math.ceil(image_number ** 0.5)
    column_number = row_number
    plt.figure(figsize=(row_number+8, column_number+8))
    for i in range(row_number):
        for j in range(column_number):
            index = i * column_number + j
            if index < image_number:
                position = (row_number, column_number, index+1)
                image = images[index]
                actual = np.argmax(labels[index])
                predict = np.argmax(predict_labels[index])
                isTrue = actual==predict
                title = 'actual:%d\npredict:%d' %(actual,predict)
                drawDigit2(position, image, title, isTrue)

batchDraw2(100)
plt.show()

上面一段代碼的運(yùn)行結(jié)果如下圖所示:


image.png

10.結(jié)論

1.這是本文作者寫的第4篇關(guān)于tensorflow的文章适肠,加深了對(duì)tensorflow框架的理解;
2.優(yōu)化器必須使用GradientDescentOptimizer候引,使用AdamOptimizer會(huì)出現(xiàn)錯(cuò)誤迂猴;
3.初始化權(quán)重Weights時(shí),全部初始化為0比隨機(jī)正態(tài)初始化效果要好背伴。
4.盡管在多數(shù)的深度學(xué)習(xí)實(shí)踐中不能初始化權(quán)重為0沸毁,但此模型只有輸入層輸出層峰髓,所以可以權(quán)重初始化為0。
5.如何進(jìn)一步提高模型準(zhǔn)確率息尺,請(qǐng)閱讀本文作者的另一篇文章《基于tensorflow+DNN的MNIST數(shù)據(jù)集手寫數(shù)字分類預(yù)測》携兵,鏈接:http://www.reibang.com/p/9a4ae5655ca6

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市搂誉,隨后出現(xiàn)的幾起案子徐紧,更是在濱河造成了極大的恐慌,老刑警劉巖炭懊,帶你破解...
    沈念sama閱讀 211,042評(píng)論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件并级,死亡現(xiàn)場離奇詭異,居然都是意外死亡侮腹,警方通過查閱死者的電腦和手機(jī)嘲碧,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 89,996評(píng)論 2 384
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來父阻,“玉大人愈涩,你說我怎么就攤上這事〖用” “怎么了履婉?”我有些...
    開封第一講書人閱讀 156,674評(píng)論 0 345
  • 文/不壞的土叔 我叫張陵,是天一觀的道長斟览。 經(jīng)常有香客問我毁腿,道長,這世上最難降的妖魔是什么苛茂? 我笑而不...
    開封第一講書人閱讀 56,340評(píng)論 1 283
  • 正文 為了忘掉前任已烤,我火速辦了婚禮,結(jié)果婚禮上味悄,老公的妹妹穿的比我還像新娘草戈。我一直安慰自己塌鸯,他們只是感情好侍瑟,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,404評(píng)論 5 384
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著丙猬,像睡著了一般涨颜。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上茧球,一...
    開封第一講書人閱讀 49,749評(píng)論 1 289
  • 那天庭瑰,我揣著相機(jī)與錄音,去河邊找鬼抢埋。 笑死弹灭,一個(gè)胖子當(dāng)著我的面吹牛督暂,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播穷吮,決...
    沈念sama閱讀 38,902評(píng)論 3 405
  • 文/蒼蘭香墨 我猛地睜開眼逻翁,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了捡鱼?” 一聲冷哼從身側(cè)響起八回,我...
    開封第一講書人閱讀 37,662評(píng)論 0 266
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎驾诈,沒想到半個(gè)月后缠诅,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,110評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡乍迄,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,451評(píng)論 2 325
  • 正文 我和宋清朗相戀三年管引,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片就乓。...
    茶點(diǎn)故事閱讀 38,577評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡汉匙,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出生蚁,到底是詐尸還是另有隱情噩翠,我是刑警寧澤,帶...
    沈念sama閱讀 34,258評(píng)論 4 328
  • 正文 年R本政府宣布邦投,位于F島的核電站伤锚,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏志衣。R本人自食惡果不足惜屯援,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,848評(píng)論 3 312
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望念脯。 院中可真熱鬧狞洋,春花似錦、人聲如沸绿店。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,726評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽假勿。三九已至借嗽,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間转培,已是汗流浹背恶导。 一陣腳步聲響...
    開封第一講書人閱讀 31,952評(píng)論 1 264
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留浸须,地道東北人惨寿。 一個(gè)月前我還...
    沈念sama閱讀 46,271評(píng)論 2 360
  • 正文 我出身青樓邦泄,卻偏偏與公主長得像,于是被迫代替她去往敵國和親裂垦。 傳聞我的和親對(duì)象是個(gè)殘疾皇子虎韵,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,452評(píng)論 2 348

推薦閱讀更多精彩內(nèi)容

  • 今年上半年可謂是大片云集,各種題材類型的片子一個(gè)接一個(gè)缸废,讓觀眾應(yīng)接不暇包蓝;但是同時(shí)也是爛片最嚴(yán)重的半年,各種爛片不斷...
    影視大亨閱讀 524評(píng)論 0 0
  • 上一章 (三) 午夜狂奔 夏末企量,秋風(fēng)亦探出腦袋测萎,時(shí)不時(shí)的深呼口氣。 過去了盛夏的煩躁届巩,人的心情也跟著活躍了起來硅瞧。于...
    左手陌路閱讀 277評(píng)論 0 0
  • 統(tǒng)編教材解讀 教材核心思想: 立德樹人,德育為本恕汇,能力為重腕唧,基礎(chǔ)為先,創(chuàng)新為上瘾英。對(duì)傳統(tǒng)文化的傳承枣接,留住我們的根。對(duì)...
    蝶化文瀾閱讀 1,088評(píng)論 0 2
  • 昨晚騎車去超市買東西缺谴,車子停在路邊沒上鎖但惶,可巧遇到同樓的同學(xué),提著東西順路就走回去了湿蛔,然后就沒有然后了膀曾。等再次想起...
    寡人有點(diǎn)煩閱讀 147評(píng)論 0 1