tensorflow練習(xí)2-簡(jiǎn)單的RNN預(yù)測(cè)帶周期性的數(shù)據(jù)

tensorflow練習(xí)系列:
tensorflow練習(xí)1-DNN


現(xiàn)在我們嘗試使用lstm來預(yù)測(cè)一個(gè)周期性的數(shù)據(jù)

import tensorflow as tf
from tensorflow.contrib import rnn

import numpy as np
from matplotlib import pyplot as plt
from sklearn import preprocessing, model_selection
from sklearn import preprocessing

import pandas as pd

input_vec_size = 1 # 輸入向量的維度
lstm_size = 10 # size of lstm
time_step_size = 5 # 循環(huán)層長(zhǎng)度

batch_size = 7
test_size = 3

1. 準(zhǔn)備數(shù)據(jù)

我們使用偽造的數(shù)據(jù)

day = (time_step_size + 1)* 200

week_rate = [0.9, 0.85, 0.80, 0.88, 1.1, 1.2, 1.15]
label = [(1 + i * 0.002) * week_rate[i%7] for i in range(day)]
label = np.array(label)

# scaler = preprocessing.StandardScaler()
# label = scaler.fit_transform(label)

label = label.reshape(int(day / (time_step_size + 1)), (time_step_size + 1))

print(label.shape)
_tmp = label
X_ = _tmp[:, :time_step_size]
Y_ = _tmp[:, time_step_size:]

print(_tmp.shape, X_.shape, Y_.shape)

plt.plot(_tmp[:100, 0])
plt.show()

輸出結(jié)果如下
(200, 6)
(200, 6) (200, 5) (200, 1)


周期性的數(shù)據(jù)

2. 準(zhǔn)備網(wǎng)絡(luò)

def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))

def model(X, W, B, lstm_size):
    # X, input shape: (batch_size, time_step_size, input_vec_size)
    # XT shape: (time_step_size, batch_size, input_vec_size)
    print(X.shape)
    XT = tf.transpose(X, [1, 0, 2])

    # XR shape: (time_step_size * batch_size, input_vec_size)
    XR = tf.reshape(XT, [-1, input_vec_size]) # each row has input for each lstm cell (lstm_size=input_vec_size)

    # Each array shape: (batch_size, input_vec_size)
    X_split = tf.split(XR, time_step_size, 0) # split them to time_step_size


    # Make lstm with lstm_size (each input vector size). num_units=lstm_size; forget_bias=1.0
    lstm = rnn.BasicLSTMCell(lstm_size, forget_bias=1.0, state_is_tuple=True)

    # Get lstm cell output, time_step_size (28) arrays with lstm_size output: (batch_size, lstm_size)
    # rnn..static_rnn()的輸出對(duì)應(yīng)于每一個(gè)timestep砸琅,如果只關(guān)心最后一步的輸出趋惨,取outputs[-1]即可
    outputs, _states = rnn.static_rnn(lstm, X_split, dtype=tf.float32)  # 時(shí)間序列上每個(gè)Cell的輸出:[... shape=(128, 28)..]

    # Linear activation
    # Get the last output
    return tf.matmul(outputs[-1], W) + B, lstm.state_size # State size to initialize the stat

trX, teX, trY, teY = model_selection.train_test_split(X_, Y_, test_size=0.3)

print(trX.shape, trY.shape, teX.shape, teY.shape)

trX = trX.reshape(-1, time_step_size,  1) 
teX = teX.reshape(-1, time_step_size,  1) 

trY = trY.reshape(-1, 1)
teY = teY.reshape(-1, 1)

X = tf.placeholder("float", [None, time_step_size, 1])
Y = tf.placeholder("float", [None, 1])

# get lstm_size and output 10 labels
W = init_weights([lstm_size, 1])
B = init_weights([1])

py_x, state_size = model(X, W, B, lstm_size)

loss = tf.reduce_mean(tf.square(py_x - Y))
train_op = tf.train.AdamOptimizer(0.01).minimize(loss)
predict_op = py_x

session_conf = tf.ConfigProto()
session_conf.gpu_options.allow_growth = True

3. 開始訓(xùn)練

# Launch the graph in a session
with tf.Session(config=session_conf) as sess:
    # you need to initialize all variables
    tf.global_variables_initializer().run()
    
    for i in range(5000):
        for start, end in zip(range(0, len(trX), batch_size), range(batch_size, len(trX)+1, batch_size)):
            #print("feed:", trX[start:end][0] , trY[start:end][0])
            #print("feed:", trX[start:end].shape, trY[start:end].shape)
            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})

        if(i % 100 == 0):
            print(i, sess.run(loss, feed_dict={X: trX, Y: trY}), sess.run(loss, feed_dict={X: teX, Y: teY}))
    

    predic = sess.run(predict_op, feed_dict={X: X_.reshape(-1, time_step_size,  1), Y: Y_.reshape(-1, 1)})

訓(xùn)練結(jié)果如下:

0 0.708224 0.584213
100 0.000834472 0.000973581
200 0.000377034 0.000441525
300 0.0111581 0.0139003
400 0.000116222 0.000156086
500 0.000468893 0.000493862
600 0.000166236 0.000181855
700 4.5977e-05 7.10702e-05
800 0.0001364 0.000223403
900 5.61442e-05 7.7883e-05
1000 4.72553e-05 6.94315e-05
1100 0.00035655 0.000320759
1200 1.88308e-05 3.37516e-05
1300 1.21966e-05 2.17105e-05
1400 4.73054e-05 7.62947e-05
1500 7.19001e-06 2.29275e-05
1600 0.00010293 0.000149464
1700 3.69292e-05 4.97202e-05
1800 1.507e-05 2.25646e-05
1900 1.81654e-05 3.16758e-05
2000 6.29222e-06 1.25351e-05
2100 7.36748e-05 0.000105863
2200 9.74285e-06 1.90102e-05
2300 5.93037e-05 7.35179e-05
2400 0.000723914 0.000536784
2500 3.20104e-06 8.17306e-06
2600 3.83796e-06 1.01513e-05
2700 5.5799e-06 1.48322e-05
2800 3.92063e-06 1.29322e-05
2900 3.68061e-06 1.5905e-05
3000 2.65767e-06 1.1614e-05
3100 3.09019e-06 1.39719e-05
3200 0.000264612 0.000287502
3300 4.02342e-05 4.46263e-05
3400 2.33401e-05 3.36289e-05
3500 8.55672e-05 0.000107355
3600 3.20796e-05 4.43166e-05
3700 1.7802e-05 2.81231e-05
3800 1.15583e-05 1.70578e-05
3900 1.40946e-05 2.29459e-05
4000 1.16366e-05 1.92793e-05
4100 4.93991e-06 1.29583e-05
4200 2.57865e-05 4.40378e-05
4300 6.99072e-06 1.18626e-05
4400 3.02546e-06 7.17427e-06
4500 6.18525e-06 1.39441e-05
4600 0.000154743 0.000180568
4700 7.82577e-05 7.30797e-05
4800 5.27492e-05 6.85774e-05
4900 1.77607e-05 2.66213e-05

查看一下結(jié)果數(shù)據(jù):

plt.plot(Y_[:100])
plt.plot(predic[:100])
plt.show()

# print(scaler.mean_)


print(np.mean(np.square(Y_ - predic)))

# print(np.hstack([scaler.inverse_transform(Y_), scaler.inverse_transform(predic)]))

print(np.hstack([X_, Y_, predic])[:20])

得到


基本符合預(yù)期
9.2249199079e-05
[[ 0.9         0.8517      0.8032      0.88528     1.1088      1.212
   1.21290922]
 [ 1.1638      0.9126      0.8636      0.8144      0.8976      1.1242
   1.12351131]
 [ 1.2288      1.1799      0.9252      0.8755      0.8256      0.90992
   0.90445042]
 [ 1.1396      1.2456      1.196       0.9378      0.8874      0.8368
   0.82641059]
 [ 0.92224     1.155       1.2624      1.2121      0.9504      0.8993
   0.8717519 ]
 [ 0.848       0.93456     1.1704      1.2792      1.2282      0.963
   0.95084727]
 [ 0.9112      0.8592      0.94688     1.1858      1.296       1.2443
   1.22580171]
 [ 0.9756      0.9231      0.8704      0.9592      1.2012      1.3128
   1.31333256]
 [ 1.2604      0.9882      0.935       0.8816      0.97152     1.2166
   1.2210536 ]
 [ 1.3296      1.2765      1.0008      0.9469      0.8928      0.98384
   0.97711486]
 [ 1.232       1.3464      1.2926      1.0134      0.9588      0.904
   0.89058191]
 [ 0.99616     1.2474      1.3632      1.3087      1.026       0.9707
   0.95494765]
 [ 0.9152      1.00848     1.2628      1.38        1.3248      1.0386
   1.02744293]
 [ 0.9826      0.9264      1.0208      1.2782      1.3968      1.3409
   1.33157599]
 [ 1.0512      0.9945      0.9376      1.03312     1.2936      1.4136
   1.41023409]
 [ 1.357       1.0638      1.0064      0.9488      1.04544     1.309
   1.31301975]
 [ 1.4304      1.3731      1.0764      1.0183      0.96        1.05776
   1.05260563]
 [ 1.3244      1.4472      1.3892      1.089       1.0302      0.9712
   0.95847595]
 [ 1.07008     1.3398      1.464       1.4053      1.1016      1.0421
   1.03257525]
 [ 0.9824      1.0824      1.3552      1.4808      1.4214      1.1142
   1.10492337]]

進(jìn)一步圖形化結(jié)果:

mean_ = np.mean(Y_)
plt.plot((predic - Y_)/mean_ * 100)
plt.show()
print("mean", mean_)
預(yù)測(cè)誤差, 單位%

mean 2.166021

4. 小結(jié)

  1. 之前的代碼存在過擬合,增加樣本數(shù)量后, 得到緩解, (一定要同時(shí)打印出train&test loss)
  2. 之前的代碼擬合效果不佳滤祖, 與樣本的數(shù)據(jù)分布過大有關(guān)陋气, 應(yīng)該在lstm哪個(gè)地方加上norm议忽, 適應(yīng)更多的數(shù)據(jù)分布溉跃。目前取得的結(jié)果與數(shù)據(jù)的分布比較合適有關(guān)村刨。
  3. Adam似乎是一個(gè)不錯(cuò)的選擇。我也嘗試過RMSprop喊积, 但是效果不佳烹困, 貌似動(dòng)量有點(diǎn)過頭的樣子。

思考

  1. 預(yù)測(cè)的結(jié)果成周期性乾吻, 是否是沒有學(xué)習(xí)到一些內(nèi)容造成的?
  2. 怎樣評(píng)判當(dāng)前的結(jié)果呢拟蜻?
print("當(dāng)前的預(yù)測(cè)結(jié)果", np.mean(np.square(predic-Y_)))
#對(duì)比绎签, 假設(shè)網(wǎng)絡(luò)什么都沒有學(xué)習(xí)到, 那么直接用X_最后一個(gè)作為預(yù)測(cè)結(jié)果
print("直接用上一個(gè)值作為預(yù)測(cè)結(jié)果", np.mean(np.square(X_[:, -1] - Y_)))
print("學(xué)習(xí)到了一點(diǎn)點(diǎn)的大勢(shì)", np.mean(np.square(X_[:, -1] + 0.002 -Y_)))

當(dāng)前的預(yù)測(cè)結(jié)果 1.69001122001e-05
直接用上一個(gè)值作為預(yù)測(cè)結(jié)果 1.96738985325
學(xué)習(xí)到了一點(diǎn)點(diǎn)的大勢(shì) 1.96737203849

  • 從上面的結(jié)果看酝锅, 還是學(xué)習(xí)到了一點(diǎn)的诡必。

  • 但是從上面可以看到, 其實(shí)我們是選的周期為6的數(shù)據(jù)在訓(xùn)練搔扁, 我們把周期改成7了爸舒, 結(jié)果會(huì)不會(huì)更好呢?
    答案當(dāng)然是更好稿蹲, loss達(dá)到了1.53918561546e-07扭勉, 原來周期為5+1的時(shí)候, loss為 1.07650830461e-05
    我們猜想原因可能是因?yàn)榫W(wǎng)絡(luò)可以直接學(xué)習(xí)到最后一個(gè)數(shù)字 / 1.2 * 1.15 修正0.02的增產(chǎn)率苛聘, 就可以得到結(jié)果涂炎。
    那么忠聚, 我們是否能通過修正網(wǎng)絡(luò), 讓周期不是6+1的時(shí)候唱捣, 也得到較好的結(jié)果e-7等級(jí)的loss呢两蟀?

  • 把當(dāng)前星期幾作為特征輸入,結(jié)果直到訓(xùn)練到2900步才把loss降到e-6級(jí)別震缭, 是否有優(yōu)化的前景呢赂毯?

  • 現(xiàn)在2007個(gè)數(shù)據(jù), 實(shí)際也只用了200組作為訓(xùn)練和驗(yàn)證數(shù)據(jù)拣宰, 剛好7個(gè)數(shù)據(jù)一組欢瞪, 實(shí)際上我們可以n個(gè)數(shù)據(jù), 取出n - 7 + 1組出來徐裸, 這個(gè)結(jié)果遣鼓, 期待后續(xù)更新, 結(jié)果單純改到多數(shù)組重贺, loss反而更高骑祟, 到了1.6e-3級(jí)別, 標(biāo)準(zhǔn)化輸入數(shù)據(jù)后气笙,結(jié)果loss到了1.5e-3~1.9e-4( 這個(gè)是scaler.scaler_**2 的結(jié)果)次企,loss有所降低, 但是效果不是特別明顯

5. 改進(jìn)

目前其實(shí)效果還是不理想潜圃, 等理解更輸入缸棵, 有時(shí)間后續(xù)改進(jìn)模型。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末谭期,一起剝皮案震驚了整個(gè)濱河市堵第,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌隧出,老刑警劉巖踏志,帶你破解...
    沈念sama閱讀 222,378評(píng)論 6 516
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異胀瞪,居然都是意外死亡针余,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,970評(píng)論 3 399
  • 文/潘曉璐 我一進(jìn)店門凄诞,熙熙樓的掌柜王于貴愁眉苦臉地迎上來圆雁,“玉大人,你說我怎么就攤上這事帆谍∥毙啵” “怎么了?”我有些...
    開封第一講書人閱讀 168,983評(píng)論 0 362
  • 文/不壞的土叔 我叫張陵既忆,是天一觀的道長(zhǎng)驱负。 經(jīng)常有香客問我嗦玖,道長(zhǎng),這世上最難降的妖魔是什么跃脊? 我笑而不...
    開封第一講書人閱讀 59,938評(píng)論 1 299
  • 正文 為了忘掉前任宇挫,我火速辦了婚禮,結(jié)果婚禮上酪术,老公的妹妹穿的比我還像新娘器瘪。我一直安慰自己,他們只是感情好绘雁,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,955評(píng)論 6 398
  • 文/花漫 我一把揭開白布橡疼。 她就那樣靜靜地躺著,像睡著了一般庐舟。 火紅的嫁衣襯著肌膚如雪欣除。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 52,549評(píng)論 1 312
  • 那天挪略,我揣著相機(jī)與錄音历帚,去河邊找鬼。 笑死杠娱,一個(gè)胖子當(dāng)著我的面吹牛挽牢,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播摊求,決...
    沈念sama閱讀 41,063評(píng)論 3 422
  • 文/蒼蘭香墨 我猛地睜開眼禽拔,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了室叉?” 一聲冷哼從身側(cè)響起睹栖,我...
    開封第一講書人閱讀 39,991評(píng)論 0 277
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎太惠,沒想到半個(gè)月后磨淌,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,522評(píng)論 1 319
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡凿渊,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,604評(píng)論 3 342
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了缚柳。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片埃脏。...
    茶點(diǎn)故事閱讀 40,742評(píng)論 1 353
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖秋忙,靈堂內(nèi)的尸體忽然破棺而出彩掐,到底是詐尸還是另有隱情,我是刑警寧澤灰追,帶...
    沈念sama閱讀 36,413評(píng)論 5 351
  • 正文 年R本政府宣布堵幽,位于F島的核電站狗超,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏朴下。R本人自食惡果不足惜努咐,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 42,094評(píng)論 3 335
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望殴胧。 院中可真熱鬧渗稍,春花似錦、人聲如沸团滥。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,572評(píng)論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽灸姊。三九已至拱燃,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間力惯,已是汗流浹背碗誉。 一陣腳步聲響...
    開封第一講書人閱讀 33,671評(píng)論 1 274
  • 我被黑心中介騙來泰國(guó)打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留夯膀,地道東北人诗充。 一個(gè)月前我還...
    沈念sama閱讀 49,159評(píng)論 3 378
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像诱建,于是被迫代替她去往敵國(guó)和親蝴蜓。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,747評(píng)論 2 361

推薦閱讀更多精彩內(nèi)容