基于tensorflow+RNN的MNIST數(shù)據(jù)集手寫數(shù)字分類

2018年9月25日筆記

tensorflow是谷歌google的深度學(xué)習(xí)框架日矫,tensor中文叫做張量,flow叫做流捍掺。
RNN是recurrent neural network的簡稱,中文叫做循環(huán)神經(jīng)網(wǎng)絡(luò)。
MNIST是Mixed National Institue of Standards and Technology database的簡稱课兄,中文叫做美國國家標(biāo)準(zhǔn)與技術(shù)研究所數(shù)據(jù)庫
此文在上一篇文章《基于tensorflow+DNN的MNIST數(shù)據(jù)集手寫數(shù)字分類預(yù)測》的基礎(chǔ)上修改模型為循環(huán)神經(jīng)網(wǎng)絡(luò)模型晨继,模型準(zhǔn)確率從98%提升到98.5%烟阐,錯誤率減少了25%
《基于tensorflow+DNN的MNIST數(shù)據(jù)集手寫數(shù)字分類預(yù)測》文章鏈接:http://www.reibang.com/p/9a4ae5655ca6

0.編程環(huán)境

操作系統(tǒng):Win10
tensorflow版本:1.6
tensorboard版本:1.6
python版本:3.6

1.致謝聲明

本文是作者學(xué)習(xí)《周莫煩tensorflow視頻教程》的成果,感激前輩紊扬;
視頻鏈接:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/

2.配置環(huán)境

使用循環(huán)神經(jīng)網(wǎng)絡(luò)模型要求有較高的機器配置蜒茄,如果使用CPU版tensorflow會花費大量時間。
讀者在有nvidia顯卡的情況下餐屎,安裝GPU版tensorflow會提高計算速度50倍檀葛。
安裝教程鏈接:https://blog.csdn.net/qq_36556893/article/details/79433298
如果沒有nvidia顯卡,但有visa信用卡腹缩,請閱讀我的另一篇文章《在谷歌云服務(wù)器上搭建深度學(xué)習(xí)平臺》屿聋,鏈接:http://www.reibang.com/p/893d622d1b5a

3.下載并解壓數(shù)據(jù)集

MNIST數(shù)據(jù)集下載鏈接: https://pan.baidu.com/s/1fPbgMqsEvk2WyM9hy5Em6w 密碼: wa9p
下載壓縮文件MNIST_data.rar完成后,選擇解壓到當(dāng)前文件夾藏鹊,不要選擇解壓到MNIST_data润讥。
文件夾結(jié)構(gòu)如下圖所示:

image.png

4.完整代碼

此章給讀者能夠直接運行的完整代碼,使讀者有編程結(jié)果的感性認(rèn)識盘寡。
如果下面一段代碼運行成功象对,則說明安裝tensorflow環(huán)境成功。
想要了解代碼的具體實現(xiàn)細(xì)節(jié)宴抚,請閱讀后面的章節(jié)勒魔。
完整代碼中定義函數(shù)RNN使代碼簡潔,但在后面章節(jié)中為了易于讀者理解菇曲,本文作者在第6章搭建神經(jīng)網(wǎng)絡(luò)將此部分函數(shù)改寫為只針對于該題的順序執(zhí)行代碼冠绢。

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
learing_rate = 0.001
batch_size =100
n_steps = 28
n_inputs = 28
n_hidden_units = 128
n_classes = 10
X_holder = tf.placeholder(tf.float32)
Y_holder = tf.placeholder(tf.float32)

def RNN(X_holder):
    reshape_X = tf.reshape(X_holder, [-1, n_steps, n_inputs])
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units)
    outputs, states = tf.nn.dynamic_rnn(lstm_cell, reshape_X, dtype=tf.float32)
    cell_list = tf.unstack(tf.transpose(outputs, [1, 0, 2]))
    last_cell = cell_list[-1]
    Weights = tf.Variable(tf.truncated_normal([n_hidden_units, n_classes]))
    biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))
    predict_Y = tf.matmul(last_cell, Weights) + biases
    return predict_Y
predict_Y = RNN(X_holder)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict_Y, labels=Y_holder))
optimizer = tf.train.AdamOptimizer(learing_rate)
train = optimizer.minimize(loss)

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

isCorrect = tf.equal(tf.argmax(predict_Y, 1), tf.argmax(Y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(isCorrect, tf.float32))
for i in range(1000):
    X, Y = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:X, Y_holder:Y})
    step = i + 1
    if step % 100 == 0:
        test_X, test_Y = mnist.train.next_batch(3000)
        test_accuracy = session.run(accuracy, feed_dict={X_holder:test_X, Y_holder:test_Y})
        print("step:%d test accuracy:%.4f" %(step, test_accuracy))

上面一段代碼的運行結(jié)果如下:

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
step:100 test accuracy:0.8483
step:200 test accuracy:0.8987
step:300 test accuracy:0.9230
step:400 test accuracy:0.9437
step:500 test accuracy:0.9457
step:600 test accuracy:0.9513
step:700 test accuracy:0.9687
step:800 test accuracy:0.9660
step:900 test accuracy:0.9710
step:1000 test accuracy:0.9740

5.數(shù)據(jù)準(zhǔn)備

第1行代碼導(dǎo)入庫warnings;
第2行代碼表示不打印警告信息常潮;
第3行代碼導(dǎo)入庫tensorflow弟胀,取別名tf;
第4行代碼從tensorflow.examples.tutorials.mnist庫中導(dǎo)入input_data方法喊式;
第6行代碼表示重置tensorflow圖
第7行代碼加載數(shù)據(jù)庫MNIST賦值給變量mnist孵户;
第8-13行代碼定義超參數(shù)學(xué)習(xí)率learning_rate、批量大小batch_size岔留、步數(shù)n_steps夏哭、輸入層大小n_inputs、隱藏層大小n_hidden_units献联、輸出層大小n_classes竖配。
第14何址、15行代碼中placeholder中文叫做占位符,將每次訓(xùn)練的特征矩陣X和預(yù)測目標(biāo)值Y賦值給變量X_holder和Y_holder进胯。

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
learing_rate = 0.001
batch_size =100
n_steps = 28
n_inputs = 28
n_hidden_units = 128
n_classes = 10
X_holder = tf.placeholder(tf.float32)
Y_holder = tf.placeholder(tf.float32)

6.搭建神經(jīng)網(wǎng)絡(luò)

本文作者將此章中使用tensorflow庫的所有方法的API鏈接總結(jié)成下表用爪,訪問需要vpn。

方法 鏈接
tf.reshape https://www.tensorflow.org/api_docs/python/tf/manip/reshape
tf.nn.rnn_cell.LSTMCell https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicLSTMCell
tf.nn.dynamic_rnn https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn
tf.transpose https://www.tensorflow.org/api_docs/python/tf/transpose
tf.unstack https://www.tensorflow.org/api_docs/python/tf/unstack
tf.Variable https://www.tensorflow.org/api_docs/python/tf/Variable
tf.truncated_normal https://www.tensorflow.org/api_docs/python/tf/truncated_normal
tf.matmul https://www.tensorflow.org/api_docs/python/tf/matmul
tf.reduce_mean https://www.tensorflow.org/api_docs/python/tf/reduce_mean
tf.nn.softmax_cross_entropy_with_logits https://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits
tf.train.AdamOptimizer https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer

第1行代碼reshape中文叫做重塑形狀胁镐,將輸入數(shù)據(jù)X_holder重塑形狀為模型需要的偎血;
第2行代碼調(diào)用tf.nn.rnn_cell.LSTMCell方法實例化LSTM細(xì)胞對象;
第3行代碼調(diào)用tf.nn.dynamic_rnn方法實例化rnn模型對象盯漂;
第4烁巫、5行代碼取得rnn模型中最后一個細(xì)胞的數(shù)值;
第6宠能、7行代碼定義在訓(xùn)練過程會更新的權(quán)重Weights、偏置biases磁餐;
第8行代碼表示xW+b的計算結(jié)果賦值給變量predict_Y违崇,即預(yù)測值;
第9行代碼表示交叉熵作為損失函數(shù)loss诊霹;
第10行代碼表示AdamOptimizer作為優(yōu)化器optimizer羞延;
第11行代碼定義訓(xùn)練過程,即使用優(yōu)化器optimizer最小化損失函數(shù)loss脾还。

reshape_X = tf.reshape(X_holder, [-1, n_steps, n_inputs])
lstm_cell = tf.nn.rnn_cell.LSTMCell(n_hidden_units)
outputs, state = tf.nn.dynamic_rnn(lstm_cell, reshape_X, dtype=tf.float32)
cell_list = tf.unstack(tf.transpose(outputs, [1, 0, 2]))
last_cell = cell_list[-1]
Weights = tf.Variable(tf.truncated_normal([n_hidden_units, n_classes]))
biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))
predict_Y = tf.matmul(last_cell, Weights) + biases
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict_Y, labels=Y_holder))
optimizer = tf.train.AdamOptimizer(learing_rate)
train = optimizer.minimize(loss)

7.參數(shù)初始化

對于神經(jīng)網(wǎng)絡(luò)模型伴箩,重要是其中的W、b這兩個參數(shù)鄙漏。
開始神經(jīng)網(wǎng)絡(luò)模型訓(xùn)練之前嗤谚,這兩個變量需要初始化。
第1行代碼調(diào)用tf.global_variables_initializer實例化tensorflow中的Operation對象怔蚌。


image.png

第2行代碼調(diào)用tf.Session方法實例化會話對象巩步;
第3行代碼調(diào)用tf.Session對象的run方法做變量初始化。

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

8.模型訓(xùn)練

第1行代碼tf.argmax方法中的第2個參數(shù)為1桦踊,即求出矩陣中每1行中最大數(shù)的索引椅野;
如果argmax方法中的第1個參數(shù)為0,即求出矩陣中每1列最大數(shù)的索引籍胯;
tf.equal方法可以比較兩個向量的在每個元素上是否相同竟闪,返回結(jié)果為向量,向量中元素的數(shù)據(jù)類型為布爾bool杖狼;
第2行代碼tf.cast方法可以強制轉(zhuǎn)換向量中元素的數(shù)據(jù)類型炼蛤,tf.reduce_mean可以求出向量中元素的均值;
第3行代碼表示迭代訓(xùn)練1000次蝶涩;
第4行代碼表示從mnist數(shù)據(jù)的訓(xùn)練集中選取batch_size數(shù)量的樣本鲸湃;
第5行代碼每運行1次赠涮,即模型訓(xùn)練1次;
第6-10行代碼表示從mnist數(shù)據(jù)的測試集中選取10000個樣本計算模型預(yù)測準(zhǔn)確率暗挑。

isCorrect = tf.equal(tf.argmax(predict_Y, 1), tf.argmax(Y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(isCorrect, tf.float32))
for i in range(1000):
    X, Y = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:X, Y_holder:Y})
    step = i + 1
    if step % 100 == 0:
        test_X, test_Y = mnist.test.next_batch(10000)
        test_accuracy = session.run(accuracy, feed_dict={X_holder:test_X, Y_holder:test_Y})
        print("step:%d test accuracy:%.4f" %(step, test_accuracy))

上面一段代碼的運行結(jié)果如下:

step:100 test accuracy:0.8479
step:200 test accuracy:0.8986
step:300 test accuracy:0.9370
step:400 test accuracy:0.9421
step:500 test accuracy:0.9522
step:600 test accuracy:0.9581
step:700 test accuracy:0.9607
step:800 test accuracy:0.9650
step:900 test accuracy:0.9661
step:1000 test accuracy:0.9685

文章篇幅所限笋除,只打印查看1000次訓(xùn)練的結(jié)果,訓(xùn)練5000次即可達到98.5%的準(zhǔn)確率炸裆。

9.總結(jié)

1.本文是作者寫的第9篇關(guān)于tensorflow編程的博客垃它;
2.在mnist案例中,rnn模型最高可達到98.5%的準(zhǔn)確率烹看,cnn模型最高可達到99.2%的準(zhǔn)確率国拇,因為本文中的rnn模型只考慮的圖像矩陣中每1行的關(guān)系,rnn模型可以提取空間特征惯殊。
3.理解第6章搭建神經(jīng)網(wǎng)絡(luò)的過程中酱吝,雖然代碼只有11行,本文作者花費了2天將近10多個小時土思;
4.周莫煩前輩的視頻當(dāng)中關(guān)于此章的代碼已經(jīng)過時务热,并且github中的代碼的模型準(zhǔn)確率不超過90%,沒有超過單隱藏層的DNN網(wǎng)絡(luò)98%的準(zhǔn)確率己儒,這違背了RNN優(yōu)于DNN的出發(fā)點崎岂。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市闪湾,隨后出現(xiàn)的幾起案子冲甘,更是在濱河造成了極大的恐慌,老刑警劉巖途样,帶你破解...
    沈念sama閱讀 211,042評論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件江醇,死亡現(xiàn)場離奇詭異,居然都是意外死亡何暇,警方通過查閱死者的電腦和手機嫁审,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 89,996評論 2 384
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事葡缰。” “怎么了捂贿?”我有些...
    開封第一講書人閱讀 156,674評論 0 345
  • 文/不壞的土叔 我叫張陵,是天一觀的道長胳嘲。 經(jīng)常有香客問我厂僧,道長,這世上最難降的妖魔是什么了牛? 我笑而不...
    開封第一講書人閱讀 56,340評論 1 283
  • 正文 為了忘掉前任颜屠,我火速辦了婚禮辰妙,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘甫窟。我一直安慰自己密浑,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 65,404評論 5 384
  • 文/花漫 我一把揭開白布粗井。 她就那樣靜靜地躺著尔破,像睡著了一般。 火紅的嫁衣襯著肌膚如雪浇衬。 梳的紋絲不亂的頭發(fā)上懒构,一...
    開封第一講書人閱讀 49,749評論 1 289
  • 那天,我揣著相機與錄音耘擂,去河邊找鬼胆剧。 笑死,一個胖子當(dāng)著我的面吹牛醉冤,可吹牛的內(nèi)容都是我干的秩霍。 我是一名探鬼主播,決...
    沈念sama閱讀 38,902評論 3 405
  • 文/蒼蘭香墨 我猛地睜開眼冤灾,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了辕近?” 一聲冷哼從身側(cè)響起韵吨,我...
    開封第一講書人閱讀 37,662評論 0 266
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎移宅,沒想到半個月后归粉,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,110評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡漏峰,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,451評論 2 325
  • 正文 我和宋清朗相戀三年糠悼,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片浅乔。...
    茶點故事閱讀 38,577評論 1 340
  • 序言:一個原本活蹦亂跳的男人離奇死亡倔喂,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出靖苇,到底是詐尸還是另有隱情席噩,我是刑警寧澤,帶...
    沈念sama閱讀 34,258評論 4 328
  • 正文 年R本政府宣布贤壁,位于F島的核電站悼枢,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏脾拆。R本人自食惡果不足惜馒索,卻給世界環(huán)境...
    茶點故事閱讀 39,848評論 3 312
  • 文/蒙蒙 一莹妒、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧绰上,春花似錦旨怠、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,726評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至疯趟,卻和暖如春拘哨,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背信峻。 一陣腳步聲響...
    開封第一講書人閱讀 31,952評論 1 264
  • 我被黑心中介騙來泰國打工倦青, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人盹舞。 一個月前我還...
    沈念sama閱讀 46,271評論 2 360
  • 正文 我出身青樓产镐,卻偏偏與公主長得像,于是被迫代替她去往敵國和親踢步。 傳聞我的和親對象是個殘疾皇子癣亚,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 43,452評論 2 348

推薦閱讀更多精彩內(nèi)容