RNN入門:利用TF的API(二)

上一篇中手工構(gòu)建了RNN網(wǎng)絡(luò)坎弯,這里介紹如何利用TensorFlow(r1.1)的API簡(jiǎn)化這一代碼。

構(gòu)建模型

將建模部分的代碼替換為:

cell = tf.contrib.rnn.BasicRNNCell(state_size)
current_state = init_state
states_series = []
for current_input in inputs_series:
    with tf.variable_scope('rnn') as vs:
        try:
            output, current_state = cell(current_input, current_state)
        except:
            vs.reuse_variables()
            output, current_state = cell(current_input, current_state)
    states_series.append(current_state)

需要注意的是译暂,tf.contrib.rnn.BasicRNNCell函數(shù)每次都會(huì)聲明一次變量抠忘,這會(huì)導(dǎo)致第二次調(diào)用失敗。所以需要加入try語(yǔ)句秧秉,在異常時(shí)聲明vs.reuse_variables()褐桌。

全部代碼

from __future__ import print_function, division
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

num_epochs = 100
total_series_length = 50000
truncated_backprop_length = 15
state_size = 4
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length//batch_size//truncated_backprop_length

def generateData():
    x = np.array(np.random.choice(2, total_series_length, p=[0.5, 0.5]))
    y = np.roll(x, echo_step)
    y[0:echo_step] = 0

    x = x.reshape((batch_size, -1))  # The first index changing slowest, subseries as rows
    y = y.reshape((batch_size, -1))

    return (x, y)

batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])

init_state = tf.placeholder(tf.float32, [batch_size, state_size])

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)

# Unpack columns
inputs_series = tf.split(batchX_placeholder, truncated_backprop_length, axis=1)
labels_series = tf.unstack(batchY_placeholder, axis=1)

# Forward passes
cell = tf.contrib.rnn.BasicRNNCell(state_size)
current_state = init_state
states_series = []
for current_input in inputs_series:
    with tf.variable_scope('rnn') as vs:
        try:
            output, current_state = cell(current_input, current_state)
        except:
            vs.reuse_variables()
            output, current_state = cell(current_input, current_state)
    states_series.append(current_state)

logits_series = [tf.matmul(state, W2) + b2 for state in states_series] #Broadcasted addition
predictions_series = [tf.nn.softmax(logits) for logits in logits_series]

losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) for logits, labels in zip(logits_series,labels_series)]
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)

def plot(loss_list, predictions_series, batchX, batchY):
    plt.subplot(2, 3, 1)
    plt.cla()
    plt.plot(loss_list)

    for batch_series_idx in range(5):
        one_hot_output_series = np.array(predictions_series)[:, batch_series_idx, :]
        single_output_series = np.array([(1 if out[0] < 0.5 else 0) for out in one_hot_output_series])

        plt.subplot(2, 3, batch_series_idx + 2)
        plt.cla()
        plt.axis([0, truncated_backprop_length, 0, 2])
        left_offset = range(truncated_backprop_length)
        plt.bar(left_offset, batchX[batch_series_idx, :], width=1, color="blue")
        plt.bar(left_offset, batchY[batch_series_idx, :] * 0.5, width=1, color="red")
        plt.bar(left_offset, single_output_series * 0.3, width=1, color="green")

    plt.draw()
    plt.pause(0.0001)


with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()
        _current_state = np.zeros((batch_size, state_size))

        print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder:batchX,
                    batchY_placeholder:batchY,
                    init_state:_current_state
                })

            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show()

參考文獻(xiàn):

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市象迎,隨后出現(xiàn)的幾起案子荧嵌,更是在濱河造成了極大的恐慌呛踊,老刑警劉巖,帶你破解...
    沈念sama閱讀 216,470評(píng)論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件啦撮,死亡現(xiàn)場(chǎng)離奇詭異谭网,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)赃春,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,393評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門愉择,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人织中,你說(shuō)我怎么就攤上這事锥涕。” “怎么了狭吼?”我有些...
    開封第一講書人閱讀 162,577評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵层坠,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我刁笙,道長(zhǎng)破花,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,176評(píng)論 1 292
  • 正文 為了忘掉前任疲吸,我火速辦了婚禮座每,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘摘悴。我一直安慰自己峭梳,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,189評(píng)論 6 388
  • 文/花漫 我一把揭開白布烦租。 她就那樣靜靜地躺著延赌,像睡著了一般。 火紅的嫁衣襯著肌膚如雪叉橱。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,155評(píng)論 1 299
  • 那天者蠕,我揣著相機(jī)與錄音窃祝,去河邊找鬼。 笑死踱侣,一個(gè)胖子當(dāng)著我的面吹牛粪小,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播抡句,決...
    沈念sama閱讀 40,041評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼探膊,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了待榔?” 一聲冷哼從身側(cè)響起逞壁,我...
    開封第一講書人閱讀 38,903評(píng)論 0 274
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤流济,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后腌闯,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體绳瘟,經(jīng)...
    沈念sama閱讀 45,319評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,539評(píng)論 2 332
  • 正文 我和宋清朗相戀三年姿骏,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了糖声。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 39,703評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡分瘦,死狀恐怖蘸泻,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情嘲玫,我是刑警寧澤悦施,帶...
    沈念sama閱讀 35,417評(píng)論 5 343
  • 正文 年R本政府宣布味悄,位于F島的核電站啸驯,受9級(jí)特大地震影響触机,放射性物質(zhì)發(fā)生泄漏铭污。R本人自食惡果不足惜株婴,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,013評(píng)論 3 325
  • 文/蒙蒙 一殃饿、第九天 我趴在偏房一處隱蔽的房頂上張望犁柜。 院中可真熱鬧川梅,春花似錦旺坠、人聲如沸乔遮。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,664評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)蹋肮。三九已至,卻和暖如春璧疗,著一層夾襖步出監(jiān)牢的瞬間坯辩,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,818評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工崩侠, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留漆魔,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 47,711評(píng)論 2 368
  • 正文 我出身青樓却音,卻偏偏與公主長(zhǎng)得像改抡,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子系瓢,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,601評(píng)論 2 353

推薦閱讀更多精彩內(nèi)容

  • 導(dǎo)語(yǔ):本文是TensorFlow實(shí)現(xiàn)流行機(jī)器學(xué)習(xí)算法的教程匯集阿纤,目標(biāo)是讓讀者可以輕松通過清晰簡(jiǎn)明的案例深入了解 T...
    Hebborn_hb閱讀 1,493評(píng)論 0 3
  • 本片文章是對(duì)上一篇文章(iOS版SessionID概念理解及demo教程)的深度解析和繼續(xù)完善迭代...... 上...
    ttdiOS閱讀 1,508評(píng)論 1 6
  • 在以往深秋季節(jié)里,自己一直以來(lái)比較討厭初起的北風(fēng)夷陋,給本已肅殺的天氣又增添了幾分凜冽欠拾,讓久已習(xí)慣溫?zé)岬奈覀兏袊@“時(shí)光...
    黑白無(wú)閱讀 136評(píng)論 0 2
  • Day 8 18-6-2017 親愛的天父胰锌,謝謝你把我?guī)У缴駩壑遥裉旆窒淼男畔⑶迨矗苡懈袆?dòng)匕荸。當(dāng)牧師分享說(shuō),“你們...
    JennyMo閱讀 311評(píng)論 0 0
  • 1.去除ios頁(yè)面的input枷邪、textarea的自帶效果(陰影榛搔、等) input,textarea {-webk...
    執(zhí)著_7a69閱讀 2,174評(píng)論 0 0