tensorflow練習(xí)系列:
tensorflow練習(xí)1-DNN
現(xiàn)在我們嘗試使用lstm來預(yù)測(cè)一個(gè)周期性的數(shù)據(jù)
import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np
from matplotlib import pyplot as plt
from sklearn import preprocessing, model_selection
from sklearn import preprocessing
import pandas as pd
input_vec_size = 1 # 輸入向量的維度
lstm_size = 10 # size of lstm
time_step_size = 5 # 循環(huán)層長(zhǎng)度
batch_size = 7
test_size = 3
1. 準(zhǔn)備數(shù)據(jù)
我們使用偽造的數(shù)據(jù)
day = (time_step_size + 1)* 200
week_rate = [0.9, 0.85, 0.80, 0.88, 1.1, 1.2, 1.15]
label = [(1 + i * 0.002) * week_rate[i%7] for i in range(day)]
label = np.array(label)
# scaler = preprocessing.StandardScaler()
# label = scaler.fit_transform(label)
label = label.reshape(int(day / (time_step_size + 1)), (time_step_size + 1))
print(label.shape)
_tmp = label
X_ = _tmp[:, :time_step_size]
Y_ = _tmp[:, time_step_size:]
print(_tmp.shape, X_.shape, Y_.shape)
plt.plot(_tmp[:100, 0])
plt.show()
輸出結(jié)果如下
(200, 6)
(200, 6) (200, 5) (200, 1)
2. 準(zhǔn)備網(wǎng)絡(luò)
def init_weights(shape):
return tf.Variable(tf.random_normal(shape, stddev=0.01))
def model(X, W, B, lstm_size):
# X, input shape: (batch_size, time_step_size, input_vec_size)
# XT shape: (time_step_size, batch_size, input_vec_size)
print(X.shape)
XT = tf.transpose(X, [1, 0, 2])
# XR shape: (time_step_size * batch_size, input_vec_size)
XR = tf.reshape(XT, [-1, input_vec_size]) # each row has input for each lstm cell (lstm_size=input_vec_size)
# Each array shape: (batch_size, input_vec_size)
X_split = tf.split(XR, time_step_size, 0) # split them to time_step_size
# Make lstm with lstm_size (each input vector size). num_units=lstm_size; forget_bias=1.0
lstm = rnn.BasicLSTMCell(lstm_size, forget_bias=1.0, state_is_tuple=True)
# Get lstm cell output, time_step_size (28) arrays with lstm_size output: (batch_size, lstm_size)
# rnn..static_rnn()的輸出對(duì)應(yīng)于每一個(gè)timestep砸琅,如果只關(guān)心最后一步的輸出趋惨,取outputs[-1]即可
outputs, _states = rnn.static_rnn(lstm, X_split, dtype=tf.float32) # 時(shí)間序列上每個(gè)Cell的輸出:[... shape=(128, 28)..]
# Linear activation
# Get the last output
return tf.matmul(outputs[-1], W) + B, lstm.state_size # State size to initialize the stat
trX, teX, trY, teY = model_selection.train_test_split(X_, Y_, test_size=0.3)
print(trX.shape, trY.shape, teX.shape, teY.shape)
trX = trX.reshape(-1, time_step_size, 1)
teX = teX.reshape(-1, time_step_size, 1)
trY = trY.reshape(-1, 1)
teY = teY.reshape(-1, 1)
X = tf.placeholder("float", [None, time_step_size, 1])
Y = tf.placeholder("float", [None, 1])
# get lstm_size and output 10 labels
W = init_weights([lstm_size, 1])
B = init_weights([1])
py_x, state_size = model(X, W, B, lstm_size)
loss = tf.reduce_mean(tf.square(py_x - Y))
train_op = tf.train.AdamOptimizer(0.01).minimize(loss)
predict_op = py_x
session_conf = tf.ConfigProto()
session_conf.gpu_options.allow_growth = True
3. 開始訓(xùn)練
# Launch the graph in a session
with tf.Session(config=session_conf) as sess:
# you need to initialize all variables
tf.global_variables_initializer().run()
for i in range(5000):
for start, end in zip(range(0, len(trX), batch_size), range(batch_size, len(trX)+1, batch_size)):
#print("feed:", trX[start:end][0] , trY[start:end][0])
#print("feed:", trX[start:end].shape, trY[start:end].shape)
sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})
if(i % 100 == 0):
print(i, sess.run(loss, feed_dict={X: trX, Y: trY}), sess.run(loss, feed_dict={X: teX, Y: teY}))
predic = sess.run(predict_op, feed_dict={X: X_.reshape(-1, time_step_size, 1), Y: Y_.reshape(-1, 1)})
訓(xùn)練結(jié)果如下:
0 0.708224 0.584213
100 0.000834472 0.000973581
200 0.000377034 0.000441525
300 0.0111581 0.0139003
400 0.000116222 0.000156086
500 0.000468893 0.000493862
600 0.000166236 0.000181855
700 4.5977e-05 7.10702e-05
800 0.0001364 0.000223403
900 5.61442e-05 7.7883e-05
1000 4.72553e-05 6.94315e-05
1100 0.00035655 0.000320759
1200 1.88308e-05 3.37516e-05
1300 1.21966e-05 2.17105e-05
1400 4.73054e-05 7.62947e-05
1500 7.19001e-06 2.29275e-05
1600 0.00010293 0.000149464
1700 3.69292e-05 4.97202e-05
1800 1.507e-05 2.25646e-05
1900 1.81654e-05 3.16758e-05
2000 6.29222e-06 1.25351e-05
2100 7.36748e-05 0.000105863
2200 9.74285e-06 1.90102e-05
2300 5.93037e-05 7.35179e-05
2400 0.000723914 0.000536784
2500 3.20104e-06 8.17306e-06
2600 3.83796e-06 1.01513e-05
2700 5.5799e-06 1.48322e-05
2800 3.92063e-06 1.29322e-05
2900 3.68061e-06 1.5905e-05
3000 2.65767e-06 1.1614e-05
3100 3.09019e-06 1.39719e-05
3200 0.000264612 0.000287502
3300 4.02342e-05 4.46263e-05
3400 2.33401e-05 3.36289e-05
3500 8.55672e-05 0.000107355
3600 3.20796e-05 4.43166e-05
3700 1.7802e-05 2.81231e-05
3800 1.15583e-05 1.70578e-05
3900 1.40946e-05 2.29459e-05
4000 1.16366e-05 1.92793e-05
4100 4.93991e-06 1.29583e-05
4200 2.57865e-05 4.40378e-05
4300 6.99072e-06 1.18626e-05
4400 3.02546e-06 7.17427e-06
4500 6.18525e-06 1.39441e-05
4600 0.000154743 0.000180568
4700 7.82577e-05 7.30797e-05
4800 5.27492e-05 6.85774e-05
4900 1.77607e-05 2.66213e-05
查看一下結(jié)果數(shù)據(jù):
plt.plot(Y_[:100])
plt.plot(predic[:100])
plt.show()
# print(scaler.mean_)
print(np.mean(np.square(Y_ - predic)))
# print(np.hstack([scaler.inverse_transform(Y_), scaler.inverse_transform(predic)]))
print(np.hstack([X_, Y_, predic])[:20])
得到
9.2249199079e-05
[[ 0.9 0.8517 0.8032 0.88528 1.1088 1.212
1.21290922]
[ 1.1638 0.9126 0.8636 0.8144 0.8976 1.1242
1.12351131]
[ 1.2288 1.1799 0.9252 0.8755 0.8256 0.90992
0.90445042]
[ 1.1396 1.2456 1.196 0.9378 0.8874 0.8368
0.82641059]
[ 0.92224 1.155 1.2624 1.2121 0.9504 0.8993
0.8717519 ]
[ 0.848 0.93456 1.1704 1.2792 1.2282 0.963
0.95084727]
[ 0.9112 0.8592 0.94688 1.1858 1.296 1.2443
1.22580171]
[ 0.9756 0.9231 0.8704 0.9592 1.2012 1.3128
1.31333256]
[ 1.2604 0.9882 0.935 0.8816 0.97152 1.2166
1.2210536 ]
[ 1.3296 1.2765 1.0008 0.9469 0.8928 0.98384
0.97711486]
[ 1.232 1.3464 1.2926 1.0134 0.9588 0.904
0.89058191]
[ 0.99616 1.2474 1.3632 1.3087 1.026 0.9707
0.95494765]
[ 0.9152 1.00848 1.2628 1.38 1.3248 1.0386
1.02744293]
[ 0.9826 0.9264 1.0208 1.2782 1.3968 1.3409
1.33157599]
[ 1.0512 0.9945 0.9376 1.03312 1.2936 1.4136
1.41023409]
[ 1.357 1.0638 1.0064 0.9488 1.04544 1.309
1.31301975]
[ 1.4304 1.3731 1.0764 1.0183 0.96 1.05776
1.05260563]
[ 1.3244 1.4472 1.3892 1.089 1.0302 0.9712
0.95847595]
[ 1.07008 1.3398 1.464 1.4053 1.1016 1.0421
1.03257525]
[ 0.9824 1.0824 1.3552 1.4808 1.4214 1.1142
1.10492337]]
進(jìn)一步圖形化結(jié)果:
mean_ = np.mean(Y_)
plt.plot((predic - Y_)/mean_ * 100)
plt.show()
print("mean", mean_)
mean 2.166021
4. 小結(jié)
- 之前的代碼存在過擬合,增加樣本數(shù)量后, 得到緩解, (一定要同時(shí)打印出train&test loss)
- 之前的代碼擬合效果不佳滤祖, 與樣本的數(shù)據(jù)分布過大有關(guān)陋气, 應(yīng)該在lstm哪個(gè)地方加上norm议忽, 適應(yīng)更多的數(shù)據(jù)分布溉跃。目前取得的結(jié)果與數(shù)據(jù)的分布比較合適有關(guān)村刨。
- Adam似乎是一個(gè)不錯(cuò)的選擇。我也嘗試過RMSprop喊积, 但是效果不佳烹困, 貌似動(dòng)量有點(diǎn)過頭的樣子。
思考
- 預(yù)測(cè)的結(jié)果成周期性乾吻, 是否是沒有學(xué)習(xí)到一些內(nèi)容造成的?
- 怎樣評(píng)判當(dāng)前的結(jié)果呢拟蜻?
print("當(dāng)前的預(yù)測(cè)結(jié)果", np.mean(np.square(predic-Y_)))
#對(duì)比绎签, 假設(shè)網(wǎng)絡(luò)什么都沒有學(xué)習(xí)到, 那么直接用X_最后一個(gè)作為預(yù)測(cè)結(jié)果
print("直接用上一個(gè)值作為預(yù)測(cè)結(jié)果", np.mean(np.square(X_[:, -1] - Y_)))
print("學(xué)習(xí)到了一點(diǎn)點(diǎn)的大勢(shì)", np.mean(np.square(X_[:, -1] + 0.002 -Y_)))
當(dāng)前的預(yù)測(cè)結(jié)果 1.69001122001e-05
直接用上一個(gè)值作為預(yù)測(cè)結(jié)果 1.96738985325
學(xué)習(xí)到了一點(diǎn)點(diǎn)的大勢(shì) 1.96737203849
從上面的結(jié)果看酝锅, 還是學(xué)習(xí)到了一點(diǎn)的诡必。
但是從上面可以看到, 其實(shí)我們是選的周期為6的數(shù)據(jù)在訓(xùn)練搔扁, 我們把周期改成7了爸舒, 結(jié)果會(huì)不會(huì)更好呢?
答案當(dāng)然是更好稿蹲, loss達(dá)到了1.53918561546e-07扭勉, 原來周期為5+1的時(shí)候, loss為 1.07650830461e-05
我們猜想原因可能是因?yàn)榫W(wǎng)絡(luò)可以直接學(xué)習(xí)到最后一個(gè)數(shù)字 / 1.2 * 1.15 修正0.02的增產(chǎn)率苛聘, 就可以得到結(jié)果涂炎。
那么忠聚, 我們是否能通過修正網(wǎng)絡(luò), 讓周期不是6+1的時(shí)候唱捣, 也得到較好的結(jié)果e-7等級(jí)的loss呢两蟀?把當(dāng)前星期幾作為特征輸入,結(jié)果直到訓(xùn)練到2900步才把loss降到e-6級(jí)別震缭, 是否有優(yōu)化的前景呢赂毯?
現(xiàn)在2007個(gè)數(shù)據(jù), 實(shí)際也只用了200組作為訓(xùn)練和驗(yàn)證數(shù)據(jù)拣宰, 剛好7個(gè)數(shù)據(jù)一組欢瞪, 實(shí)際上我們可以n個(gè)數(shù)據(jù), 取出n - 7 + 1組出來徐裸, 這個(gè)結(jié)果遣鼓, 期待后續(xù)更新, 結(jié)果單純改到多數(shù)組重贺, loss反而更高骑祟, 到了1.6e-3級(jí)別, 標(biāo)準(zhǔn)化輸入數(shù)據(jù)后气笙,結(jié)果loss到了1.5e-3~1.9e-4( 這個(gè)是scaler.scaler_**2 的結(jié)果)次企,loss有所降低, 但是效果不是特別明顯
5. 改進(jìn)
目前其實(shí)效果還是不理想潜圃, 等理解更輸入缸棵, 有時(shí)間后續(xù)改進(jìn)模型。