人人都能懂的機器學習——用Keras搭建人工神經網絡09

微調神經網絡超參數

神經網絡的靈活性其實也是它的缺點:有太多的超參數需要調整。神經網絡的靈活性可以讓我們使用任何想象中的網絡架構儿普,但是即使一個簡單的MLP崎逃,我們也要考慮層的數量,每層神經元的數量眉孩,每層使用的激活函數的類型个绍,權重初始化邏輯等等。那么我們要怎么知道那種超參數組合最適合解決問題呢浪汪?

一個解決方式就是嘗試各種超參數組合方式然后看那種組合在驗證集上的表現最好(或者使用K-fold交叉驗證)巴柿。為此,我們可以使用網格搜索或者隨機搜索方法死遭,探索超參數空間广恢。我們需要將Keras模型打包成一個對象,就像一般的Scikit-Learn回歸器一樣呀潭。第一步我們需要創(chuàng)建一個函數钉迷,用于創(chuàng)建和編譯Keras模型,然后輸入一系列超參數:

def build_model(n_hidden=1, n_neurons=30, learning_rate=3e-3, input_shape=[8]):
    model = keras.models.Sequential()
    model.add(keras.layers.InputLayer(input_shape=input_shape))
    for layer in range(n_hidden):
        model.add(keras.layers.Dense(n_neurons, activation="relu"))
    model.add(keras.layers.Dense(1))
    optimizer = keras.optimizers.SGD(lr=learning_rate)
    model.compile(loss="mse", optimizer=optimizer)
    return model

這個函數創(chuàng)建了一個簡單的單變量回歸Sequential模型蜗侈,并向其傳遞了輸入形狀篷牌,層數和每層的神經元數,然后用SGD優(yōu)化器編譯并設置了學習率踏幻。代碼中向模型提供了盡可能多的超參數默認值,這通常是比較好的做法戳杀。

接下來该面,我們創(chuàng)建一個基于上面的build_model()函數的KerasRegressor:

keras_reg = keras.wrappers.scikit_learn.KerasRegressor(build_model)

KerasRegressor對象是使用build_model()構建的Keras模型的一個簡單包裝夭苗。因為我們在創(chuàng)建它時沒有指定任何超參數,所以它將使用我們在build_model()中定義的默認超參數值「糇海現在我們可以像使用常規(guī)的Scikit-Learn回歸器一樣使用這個對象:我們可以使用fit()方法來訓練它题造,然后使用score()方法來評估,并使用predict()方法進行預測猾瘸。具體操作如下所示:

keras_reg.fit(X_train, y_train, epochs=100,
              validation_data=(X_valid, y_valid),
              callbacks=[keras.callbacks.EarlyStopping(patience=10)])
mse_test = keras_reg.score(X_test, y_test)
y_pred = keras_reg.predict(X_new)

注意界赔,傳遞給fit()方法的任何額外的參數都將傳遞給底層的Keras模型。另外牵触,score()方法的結果與MSE的相反淮悼,因為Scikit-Learn的理念是計算分數,而不是損失(即得分越高越好)

當然我們的目的并不是訓練和評估一個模型揽思,而是訓練成百上千的超參數組合然后看哪種超參數組合在驗證集上的表現最好袜腥。那么既然有那么多超參數需要調整,選用隨機搜索的方法比網格搜索更佳钉汗。讓我們嘗試探索一下隱藏層和神經元的數量以及學習率:

from scipy.stats import reciprocal
from sklearn.model_selection import RandomizedSearchCV
param_distribs = {
    "n_hidden": [0, 1, 2, 3],
    "n_neurons": np.arange(1, 100),
    "learning_rate": reciprocal(3e-4, 3e-2),
}
rnd_search_cv = RandomizedSearchCV(keras_reg, param_distribs, n_iter=10, cv=3)
rnd_search_cv.fit(X_train, y_train, epochs=100,
                  validation_data=(X_valid, y_valid),
                  callbacks=[keras.callbacks.EarlyStopping(patience=10)])

我們將額外的參數傳遞給fit()方法羹令,并將它們傳遞給底層的Keras模型。注意RandomizedSearchCV使用K-fold交叉驗證损痰,因此它不使用X_valid和y_valid福侈,這兩個值只用于早停森篷。

超參數的探索可能需要幾個小時兼吓,具體取決于硬件,數據量步绸,模型復雜度尝丐,n_itercv的值显拜。在運行結束之后,你就會得到最佳的超參數爹袁,得分远荠,以及訓練好的Keras模型:

>>> rnd_search_cv.best_params_
{'learning_rate': 0.0033625641252688094, 'n_hidden': 2, 'n_neurons': 42}
>>> rnd_search_cv.best_score_
-0.3189529188278931
>>> model = rnd_search_cv.best_estimator_.model

現在你就可以保存這個模型,在測試集上進行評估失息,如果你對這個結果比較滿意譬淳,那么就可以部署生產了。使用隨機搜索并不難盹兢,并且對于許多簡單的問題表現得很好邻梆。但當訓練速度很慢的時候(比如問題很復雜,數據量又很大)绎秒,這個方法只能探索很小范圍的超參數空間浦妄。我們可以手動幫助這個探索過程,讓它稍微有所改善:首先使用大范圍的超參數空間,快速運行一次隨機搜索剂娄,然后在第一次的運行結果中找出最佳的超參數值蠢涝,再在這個超參數值附近更小的空間運行一次隨機搜索。這種方法可以手動縮小至一組較好的超參數阅懦。但是和二,這個方法仍然非常花時間耳胎,而且也不是最值得花時間的地方惯吕。

非常幸運的是,現在已經有很多技術可以幫助我們更高效地探索超參數空間了怕午。這些技術的核心思路也很簡單:當一個區(qū)域的超參數空間的結果不錯废登,那么應該對這個區(qū)域進行更多的探索。這些技術幫助我們處理了縮小搜索空間的事情诗轻,并且用更少的時間得出更好的解決方案钳宪。下面列出了一下可以用來調參的python庫:

  • Hyperopt:用于優(yōu)化各種復雜的搜索空間的Python庫(比如學習率,離散值扳炬,隱藏層數)
  • Hyperas吏颖,kopt,Talos:為Keras模型優(yōu)化超參數而開發(fā)的庫(前兩個是基于Hyperopt)
  • Keras Tuner:谷歌為Keras開發(fā)的很容易使用的超參數優(yōu)化庫恨樟,并且?guī)в锌梢暬头治鐾泄芊?/li>
  • Scikit-Optimize(skopt):是一個通用的優(yōu)化庫半醉。BayesSearchCV類執(zhí)行貝葉斯優(yōu)化,并且其接口與GridSearchCV十分相似
  • Spearmint:一個貝葉斯優(yōu)化庫
  • Hyperband:基于Lisha Li等人的論文《Hyperband: A Novel Bandit-Based Approach to Hyperparameter Optimization》開發(fā)的快速超參數調優(yōu)庫
  • Sklearn-Deap:一個基于進化算法的超參數優(yōu)化庫劝术,其接口也與GridSearchCV十分相似

另外缩多,許多公司還提供超參數優(yōu)化服務。我們在未來的文章中將討論谷歌云AI平臺的超參數調優(yōu)服務养晋。還有公司會提供超參數優(yōu)化的API衬吆,比如Arimo,SifOpt和Oscar等绳泉。

超參數調優(yōu)仍然是一個活躍的研究領域逊抡。進化算法最近又卷土重來了。比如零酪,DeepMind在2017年發(fā)表了一篇優(yōu)秀的論文:《Population Based Training of Neural Networks》冒嫡,在論文中作者綜合優(yōu)化了一些模型以及它們的超參數。谷歌還使用了一種進化方法四苇,不僅用于搜索超參數孝凌,還用于尋找解決問題的最佳神經網絡體系結構。這個進化方法被稱為AutoML月腋,并已經可以在云端使用蟀架。也許這個技術會是人工構建神經網絡的終結瓣赂?有興趣的朋友可以看谷歌關于這個項目的文章:

https://ai.googleblog.com/2018/03/using-evolutionary-automl-to-discover.html

實際上,進化算法已經替代無處不在的梯度下降法成功用于訓練單個神經網絡了辜窑。Uber就于2017年在官網上發(fā)表了一篇介紹他們深度神經進化技術的文章钩述。

雖然我們有了這些令人興奮的技術發(fā)展寨躁,還有各種工具和服務穆碎,但是了解每個超參數的合理范圍仍然是有幫助的。這樣我們就可以構建一個快速的原型并限制超參數搜索空間职恳。

下一篇文章將講述如何在MLP中選擇隱藏層和神經元數量所禀,以及選擇其他重要超參數合適的值。

敬請期待啦放钦!

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
  • 序言:七十年代末色徘,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子操禀,更是在濱河造成了極大的恐慌褂策,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,451評論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件颓屑,死亡現場離奇詭異斤寂,居然都是意外死亡,警方通過查閱死者的電腦和手機揪惦,發(fā)現死者居然都...
    沈念sama閱讀 93,172評論 3 394
  • 文/潘曉璐 我一進店門遍搞,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人器腋,你說我怎么就攤上這事溪猿。” “怎么了纫塌?”我有些...
    開封第一講書人閱讀 164,782評論 0 354
  • 文/不壞的土叔 我叫張陵诊县,是天一觀的道長。 經常有香客問我措左,道長依痊,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,709評論 1 294
  • 正文 為了忘掉前任媳荒,我火速辦了婚禮抗悍,結果婚禮上,老公的妹妹穿的比我還像新娘钳枕。我一直安慰自己缴渊,他們只是感情好,可當我...
    茶點故事閱讀 67,733評論 6 392
  • 文/花漫 我一把揭開白布鱼炒。 她就那樣靜靜地躺著衔沼,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上指蚁,一...
    開封第一講書人閱讀 51,578評論 1 305
  • 那天菩佑,我揣著相機與錄音,去河邊找鬼凝化。 笑死稍坯,一個胖子當著我的面吹牛,可吹牛的內容都是我干的搓劫。 我是一名探鬼主播瞧哟,決...
    沈念sama閱讀 40,320評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼枪向!你這毒婦竟也來了勤揩?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 39,241評論 0 276
  • 序言:老撾萬榮一對情侶失蹤秘蛔,失蹤者是張志新(化名)和其女友劉穎陨亡,沒想到半個月后,有當地人在樹林里發(fā)現了一具尸體深员,經...
    沈念sama閱讀 45,686評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡负蠕,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 37,878評論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現自己被綠了辨液。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片虐急。...
    茶點故事閱讀 39,992評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖滔迈,靈堂內的尸體忽然破棺而出止吁,到底是詐尸還是另有隱情,我是刑警寧澤燎悍,帶...
    沈念sama閱讀 35,715評論 5 346
  • 正文 年R本政府宣布敬惦,位于F島的核電站,受9級特大地震影響谈山,放射性物質發(fā)生泄漏俄删。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,336評論 3 330
  • 文/蒙蒙 一奏路、第九天 我趴在偏房一處隱蔽的房頂上張望畴椰。 院中可真熱鬧,春花似錦鸽粉、人聲如沸斜脂。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,912評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽帚戳。三九已至玷或,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間片任,已是汗流浹背偏友。 一陣腳步聲響...
    開封第一講書人閱讀 33,040評論 1 270
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留对供,地道東北人位他。 一個月前我還...
    沈念sama閱讀 48,173評論 3 370
  • 正文 我出身青樓,卻偏偏與公主長得像犁钟,于是被迫代替她去往敵國和親棱诱。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 44,947評論 2 355

推薦閱讀更多精彩內容