手把手教你用Tensorflow搭建RNN

自從跟導(dǎo)師吹了波牛之后喻奥,開(kāi)始將自己的研究重心轉(zhuǎn)為NLP方向。那么RNN就成了必須要了解的一個(gè)DL模型(導(dǎo)師暑假去了趟西安赖欣,學(xué)習(xí)一波DL屑彻,回來(lái)神神叨叨地說(shuō)CNN到瓶頸期了验庙,別瞎折騰了)。對(duì)于RNN社牲,我就默認(rèn)大家都懂得其中的原理粪薛,有不明白的可以去看NG 的視頻教學(xué):https://mooc.study.163.com/smartSpec/detail/1001319001.htm。話不多說(shuō)搏恤,開(kāi)搞開(kāi)搞Nナ佟!挑社!


Step1 搭建環(huán)境

系統(tǒng):Windows7+TensorFlow1.9.0(cpu)+Python3.6


Step2 加載數(shù)據(jù)集

MNIST是一個(gè)手寫(xiě)數(shù)字?jǐn)?shù)據(jù)庫(kù),它有55000個(gè)訓(xùn)練樣本集和10000個(gè)測(cè)試樣本集巡揍。它是MNIST數(shù)據(jù)庫(kù)的一個(gè)子集痛阻。其中每張圖片固定大小為28×28的黑白圖片。如下圖所示:

MNIST數(shù)據(jù)集

使用Tensorflow加載內(nèi)置的MNIST數(shù)據(jù)集腮敌,具體方法展示如下阱当。

加載方法

加載完成后,打印訓(xùn)練樣本數(shù)(ntrain)糜工,測(cè)試樣本數(shù)(ntest)弊添,樣本 總維度(dim)以及分類數(shù)(nlasses)

打印結(jié)果

Step3 RNN模型

接下來(lái)我們就需要去看看,我們所要搭建的RNN模型到底長(zhǎng)啥樣捌木。

這是從NG那邊搞來(lái)的圖油坝,相信大家都能明白其中的奧秘吧。但當(dāng)時(shí)我在學(xué)的時(shí)候很困惑刨裆,這東西實(shí)際中咋用澈圈,搞成這副鬼樣子,真有那么神帆啃?em...下面這張圖就是利用MNIST數(shù)據(jù)集來(lái)做的一個(gè)實(shí)驗(yàn)瞬女。

簡(jiǎn)單敘述下該模型的建立過(guò)程,每一個(gè)樣本都可以看出是一個(gè)[28,28]的矩陣努潘,那么將矩陣的每一行作為一個(gè)輸入向量诽偷,大小為[1,28],那么整個(gè)模型就擁有28個(gè)輸入神經(jīng)元疯坤,這28個(gè)神經(jīng)元我們將其統(tǒng)稱為輸入層报慕。完成后我再輸入層與RNN層之間增加一個(gè)含有128個(gè)神經(jīng)元的隱藏層,用于對(duì)輸入層進(jìn)行特征 提取压怠,形成一個(gè)[1卖子,128]的向量傳入RNN中。RNN中的內(nèi)部構(gòu)造參見(jiàn)LSTM刑峡,形成兩個(gè)向量分別為L(zhǎng)STM_O洋闽,LSTM_S玄柠,大小都為[1,128]诫舅。其中LSTM_O為RNN模型的輸出羽利,LSTM_S為RNN模型的內(nèi)部記憶向量,傳遞到下一個(gè)RNN神經(jīng)元刊懈。最后對(duì)LSTM_O進(jìn)行Softmax處理这弧,通過(guò)概率分析出該樣本的類別。

接下來(lái)我們對(duì)模型中所涉及的權(quán)重虚汛、偏置以及各層神經(jīng)元數(shù)量的設(shè)置匾浪。其中W["h1"]為輸入層到隱藏層的權(quán)重,大小為[28,128]卷哩,W["h2"]為隱藏層到RNN的權(quán)重蛋辈,大小為[128,10]。b["b1"]與b["b2"]同理将谊。

權(quán)重冷溶、偏置

Step4 創(chuàng)建RNN模型(關(guān)鍵步驟)

創(chuàng)建RNN

42~43行:對(duì)輸入數(shù)據(jù)進(jìn)行預(yù)處理操作。這里涉及到batch_size的問(wèn)題尊浓,在訓(xùn)練時(shí)我們通常是將一批數(shù)據(jù)導(dǎo)入模型來(lái)提高模型的效率逞频,那么批次的大小就是batch_size。即我們可以理解為我們是將一個(gè)batch_size*28*28的三維矩陣導(dǎo)入了我們的RNN模型栋齿,那么我們就要對(duì)該矩陣進(jìn)行變換從而滿足我們[None,28]的要求苗胀。

44~45行:隱藏層處理好輸入數(shù)據(jù)后形成一個(gè)[None,128]的矩陣。然后對(duì)該矩陣進(jìn)行切割瓦堵,我們的RNN一共有28個(gè)輸入單元柒巫,那就切成28個(gè)咯。

46~50行:將切好的矩陣依次傳入RNN中谷丸。接下來(lái)是對(duì)RNN內(nèi)部的設(shè)置堡掏,這邊使用的是LSTM(tf.nn.rnn_cell.BasicLSTMCell()),當(dāng)然tf.nn也為我們實(shí)現(xiàn)好了其他的內(nèi)部設(shè)置方便我們調(diào)用刨疼。


Step5 超參數(shù)的定義(損泉唁,優(yōu),學(xué)揩慕,準(zhǔn)亭畜,初)

該步驟定義模型中我們所需要的一些超參數(shù)。Tensorflow擁有現(xiàn)成的方法迎卤,方便我們調(diào)用拴鸵。

學(xué)習(xí)率:leraning_rate

損失:cost

優(yōu)化方法:optm

參數(shù)初始化:init

超參數(shù)的定義

Step6 訓(xùn)練測(cè)試

最后一步,設(shè)置好迭代次數(shù)(training_epoch)、批次大芯⒚辍(batch_size)等后八堡,將MNIST的數(shù)據(jù)集加載到Step5中設(shè)置好的X、Y中聘芜,完成訓(xùn)練測(cè)試兄渺。沒(méi)什么好多說(shuō)的,我每個(gè)模型最后的訓(xùn)練測(cè)試都這樣汰现,照著寫(xiě)吧挂谍。

訓(xùn)練測(cè)試

Step7 結(jié)果展示

總共迭代了5次,電腦跑太慢了...

準(zhǔn)確率達(dá)到93.9%瞎饲,還彳亍 口 巴口叙!

結(jié)果

附上所有代碼:

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

from tensorflow.contrib import rnn

# 加載數(shù)據(jù)

mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

trainimgs, trainlabels, testimgs, testlabels \

= mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

ntrain, ntest, dim, nclasses\

=trainimgs.shape[0],testimgs.shape[0],trainimgs.shape[1],trainlabels.shape[1]

#print(ntrain, ntest, dim, nclasses)

print ("MNIST loaded")

#設(shè)置參數(shù),權(quán)重,偏置

diminput = 28

dimhidden = 128

dimoutput = nclasses

nsteps = 28

W = {"h1" : tf.Variable(tf.random_normal([diminput,dimhidden])),

? ? "h2" : tf.Variable(tf.random_normal([dimhidden,dimoutput]))}

b = {"b1" : tf.Variable(tf.random_normal([dimhidden])),

? ? "b2" : tf.Variable(tf.random_normal([dimoutput]))}

# 創(chuàng)建模型

def RNN(X,W,b,nsteps):

? ? X = tf.transpose(X,[1,0,2])

? ? X = tf.reshape(X,[-1,diminput])

? ? H_1 = tf.matmul(X,W["h1"])+b["b1"]

? ? H_1 = tf.split(H_1,nsteps,0)

? ? lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden,forget_bias=1.0)

? ? LSTM_O,LSTM_S = rnn.static_rnn(lstm_cell,H_1,dtype=tf.float32)

? ? O = tf.matmul(LSTM_O[-1],W["h2"])+b["b2"]

? ? return {"X":X,"H_1":H_1,"LSTM_O":LSTM_O,"LSTM_S":LSTM_S,"O":O}?

print ("Network ready")

# 設(shè)置損失嗅战,優(yōu)化,學(xué)習(xí)率妄田,準(zhǔn)確率,參數(shù)初始化

learning_rate = 0.001

x? ? ? = tf.placeholder("float", [None, nsteps, diminput])

y? ? ? = tf.placeholder("float", [None, dimoutput])

myrnn? = RNN(x, W, b, nsteps)

pred? = myrnn['O']

cost? = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=pred))

optm? = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

accr? = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1), tf.argmax(y,1)), tf.float32))

init? = tf.global_variables_initializer()

print ("Network Ready!")

# 訓(xùn)練仗哨,測(cè)試

#所有樣本迭代(epoch)5次

training_epochs = 5

#每進(jìn)行一次迭代選擇的樣本數(shù)

batch_size? ? ? = 16

#展示

display_step? ? = 1

sess = tf.Session()

sess.run(init)

print ("Start optimization")

for epoch in range(training_epochs):

? ? avg_cost = 0.

? ? total_batch = int(mnist.train.num_examples/batch_size)

? ? #total_batch = 100

? ? # Loop over all batches

? ? for i in range(total_batch):

? ? ? ? batch_xs, batch_ys = mnist.train.next_batch(batch_size)

? ? ? ? batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))

? ? ? ? # Fit training using batch data

? ? ? ? feeds = {x: batch_xs, y: batch_ys}

? ? ? ? sess.run(optm, feed_dict=feeds)

? ? ? ? # Compute average loss

? ? ? ? avg_cost += sess.run(cost, feed_dict=feeds)/total_batch

? ? # Display logs per epoch step

? ? if epoch % display_step == 0:

? ? ? ? print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))

? ? ? ? feeds = {x: batch_xs, y: batch_ys}

? ? ? ? train_acc = sess.run(accr, feed_dict=feeds)

? ? ? ? print (" Training accuracy: %.3f" % (train_acc))

? ? ? ? testimgs = testimgs.reshape((ntest, nsteps, diminput))

? ? ? ? feeds = {x: testimgs, y: testlabels}

? ? ? ? test_acc = sess.run(accr, feed_dict=feeds)

? ? ? ? print (" Test accuracy: %.3f" % (test_acc))

print ("Optimization Finished.")

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末形庭,一起剝皮案震驚了整個(gè)濱河市铅辞,隨后出現(xiàn)的幾起案子厌漂,更是在濱河造成了極大的恐慌,老刑警劉巖斟珊,帶你破解...
    沈念sama閱讀 218,386評(píng)論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件苇倡,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡囤踩,警方通過(guò)查閱死者的電腦和手機(jī)旨椒,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,142評(píng)論 3 394
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)堵漱,“玉大人综慎,你說(shuō)我怎么就攤上這事∏诼” “怎么了示惊?”我有些...
    開(kāi)封第一講書(shū)人閱讀 164,704評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)愉镰。 經(jīng)常有香客問(wèn)我米罚,道長(zhǎng),這世上最難降的妖魔是什么丈探? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 58,702評(píng)論 1 294
  • 正文 為了忘掉前任录择,我火速辦了婚禮,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘隘竭。我一直安慰自己塘秦,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,716評(píng)論 6 392
  • 文/花漫 我一把揭開(kāi)白布货裹。 她就那樣靜靜地躺著嗤形,像睡著了一般。 火紅的嫁衣襯著肌膚如雪弧圆。 梳的紋絲不亂的頭發(fā)上赋兵,一...
    開(kāi)封第一講書(shū)人閱讀 51,573評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音搔预,去河邊找鬼霹期。 笑死,一個(gè)胖子當(dāng)著我的面吹牛拯田,可吹牛的內(nèi)容都是我干的历造。 我是一名探鬼主播,決...
    沈念sama閱讀 40,314評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼船庇,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼吭产!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起鸭轮,我...
    開(kāi)封第一講書(shū)人閱讀 39,230評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤臣淤,失蹤者是張志新(化名)和其女友劉穎,沒(méi)想到半個(gè)月后窃爷,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體邑蒋,經(jīng)...
    沈念sama閱讀 45,680評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,873評(píng)論 3 336
  • 正文 我和宋清朗相戀三年按厘,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了医吊。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 39,991評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡逮京,死狀恐怖卿堂,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情懒棉,我是刑警寧澤草描,帶...
    沈念sama閱讀 35,706評(píng)論 5 346
  • 正文 年R本政府宣布,位于F島的核電站漓藕,受9級(jí)特大地震影響陶珠,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜享钞,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,329評(píng)論 3 330
  • 文/蒙蒙 一揍诽、第九天 我趴在偏房一處隱蔽的房頂上張望诀蓉。 院中可真熱鬧,春花似錦暑脆、人聲如沸渠啤。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 31,910評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)沥曹。三九已至,卻和暖如春碟联,著一層夾襖步出監(jiān)牢的瞬間妓美,已是汗流浹背。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 33,038評(píng)論 1 270
  • 我被黑心中介騙來(lái)泰國(guó)打工鲤孵, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留壶栋,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,158評(píng)論 3 370
  • 正文 我出身青樓普监,卻偏偏與公主長(zhǎng)得像贵试,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子凯正,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,941評(píng)論 2 355