Keras 2.2.x learning_phase 機制源碼級解讀

由于Keras 2.3.0開始適配tf2.0導致代碼大規(guī)模重構啼器,因此我們討論和tf1.x適配的Keras2.1.x與2.2.x版本。
Learning_phase雖然看似簡單端逼,實則非常重要暑认,標志著模型的運行狀態(tài)(訓練還是推理)柏肪。在pytorch中可以使用model.train()與model.eval()切換模型的狀態(tài),因為會關系到BN與dropout層的計算改變。而在Keras中贿肩,為了使模型api簡潔峦椰,且適配多種backend,運用了learning_phase機制去解決這種問題汰规。而這種機制意外帶來了巨大的復雜性汤功,邏輯變得混亂,從而間接引發(fā)了一些問題:
Keras中的BN層錯誤:https://zhuanlan.zhihu.com/p/56225304
固化Keras模型后輸出參數(shù)變了一點:https://stackoverflow.com/questions/61619032/got-small-output-value-error-between-h5-model-and-pb-model

本文試圖徹底理清learning_phase機制溜哮,從而對Keras有著更全面的認知滔金,使得下次遇到相關問題可以輕松解決。

一茂嗓、定義

首先可以發(fā)現(xiàn)餐茵,在tensorflow_backend.py中有它的定義:

def learning_phase():
    """Returns the learning phase flag.

    The learning phase flag is a bool tensor (0 = test, 1 = train)
    to be passed as input to any Keras function
    that uses a different behavior at train time and test time.

    # Returns
        Learning phase (scalar integer tensor or Python integer).
    """
    graph = tf.get_default_graph()
    if graph not in _GRAPH_LEARNING_PHASES:
        phase = tf.placeholder_with_default(False,
                                            shape=(),
                                            name='keras_learning_phase')
        _GRAPH_LEARNING_PHASES[graph] = phase
    return _GRAPH_LEARNING_PHASES[graph]

可以看到,K.learning_phase()是全局量述吸,且依附于當前graph中的唯一量忿族,當調(diào)用它的時候,會去找_GRAPH_LEARNING_PHASES字典刚梭,若有當前graph的K.learning_phase()則取出肠阱,否則新建一個中placeholder_with_default(False)存入字典,下次取出的就是它了朴读。
注意這是一個dtype=bool的placeholder屹徘,代表運行網(wǎng)絡時,我們可以使用feed_dict={K.learning_phase(): 0 or 1}喂入衅金,指定它的取值(前提是當前graph中的K.learning_phase()還是一個placeholder)噪伊。若不指定,則默認為False(0)氮唯。

  • 手動賦值
def set_learning_phase(value):
    """Sets the learning phase to a fixed value.

    # Arguments
        value: Learning phase value, either 0 or 1 (integers).

    # Raises
        ValueError: if `value` is neither `0` nor `1`.
    """
    global _GRAPH_LEARNING_PHASES
    if value not in {0, 1}:
        raise ValueError('Expected learning phase to be '
                         '0 or 1.')
    _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = value

可以通過K.learning_phase()方法鉴吹,為該全局量賦值,0為test惩琉,1為train豆励。
賦值必須為這兩個int之一。一旦手動賦值瞒渠,則會覆蓋掉原先dict里默認的placeholder良蒸。因此手動set值之后建立模型后,graph里的該值就被定死為一個int了伍玖,無法更改嫩痰,也無法再feed_dict里喂入K.learning_phase()了。因為當前graph的模型建立時窍箍,用到的K.learning_phase()就是一個int串纺。
建立模型后丽旅,再使用set_learning_phase更改值,對原模型無效纺棺。因為此時改的K.learning_phase()和graph里的建立好的模型無關了榄笙。

因此訓練時候不能用,否則訓練過程中預測驗證集的時候也會強制使用train時候的配置祷蝌。這與pytorch動態(tài)圖能隨意切換train/eval狀態(tài)不同(畢竟tf是靜態(tài)圖)办斑。

專門進行測試的時候可以用,但是其實沒必要杆逗,一是因為Keras.model.predict的時候會傳入數(shù)值0,二是因為placeholder_with_default默認就是False鳞疲。

模型中不存在BN或dropout層的時候無效罪郊,因為train/eval都是同一套運算流程和參數(shù)配置。

那這個哪里可用到尚洽?可在建立模型的時候用悔橄,手動控制該layer使用哪種配置,適合折騰(比如開篇知乎那個解決BN問題的時候腺毫,就可通過創(chuàng)建layer時給某些BN層強制配置值0使得BN一直處于推理階段癣疟,一直使用一開始遷移學習初始狀態(tài)的移動平均值)

二、使用地點

在建立模型時潮酒,當遇到dropout或bn層時睛挚,以簡單的keras.layers.Dropout()為例:

    def call(self, inputs, training=None):
        if 0. < self.rate < 1.:
            noise_shape = self._get_noise_shape(inputs)

            def dropped_inputs():
                return K.dropout(inputs, self.rate, noise_shape,
                                 seed=self.seed)
            return K.in_train_phase(dropped_inputs, inputs,
                                    training=training)
        return inputs

call()方法是所有Layer的邏輯實現(xiàn)層,調(diào)用該層的時候就會調(diào)用此方法急黎。 K.dropout本質(zhì)是為了適配不同后端扎狱,tf就會在該方法中調(diào)用tf.nn.dropout。

重點是 K.in_train_phase(dropped_inputs, inputs,training=training)函數(shù)勃教。

def in_train_phase(x, alt, training=None):
    """Selects `x` in train phase, and `alt` otherwise.

    Note that `alt` should have the *same shape* as `x`.

    # Arguments
        x: What to return in train phase
            (tensor or callable that returns a tensor).
        alt: What to return otherwise
            (tensor or callable that returns a tensor).
        training: Optional scalar tensor
            (or Python boolean, or Python integer)
            specifying the learning phase.

    # Returns
        Either `x` or `alt` based on the `training` flag.
        the `training` flag defaults to `K.learning_phase()`.
    """
    if training is None:
        training = learning_phase()
        uses_learning_phase = True
    else:
        uses_learning_phase = False

    if training is 1 or training is True:
        if callable(x):
            return x()
        else:
            return x

    elif training is 0 or training is False:
        if callable(alt):
            return alt()
        else:
            return alt

    # else: assume learning phase is a placeholder tensor.
    x = switch(training, x, alt)
    if uses_learning_phase:
        x._uses_learning_phase = True
    return x

即通過K.in_train_phase判斷該返回哪個值淤击,訓練的時候應該返回drop后的值,測試的時候應該不丟棄值(tf的dropout在訓練時已經(jīng)除以keep_prob系數(shù)了故源,所以測試時直接輸出input即可)污抬。

  1. 模型建立階段:(即構建graph)layer未設置training參數(shù),則默認為None绳军。之前沒手動set_learning_phase:
    則K.learning_phase()在training = learning_phase()這一步第一次被調(diào)用印机,然后初始化,故K.learning_phase()是:(BN的example)
    Tensor("bn_conv1/keras_learning_phase:0", shape=(), dtype=bool)
    當然若在模型建立前調(diào)用過删铃,則name里沒bn_conv1這個前綴耳贬。
    然后uses_learning_phase = True,代表這層使用到了learning_phase猎唁。返回的tensor一定會設置_uses_learning_phase = True這個屬性值咒劲。graph建立完畢顷蟆。
  • 若是訓練階段運行graph,Keras.model的fit等方法會自動判斷模型是否有的layer._uses_learning_phase為True腐魂,即是否用到了learning_phase()帐偎,若用到了則feed_dict中多一個K.learning_phase()這個placeholder,并傳入一個值:1. 蛔屹,因此訓練運行時削樊,通過switch函數(shù)返回dropout的返回值。
  • 若測試階段運行graph兔毒,Keras.model.predict()會自動為K.learning_phase()這個placeholder傳入一個值:0.漫贞,或自行調(diào)用sess.run方法的時候不傳值,由于placeholder_with_default(False)默認就是False值故不影響結果育叁。

由于上文提到過迅脐,訓練時一般不手動調(diào)用set_learning_phase,因此我們討論:

  1. 模型建立階段:(即構建graph)layer設置training參數(shù)為True豪嗽,無論有沒有手動set_learning_phase:
    我們發(fā)現(xiàn)graph搭建的時候直接跳過了swith函數(shù)分支谴蔑,直接返回dropout之后的值。同時x._uses_learning_phase也未設置龟梦,可以說是完全拋棄了K.learning_phase()隐锭。K.learning_phase()都沒初始化的機會。
  • 若是訓練/測試階段運行graph计贰,都會直接運行graph里的返回dropout值钦睡,問題很大!
  1. 模型建立階段:(即構建graph)layer設置training參數(shù)為False蹦玫,無論之前有沒有手動set_learning_phase:
    和2原理一樣赎婚,都會強行返回輸入值,graph中失去了對訓練or測試階段的選擇性樱溉。即dropout層等于失效挣输。

總結:
training參數(shù)手動設置True還是False都會帶來巨變,使網(wǎng)絡拋棄了K.learning_phase()福贞,也等于是graph中失去了對訓練or測試階段的選擇性撩嚼。因為Graph可以應對兩種階段,本質(zhì)是由于K.learning_phase()是個placeholder使輸入有兩種可能挖帘,然后K.in_train_phase中存在switch方法分支根據(jù)placeholder輸出兩種可能性完丽。

那training參數(shù)保留None,手動set_learning_phase會咋樣拇舀?

  1. 模型建立階段:(即構建graph)layer未設置training參數(shù)逻族,則默認為None。之前手動set_learning_phase=1:
    此時training=1骄崩,還是會進入分支2聘鳞,只保留強行使用dropout這一條路薄辅。

  2. 模型建立階段:(即構建graph)layer未設置training參數(shù),則默認為None抠璃。之前手動set_learning_phase=0:
    此時training=0站楚,還是會進入分支3,直接砍了dropout這一條路搏嗡。即模型中dropout層沒作用了窿春。

總結:
構建模型時不應該手動設置training參數(shù)。那么采盒,手動set_learning_phase=1構建模型旧乞,會使模型只留下訓練配置一條路,和training=True一樣磅氨,測試必定會強行使用dropout層一定輸出會出錯良蛮。而手動set_learning_phase=0構建模型,會使模型中dropout層失效悍赢,直接訓練就廢了。
那么還要手動設這些參數(shù)干嘛货徙?因為有特殊情況左权,即遷移訓練、或加載別人訓練好的模型痴颊。例如keras.application.ResNet50()赏迟,我們不需要自己訓練,那么可以提前set_learning_phase=0蠢棱,這樣構建出來的resnet會少了很多節(jié)點锌杀,相當于把graph里BN層訓練的路子給砍了。load_weight的時候只會載入部分weights泻仙。model.predict()的時候結果也和原來一致糕再。
同時,和知乎里說的一樣玉转,遷移訓練時給BN設置training=False突想,或臨時給BN層set_learning_phase=0,(別的層雖然set_learning_phase=1但由于他們訓練or測試時行為一致所以其實無所謂)究抓,然后load_weights之后再訓練猾担,這樣BN層只會輸出舊模型的滑動平均值作為參數(shù),都不會參與訓練了刺下。

三绑嘹、訓練或推理時

以keras.model.Model().predict()為例:

        # Prepare inputs, delegate logic to `_predict_loop`.
        if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
            ins = x + [0.]
        else:
            ins = x
        self._make_predict_function()
        f = self.predict_function
        return self._predict_loop(f, ins, batch_size=batch_size,
                                  verbose=verbose, steps=steps)

而其中的self.uses_learning_phase來自model類繼承的container類:

    @property
    def uses_learning_phase(self):
        return any([x._uses_learning_phase for x in self.outputs])

而這個方法是去判斷container里的layer是否有含有x._uses_learning_phase屬性。(BN橘茉、dropout層就有這個屬性工腋,上面提到了)
因此if self.uses_learning_phase and not isinstance(K.learning_phase(), int):這句的含義是若model(里的某些layer)用到了learning_phase姨丈,且當前graph的learning_phase()不是int,后面半句判斷實際上是在判斷是否model在創(chuàng)建之前使用過set_learning_phase,因為一旦set過0或1夷蚊,那么模型實際上就會被剪掉train或test的分支构挤,指定不同的learning_phase也就沒了意義,實際上此時根本不能指定learning_phase()了惕鼓,因為此時模型內(nèi)部的K.learning_phase()就是一個int筋现。
因此只有滿足這兩個條件(模型里用到了dropout或BN層、且模型創(chuàng)建前未set_learning_phase)才會給輸入多添加一個值:0. 這個值代表的是feed_dict中喂給K.learning_phase()這個bool placeholder的值為0箱歧,代表測試階段矾飞。然后graph運行的時候會進入推理分支。
self._make_predict_function()函數(shù)會動態(tài)創(chuàng)建實際上的預測函數(shù)呀邢,根據(jù)需不需要傳入learning_phase創(chuàng)建不同的需要喂入的feed_dict洒沦。

訓練階段同理。

因此我們可以同樣使用如下tf函數(shù)价淌,來獲取輸出的值申眼,分別是train分支與推理分支,結果與model.predict相同:

model_ = ResNet50(include_top=False, pooling='avg', weights='imagenet')
print(K.learning_phase())  # Tensor("bn_conv1/keras_learning_phase:0", shape=(), dtype=bool)
sess = K.get_session()
preds = sess.run(net_model.get_output_at(0), feed_dict={net_model.get_input_at(0): x_input, 
              sess.graph.get_tensor_by_name('bn_conv1/keras_learning_phase:0':1)})
print('before constantize output:', np.array(preds).squeeze()[:10])

preds = sess.run(net_model.get_output_at(0), feed_dict={net_model.get_input_at(0): x_input, 
              sess.graph.get_tensor_by_name('bn_conv1/keras_learning_phase:0':0)})
print('before constantize output:', np.array(preds).squeeze()[:10])  # 與model.predict相同

當然推理時可省略learning_phase的傳入蝉衣,因為這個placeholder默認就是False括尸。
因此我們可通過下面語句獲取中間某些node的輸出值來調(diào)試網(wǎng)絡:

preds = sess.run(sess.graph.get_tensor_by_name('bn_conv1/batchnorm/add_1:0'),     
                    feed_dict={net_model.get_input_at(0): x_input})
print('before constantize bn_conv1/batchnorm/add_1:0:', np.array(preds).squeeze()[0,0,:10])

注意,K.learning_phase()的name不是固定的病毡,而是看第一次在哪里調(diào)用它濒翻,name的前綴會不同。這個案例中啦膜,在模型創(chuàng)建前并未調(diào)用過它有送,因此它是在第一個BN里才用到,那里的name_scope下第一次初始化創(chuàng)建僧家,因此全名里帶prefix是bn_conv1/keras_learning_phase:0雀摘。若在model創(chuàng)建前就調(diào)用過K.learning_phase()則模型里存儲的該tensor.name=keras_learning_phase:0,模型創(chuàng)建的時候直接就去調(diào)用它了八拱。

四届宠、總結

Keras默認情況下K.learning_phase()返回一個全局的placeholder_with_default(False),Keras使用這個輸入量來控制模型到底是train/eval階段乘粒,關系到dropout和BN層的狀態(tài)豌注。

  • Keras.model.fit()等方法默認會構造一個feed_dict:{K.learning_phase():1}喂入模型,而predict等方法同理會喂入0,這樣就告訴了BN層或dropout層此時應該使用graph里的哪個分支。
  • tf的靜態(tài)圖特性決定了必須使用placeholder這種機制創(chuàng)建模型后蚤告,才能根據(jù)輸入量切換train/eval階段。一旦提前set_learning_phase(1)將使創(chuàng)建出來的模型永遠只擁有train這一個分支齿风,后續(xù)將無法更改模型的分支药薯,因為創(chuàng)建的模型里只有那一個分支。有分支的前提是K.learning_phase()是一個待輸入量placeholder救斑。
  • 不僅如此童本,將BN或dropout層的training參數(shù)設為True或False,同樣也會引發(fā)這種現(xiàn)象脸候,即設為True后構建的graph就只存在使用dropout層這一條路子穷娱,設為False則表示完全不用,將會忽略K.learning_phase()的取值运沦。而BN的情況更加復雜泵额,可參考文章開頭的知乎鏈接。
最后編輯于
?著作權歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末携添,一起剝皮案震驚了整個濱河市嫁盲,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌烈掠,老刑警劉巖羞秤,帶你破解...
    沈念sama閱讀 219,270評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異左敌,居然都是意外死亡锥腻,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,489評論 3 395
  • 文/潘曉璐 我一進店門母谎,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人京革,你說我怎么就攤上這事奇唤。” “怎么了匹摇?”我有些...
    開封第一講書人閱讀 165,630評論 0 356
  • 文/不壞的土叔 我叫張陵咬扇,是天一觀的道長。 經(jīng)常有香客問我廊勃,道長懈贺,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,906評論 1 295
  • 正文 為了忘掉前任坡垫,我火速辦了婚禮梭灿,結果婚禮上,老公的妹妹穿的比我還像新娘冰悠。我一直安慰自己堡妒,他們只是感情好,可當我...
    茶點故事閱讀 67,928評論 6 392
  • 文/花漫 我一把揭開白布溉卓。 她就那樣靜靜地躺著皮迟,像睡著了一般搬泥。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上伏尼,一...
    開封第一講書人閱讀 51,718評論 1 305
  • 那天忿檩,我揣著相機與錄音,去河邊找鬼爆阶。 笑死燥透,一個胖子當著我的面吹牛,可吹牛的內(nèi)容都是我干的扰她。 我是一名探鬼主播兽掰,決...
    沈念sama閱讀 40,442評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼徒役!你這毒婦竟也來了孽尽?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,345評論 0 276
  • 序言:老撾萬榮一對情侶失蹤忧勿,失蹤者是張志新(化名)和其女友劉穎杉女,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體鸳吸,經(jīng)...
    沈念sama閱讀 45,802評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡熏挎,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,984評論 3 337
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了晌砾。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片坎拐。...
    茶點故事閱讀 40,117評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖养匈,靈堂內(nèi)的尸體忽然破棺而出哼勇,到底是詐尸還是另有隱情,我是刑警寧澤呕乎,帶...
    沈念sama閱讀 35,810評論 5 346
  • 正文 年R本政府宣布积担,位于F島的核電站,受9級特大地震影響猬仁,放射性物質(zhì)發(fā)生泄漏帝璧。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,462評論 3 331
  • 文/蒙蒙 一湿刽、第九天 我趴在偏房一處隱蔽的房頂上張望的烁。 院中可真熱鬧,春花似錦诈闺、人聲如沸撮躁。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,011評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽把曼。三九已至杨帽,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間嗤军,已是汗流浹背注盈。 一陣腳步聲響...
    開封第一講書人閱讀 33,139評論 1 272
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留叙赚,地道東北人老客。 一個月前我還...
    沈念sama閱讀 48,377評論 3 373
  • 正文 我出身青樓,卻偏偏與公主長得像震叮,于是被迫代替她去往敵國和親胧砰。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 45,060評論 2 355