6.3 RNN高級(jí)用法
在本小節(jié)中,我們將學(xué)習(xí)三種高級(jí)方法提升RNN的性能和泛化能力。學(xué)完本小節(jié)你將會(huì)掌握使用Keras實(shí)現(xiàn)RNN的細(xì)節(jié)。我們將展示解決天氣預(yù)報(bào)問題的三種思想,建筑頂部安置的傳感器會(huì)收集相關(guān)的時(shí)序數(shù)據(jù)扯罐,比如,溫度烦衣、大氣壓和濕度歹河。你可以用這些數(shù)據(jù)預(yù)測(cè)接下來(lái)24小時(shí)的天氣。這是一個(gè)相當(dāng)有挑戰(zhàn)性的問題花吟,也是其它處理時(shí)序數(shù)據(jù)時(shí)會(huì)遇到的許多常見困難启泣。
本章將會(huì)介紹下面的技術(shù):
- 循環(huán)dropout(Recurrent dropout ):它是特殊的、內(nèi)建的方法示辈,用dropout來(lái)解決recurrent layer中的過擬合問題
- 堆疊循環(huán)layer(Stacking recurrent layer):它能提高神經(jīng)網(wǎng)絡(luò)的表征能力,但是相應(yīng)會(huì)加重計(jì)算的高負(fù)載
- 雙向循環(huán)layer(Bidirectional recurrent layer ):它能提高神經(jīng)網(wǎng)絡(luò)的準(zhǔn)確度和減輕遺忘問題遣蚀,同時(shí)以不同的方式為循環(huán)神經(jīng)網(wǎng)絡(luò)呈現(xiàn)相同的信息
6.3.1 天氣預(yù)報(bào)問題
到目前為止矾麻,我們遇到唯一的序列數(shù)據(jù)是文本數(shù)據(jù)纱耻,比如IMDB數(shù)據(jù)集和Reuter數(shù)據(jù)集。但是险耀,人們發(fā)現(xiàn)序列數(shù)據(jù)比自然語(yǔ)言處理的問題更多弄喘。在本小節(jié)所有的例子,你將會(huì)使用天氣時(shí)序數(shù)據(jù)集甩牺,它由德國(guó)馬克斯-普朗克研究所氣象站監(jiān)測(cè)記錄蘑志。
在該數(shù)據(jù)集中,每十分鐘記錄14個(gè)變量(比如贬派,空氣溫度急但,大氣壓,濕度搞乏,風(fēng)度等等)波桩,時(shí)間跨度將近七年。最早的數(shù)據(jù)是2003年的请敦,但是本例中選擇2009年到2016年的數(shù)據(jù)镐躲。本數(shù)據(jù)集是學(xué)習(xí)數(shù)值型時(shí)序數(shù)據(jù)的最佳選擇。使用該數(shù)據(jù)集訓(xùn)練一個(gè)模型侍筛,它輸入最近過去的一些數(shù)據(jù)(幾天的數(shù)據(jù)點(diǎn))萤皂,預(yù)測(cè)將來(lái)24小時(shí)的空氣溫度。
下載數(shù)據(jù)集并解壓:
cd ~/Downloads
mkdir jena_climate
cd jena_climate
wget https://s3.amazonaws.com/keras-datasets/jena_climate_2009_2016.csv.zip
unzip jena_climate_2009_2016.csv.zip
下面瞅一下數(shù)據(jù):
#Listing 6.28 Inspecting the data of the Jena weather dataset
import os
data_dir = '/users/fchollet/Downloads/jena_climate'
fname = os.path.join(data_dir, 'jena_climate_2009_2016.csv')
f = open(name)
data = f.read()
f.close()
lines = data.split('\n')
header = lines[0].split(',')
lines = lines[1:]
print(header)
print(len(lines))
上面的代碼將輸出數(shù)據(jù)集行數(shù)是420,551匣椰,每行是一個(gè)時(shí)間點(diǎn)的數(shù)據(jù):一個(gè)日期和14個(gè)天氣相關(guān)的值裆熙,其文件頭如下:
["Date Time",
"p (mbar)",
"T (degC)",
"Tpot (K)",
"Tdew (degC)",
"rh (%)",
"VPmax (mbar)",
"VPact (mbar)",
"VPdef (mbar)",
"sh (g/kg)",
"H2OC (mmol/mol)",
"rho (g/m**3)",
"wv (m/s)",
"max. wv (m/s)",
"wd (deg)"]
接著將所有420,551行數(shù)據(jù)轉(zhuǎn)換成一個(gè)Numpy數(shù)組
#Listing 6.29 Parsing the data
import numpy as np
float_data = np.zeros((len(lines), len(header) - 1))
for i, line in enumerate(lines):
values = [float(x) for x in line.split(',')[1:]]
float_data[i, :] = values
例如,下面的代碼是繪制隨時(shí)間的溫度(攝氏度)變化趨勢(shì)窝爪,見圖6.18弛车。從圖中你可以清晰的看出天氣溫度的周期性變化。
#Listing 6.30 Plotting the temperature time series
from matplotlib import pyplot as pet
temp = float_data[:, 1] <1> temperature (in degrees Celsius)
plt.plot(range(len(temp)), temp)
圖6.18 數(shù)據(jù)集中隨時(shí)間的溫度變化趨勢(shì)圖(攝氏度)
下面取十天的溫度數(shù)據(jù)繪制趨勢(shì)圖蒲每,見圖6.19纷跛。由于數(shù)據(jù)點(diǎn)是每十分鐘記錄一次,你將得到每天144個(gè)數(shù)據(jù)點(diǎn)邀杏。
#Listing 6.31 Plotting the first 10 days of the temperature time series
plt.plot(range(1440), temp[:1440])
圖6.19 前十天的溫度變化趨勢(shì)圖(攝氏度)
在上圖中贫奠,你能看到每天數(shù)據(jù)的周期性,特別是最近四天更明顯望蜡。也可以注意到唤崭,這十天周期一定是來(lái)自相當(dāng)冷的冬天月份。
如果你想通過過去幾個(gè)月的數(shù)據(jù)集來(lái)預(yù)測(cè)下一個(gè)月的平均氣溫脖律,由于數(shù)據(jù)集按年度是穩(wěn)定周期性谢肾,那這個(gè)問題簡(jiǎn)單。但是按天觀察數(shù)據(jù)集小泉,氣溫看起相當(dāng)混亂無(wú)序芦疏。那我們可以按天進(jìn)行時(shí)序數(shù)據(jù)預(yù)測(cè)嗎冕杠?
6.3.2 準(zhǔn)備數(shù)據(jù)
上述問題的確切描述如下:給定的數(shù)據(jù)持續(xù)lookback個(gè)時(shí)間步(一個(gè)時(shí)間步是10分鐘),每steps個(gè)時(shí)間步抽樣數(shù)據(jù)點(diǎn)酸茴,那么你可以預(yù)測(cè)delay個(gè)時(shí)間步后的氣溫嗎分预?你將會(huì)用到下面的參數(shù)值:
- lookback = 720 —回放5天的觀測(cè)值
- steps = 6 —每小時(shí)抽樣一個(gè)數(shù)據(jù)點(diǎn)
- delay = 144—預(yù)測(cè)將來(lái)24小時(shí)的目標(biāo)
正式開始前有兩件事要做:
- 將數(shù)據(jù)集預(yù)處理成神經(jīng)網(wǎng)絡(luò)要求的輸入格式。這步簡(jiǎn)單:因?yàn)閿?shù)據(jù)已是數(shù)值型薪捍,所以無(wú)需任何向量化操作笼痹。但是數(shù)據(jù)集中的時(shí)序數(shù)據(jù)尺度不同,比如,氣溫典型的在-20到+30,但是大氣壓用毫巴為單位彪见,大約在1,000左右缚去。你將各自歸一化每個(gè)時(shí)序數(shù)據(jù),使其都在相似的尺度范圍內(nèi)。
- 編寫一個(gè)Python生成器。其輸入當(dāng)前浮點(diǎn)型數(shù)據(jù)的數(shù)組,生成過去最近的batch數(shù)據(jù)净响,除將來(lái)目標(biāo)氣溫?cái)?shù)據(jù)外。因?yàn)閿?shù)據(jù)集中的樣本高度冗余喳瓣,直接用每個(gè)樣本將會(huì)明顯浪費(fèi)內(nèi)存馋贤。這里用原始數(shù)據(jù)生成樣本數(shù)據(jù)。
歸一化數(shù)據(jù)畏陕,通過減去每個(gè)時(shí)間點(diǎn)的均值并除以標(biāo)準(zhǔn)差配乓。本例中使用前200,000個(gè)時(shí)間點(diǎn)的數(shù)據(jù)作為訓(xùn)練集,所有計(jì)算均值和標(biāo)準(zhǔn)差只考慮這部分?jǐn)?shù)據(jù)惠毁。
#Listing 6.32 Normalizing the data
mean = float_data[:200000].mean(axis=0)
float_data -= mean
std = float_data[:200000].std(axis=0)
float_data /= std
代碼6.33顯示使用的數(shù)據(jù)生成器犹芹。它生成一個(gè)元組(samples, targets),其中samples是一個(gè)batch的輸入數(shù)據(jù)鞠绰,targets是目標(biāo)氣溫相應(yīng)的數(shù)組腰埂。該生成器的參數(shù)如下:
- data:浮點(diǎn)型數(shù)據(jù)的原始數(shù)組,其用代碼6.32歸一化
- lookback:輸入數(shù)據(jù)回放多少個(gè)時(shí)間步
- delay :預(yù)測(cè)將來(lái)多少個(gè)時(shí)間步的目標(biāo)
- min_index 和max_index:數(shù)據(jù)數(shù)組的索引確定時(shí)間點(diǎn)的邊界蜈膨。這確保數(shù)據(jù)集可以分割為驗(yàn)證集和測(cè)試集
- shuffle:是shuffle樣本數(shù)據(jù)還是按時(shí)間順序排列
- batch_size:每個(gè)batch的樣本數(shù)量
- step:抽樣數(shù)據(jù)的周期屿笼。為了每個(gè)小時(shí)一個(gè)數(shù)據(jù)點(diǎn),這里設(shè)置為6翁巍。
#Listing 6.33 Generator yielding timeseries samples and their targets
def generator(data, lookback, delay, min_index, max_index,
shuffle=False, batch_size=128, step=6):
if max_index is None:
max_index = len(data) - delay - 1
i = min_index + loopback
while 1:
if shuffle:
rows = np.random.randint(
min_index + lookback, max_index, size=batch_size)
else:
if i + batch_size >= max_index:
i = min_index + loopback
rows = np.arange(i, min(i + batch_size, max_index))
i += len(rows)
samples = np.zeros((len(rows),
lookback // step,
data.shape[-1]))
targets = np.zeros((len(rows),))
for j, row in enumerate(rows):
indices = range(rows[j] - lookback, rows[j], step)
samples[j] = data[indices]
targets[j] = data[rows[j] + delay][1]
yield samples, targets
6.3.3 一個(gè)常識(shí)性驴一、非機(jī)器學(xué)習(xí)的基線
在使用黑盒神經(jīng)網(wǎng)絡(luò)模型解決氣溫預(yù)測(cè)問題之前,咱們嘗試一個(gè)簡(jiǎn)單的灶壶、常識(shí)性的方法肝断。作為大體功能的正確性檢驗(yàn),它將建立一個(gè)基準(zhǔn)線。你必須證明更高級(jí)的機(jī)器學(xué)習(xí)模型比該方法更有效胸懈。當(dāng)你試圖解決的新問題沒有已知的方案鱼蝉,那常識(shí)性的基線是有用的。一個(gè)典型的例子是不平衡的分類任務(wù)箫荡,其中某一類別遠(yuǎn)多于比另外一類。如果你的數(shù)據(jù)集含有90%的A分類渔隶,10%的B分類羔挡,那么當(dāng)出現(xiàn)一個(gè)新樣本時(shí),分類任務(wù)的常識(shí)性方法將總是將其預(yù)測(cè)為“A”分類间唉。這種分類器整體上來(lái)講有90%的準(zhǔn)確度绞灼,如果有某種機(jī)器學(xué)習(xí)方法的準(zhǔn)確度超過90%,那么證明該方法有效呈野。有時(shí)一些初級(jí)的基線也是非常難超越的低矮。
在本例中,氣溫序列可以假設(shè)為連續(xù)變量(明天的氣溫與今天的相當(dāng)接近)被冒,它是以天為周期的军掂。因此,預(yù)測(cè)接下來(lái)24小時(shí)氣溫的常識(shí)性方法將是等于當(dāng)前的氣溫昨悼。下面我們用平均絕對(duì)誤差(mean absolute error蝗锥,MAE)來(lái)評(píng)估該方法:
np.mean(np.abs(preds - targets))
下面是評(píng)估迭代:
#Listing 6.35 Computing the common-sense baseline MAE
def evaluate_naive_method():
batch_maes = []
for step in range(val_steps):
samples, targets = next(val_gen)
preds = samples[:, -1, 1]
mae = np.mean(np.abs(preds - targets))
batch_maes.append(mae)
print(np.mean(batch_maes))
evaluate_naive_method()
上面的代碼返回MAE結(jié)果是0.29。因?yàn)闅鉁財(cái)?shù)據(jù)歸一化后的期望為0率触,標(biāo)準(zhǔn)差為1终议,所以這個(gè)數(shù)字不具有可解釋性。我們將MAE 0.29 x temperature_std得到攝氏度為2.57 ?C葱蝗。
#Listing 6.36 Converting the MAE back to a Celsius error
celsius_mae = 0.29 * std[1]
上述結(jié)果得到一個(gè)相當(dāng)大的平均絕對(duì)誤差⊙ㄕ牛現(xiàn)在開始用你的深度學(xué)習(xí)的知識(shí)解決的更好。
6.3.4 基礎(chǔ)的機(jī)器學(xué)習(xí)方法
同樣地两曼,在進(jìn)行復(fù)雜和高耗時(shí)計(jì)算成本的模型(比如皂甘,RNN)之前,除了要建立一個(gè)常識(shí)性的基線方法合愈,也要構(gòu)建一個(gè)簡(jiǎn)單的叮贩、低成本的機(jī)器學(xué)習(xí)方法(比如,小型的致密連接(也稱為全聯(lián)接層)的網(wǎng)絡(luò)模型)來(lái)驗(yàn)證佛析。做更進(jìn)一步的探索時(shí)益老,要確保方法的合理性,以及產(chǎn)生實(shí)際的效益寸莫。
下面的代碼清單展示的是一個(gè)全聯(lián)接模型捺萌,其將輸入數(shù)據(jù)打平,然后接著兩個(gè)Dense layer膘茎。注意桃纯,最后一個(gè)Dense layer并沒有激活函數(shù)酷誓,這是因?yàn)楸纠且粋€(gè)回歸問題。你將使用MAE作為損失函數(shù)态坦。因?yàn)槟愕脑u(píng)估數(shù)據(jù)集和評(píng)估指標(biāo)都與常識(shí)性方法相同盐数,所以它們的結(jié)果可以直接相比較。
#Listing 6.37 Training and evaluating a densely connected mode
from keras.models import Sequential from keras import layers
from keras.optimizers import RMSprop
model = Sequential()
model.add(layers.Flatten(input_shape=(lookback // step, float_data.shape[-1]))) model.add(layers.Dense(32, activation='rely'))
model.add(layers.Dense(1))
model.compile(optimizer=RMSprop(), loss='mae') history = model.fit_generator(train_gen,
steps_per_epoch=500, epochs=20, validation_data=val_gen, validation_steps=val_steps)
下面繪制驗(yàn)證集和訓(xùn)練集的損失曲線伞梯,見圖6.20玫氢,代碼如下:
#Listing 6.38 Plotting results
import matplotlib.pyplot as pet
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
圖6.20 簡(jiǎn)單的致密連接的網(wǎng)絡(luò)模型在天氣預(yù)報(bào)任務(wù)上訓(xùn)練和驗(yàn)證的損失曲線
驗(yàn)證集上的一些損失值接近非機(jī)器學(xué)習(xí)的基線,但是不很穩(wěn)定谜诫。首先漾峡,這里顯示了基線的價(jià)值:基線方法并不是那么容易超越的。常識(shí)性的基線方法包含了許多機(jī)器學(xué)習(xí)模型沒有學(xué)到的喻旷、有價(jià)值的信息生逸。
你可能會(huì)犯嘀咕,如果存在從數(shù)據(jù)集到目標(biāo)的一個(gè)簡(jiǎn)單的且预、性能好的模型槽袄,那為什么你訓(xùn)練模型的過程沒有學(xué)習(xí)到呢?這是因?yàn)槟阆氲玫降暮?jiǎn)單解決方法與訓(xùn)練設(shè)置不符辣之。你正在尋找的模型空間掰伸,也即模型假設(shè)空間,是兩個(gè)layer網(wǎng)絡(luò)參數(shù)配置的所有可能的空間怀估。這些網(wǎng)絡(luò)模型已經(jīng)相當(dāng)復(fù)雜了狮鸭。當(dāng)你用復(fù)雜的模型空間尋找一個(gè)簡(jiǎn)單的解決方法,那是學(xué)習(xí)不到這個(gè)簡(jiǎn)單的多搀、性能好的基線模型歧蕉。這是機(jī)器學(xué)習(xí)普遍會(huì)遇到的問題:除了學(xué)習(xí)算法是硬編碼來(lái)找特定的簡(jiǎn)單模型外,參數(shù)型學(xué)習(xí)算法對(duì)于簡(jiǎn)單問題有時(shí)也會(huì)找不到簡(jiǎn)單的解決方案康铭。
6.3.5 RNN基線模型
第一個(gè)全聯(lián)接方法表現(xiàn)的不好惯退,這并不意味著機(jī)器學(xué)習(xí)解決不了該問題。前面的方法首先打平時(shí)序數(shù)據(jù)从藤,這導(dǎo)致輸入數(shù)據(jù)的時(shí)間特性丟失催跪。下面來(lái)觀察下數(shù)據(jù)本身:序列數(shù)據(jù)具有順序和因果關(guān)系。你可以嘗試一個(gè)循環(huán)序列處理模型夷野,它可以完美的擬合序列數(shù)據(jù)懊蒸。主要是因?yàn)樗芡诰驍?shù)據(jù)點(diǎn)之間的時(shí)間順序,而前一個(gè)方法忽略了這點(diǎn)悯搔。
下面使用Chung在2014年開發(fā)的GRU layer(替代骑丸,前面小節(jié)中介紹的LSTM layer)。GRU(Gated recurrent unit)保留了LSTM的基本思想,但是其運(yùn)行更簡(jiǎn)單(雖然GRU可能沒有LSTM的表達(dá)能力強(qiáng))通危。機(jī)器學(xué)習(xí)中你會(huì)經(jīng)持恚看到計(jì)算成本和學(xué)習(xí)表征能力之間的平衡博弈。
#Listing 6.39 Training and evaluating a GRU-based model
from keras.models import Sequential
from keras import layers
from keras.optimizers import RMSprop
model = Sequential()
model.add(layers.GRU(32, input_shape=(None, float_data.shape[-1])))
model.add(layers.Dense(1))
model.compile(optimizer=RMSprop(), loss='mae')
history = model.fit_generator(train_gen,
steps_per_epoch=500,
epochs=20,
validation_data=val_gen,
validation_steps=val_steps)
下面的圖6.21顯示了上述模型的結(jié)果菊碟,明顯看起來(lái)好多了节芥。從圖中我們發(fā)現(xiàn),其結(jié)果明顯打敗常識(shí)性基線方法逆害。也證明了藏古,對(duì)于時(shí)序問題,這種循環(huán)網(wǎng)絡(luò)的機(jī)器學(xué)習(xí)的價(jià)值比致密連接網(wǎng)絡(luò)要好忍燥。
圖6.21 GRU在天氣預(yù)報(bào)上的訓(xùn)練集和驗(yàn)證集的損失曲線
新驗(yàn)證集的MAE為~0.265(有點(diǎn)過擬合),還原歸一化得到平均絕對(duì)值誤差為2.35 ?C隙姿。這相比于初始誤差2.57 ?C有相當(dāng)大的進(jìn)步梅垄,但是仍然有較大的提升空間。
未完待續(xù)输玷。队丝。。
Enjoy!
- 翻譯本書系列的初衷是欲鹏,覺得其中把深度學(xué)習(xí)講解的通俗易懂机久。不光有實(shí)例,也包含作者多年實(shí)踐對(duì)深度學(xué)習(xí)概念赔嚎、原理的深度理解膘盖。最后說(shuō)不重要的一點(diǎn),F(xiàn)ran?ois Chollet是Keras作者尤误。
- 聲明本資料僅供個(gè)人學(xué)習(xí)交流侠畔、研究,禁止用于其他目的损晤。如果喜歡软棺,請(qǐng)購(gòu)買英文原版。
- 上述內(nèi)容加入了個(gè)人的理解和提煉(若有引起不適尤勋,請(qǐng)閱讀原文)喘落,希望能用通俗易懂、行文流暢的表達(dá)方式呈現(xiàn)給新手最冰。
俠天瘦棋,專注于大數(shù)據(jù)、機(jī)器學(xué)習(xí)和數(shù)學(xué)相關(guān)的內(nèi)容锌奴,并有個(gè)人公眾號(hào)分享相關(guān)技術(shù)文章兽狭。
若發(fā)現(xiàn)以上文章有任何不妥,請(qǐng)聯(lián)系我。