tf.train.SessionRunHook 讓 estimator 訓(xùn)練過程可以個性化定制

estimator

estimator 是 tensorflow 提供的使用非常方便的模型封裝。estimator 中提供了許多內(nèi)置的模型涵但,例如 LinearClassifier姿鸿、DNNLinearCombinedClassifier犬耻、LinearRegressor等眶蕉。用戶也可以通過 model_fn 定制模型結(jié)構(gòu)。在 estimator 對象的基礎(chǔ)上任何模型都可以直接調(diào)用 train 和 eval 函數(shù)進行訓(xùn)練和測試蛛淋,用戶無需手動地創(chuàng)建 session 和 run session咙好。estimator 的具體使用方式可以參考[1]。
estimator.png

dataset

tensorflow 底層 API 中都是使用 placeholder 和 feed_dict 向模型輸入數(shù)據(jù)的铣鹏,這樣的方式效率較低敷扫。我們可以利用 dataset 庫哀蘑,這里提供了高效讀取數(shù)據(jù)并且輸入給模型訓(xùn)練的方式诚卸。

可以直接用 numpy 數(shù)組創(chuàng)建 dataset。直接用數(shù)組創(chuàng)建 dataset 的一個問題是 tensorflow 會直接把 dataset 中的數(shù)據(jù)寫到 graph 中绘迁,當數(shù)據(jù)量較大時會報錯合溺,因為 graph 在序列化到 pb 文件時現(xiàn)在最大2GB。

def input_fn():
  features, labels = (np.random.sample((100,2)), np.random.sample((100,1)))
  dataset = tf.data.Dataset.from_tensor_slices((features,labels))
  dataset = dataset.shuffle(100000).repeat().batch(batch_size)
  return dataset

...
estimator.train(input_fn)

為了在大數(shù)據(jù)量時使用 dataset缀台,我們可以用 placeholder 創(chuàng)建 dataset棠赛。這時數(shù)據(jù)就不會直接寫到 graph 中,graph 中只有一個 placeholder 占位符膛腐。但是睛约,用了 placeholder 就需要我們在一開始對它進行初始化填數(shù)據(jù),需要調(diào)用 sess.run(iter.initializer, feed_dict={ x: data })哲身。更多關(guān)于 dataset 的使用介紹可以參考文獻[2]辩涝。

def input_fn():
  x = tf.placeholder(tf.float32, shape=[None,2])
  dataset = tf.data.Dataset.from_tensor_slices(x)
  dataset = dataset.shuffle(100000).repeat().batch(batch_size)
  iter = dataset.make_initializable_iterator()
  return iter.get_next()

SessionRunHook

既然前面說到 estimator 是 tensorflow 對模型的一種封裝,我們不需要也無法拿到訓(xùn)練和測試時創(chuàng)建的 session勘天,那么我們?nèi)绾卧?estimator 中對上一節(jié)使用 placeholder 的 dataset 的 initializeble_iterator 調(diào)用 sess.run 進行初始化呢怔揩?這時候就要用到 SessionRunHook 了。
先從字面意思理解一下 SessionRunHook 這個類脯丝。Session 就是 tensorflow 運行模型計算時的會話商膊,Run就是整個 session 運行過程,Hook 是掛鉤的意思即把某些事情掛在這個對象上可以理解為回調(diào)宠进。

再看一下 SessionRunHook 源碼[3]中的定義:
A SessionRunHook extends session.run() calls for the MonitoredSession.
SessionRunHooks are useful to track training, report progress, request early
stopping and more. SessionRunHooks use the observer pattern and notify at the
following points:

  • when a session starts being used
  • before a call to the session.run()
  • after a call to the session.run()
  • when the session closed
class SessionRunHook(object):
  """Hook to extend calls to MonitoredSession.run()."""

  def begin(self):
    """Called once before using the session.
    When called, the default graph is the one that will be launched in the
    session.  The hook can modify the graph by adding new operations to it.
    After the `begin()` call the graph will be finalized and the other callbacks
    can not modify the graph anymore. Second call of `begin()` on the same
    graph, should not change the graph.
    """
    pass

  def after_create_session(self, session, coord):  # pylint: disable=unused-argument
    """Called when new TensorFlow session is created.
    This is called to signal the hooks that a new session has been created. This
    has two essential differences with the situation in which `begin` is called:
    * When this is called, the graph is finalized and ops can no longer be added
        to the graph.
    * This method will also be called as a result of recovering a wrapped
        session, not only at the beginning of the overall session.
    Args:
      session: A TensorFlow Session that has been created.
      coord: A Coordinator object which keeps track of all threads.
    """
    pass

  def before_run(self, run_context):  # pylint: disable=unused-argument
    """Called before each call to run().
    You can return from this call a `SessionRunArgs` object indicating ops or
    tensors to add to the upcoming `run()` call.  These ops/tensors will be run
    together with the ops/tensors originally passed to the original run() call.
    The run args you return can also contain feeds to be added to the run()
    call.
    The `run_context` argument is a `SessionRunContext` that provides
    information about the upcoming `run()` call: the originally requested
    op/tensors, the TensorFlow Session.
    At this point graph is finalized and you can not add ops.
    Args:
      run_context: A `SessionRunContext` object.
    Returns:
      None or a `SessionRunArgs` object.
    """
    return None

  def after_run(self,
                run_context,  # pylint: disable=unused-argument
                run_values):  # pylint: disable=unused-argument
    """Called after each call to run().
    The `run_values` argument contains results of requested ops/tensors by
    `before_run()`.
    The `run_context` argument is the same one send to `before_run` call.
    `run_context.request_stop()` can be called to stop the iteration.
    If `session.run()` raises any exceptions then `after_run()` is not called.
    Args:
      run_context: A `SessionRunContext` object.
      run_values: A SessionRunValues object.
    """
    pass

  def end(self, session):  # pylint: disable=unused-argument
    """Called at the end of session.
    The `session` argument can be used in case the hook wants to run final ops,
    such as saving a last checkpoint.
    If `session.run()` raises exception other than OutOfRangeError or
    StopIteration then `end()` is not called.
    Note the difference between `end()` and `after_run()` behavior when
    `session.run()` raises OutOfRangeError or StopIteration. In that case
    `end()` is called but `after_run()` is not called.
    Args:
      session: A TensorFlow Session that will be soon closed.
    """
    pass

我們看到 SessionRunHook 源碼中為 5 中不同的事件提供了回調(diào)函數(shù)晕拆,用戶只需要繼承 SessionRunHook 這個類并且具體實現(xiàn)想要的回調(diào)函數(shù)即可,具體用法看下一節(jié)材蹬。

estimator 結(jié)合 SessionRunHook 實現(xiàn) placeholder 初始化

仔細看一下 estimator 的 train 和 evaluate 函數(shù)定義可以發(fā)現(xiàn)它們都接收 hooks 參數(shù)潦匈,這個參數(shù)的定義是:List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the training loop. 就是上一節(jié)提到的用戶繼承自 SessionRunHook 的類的實例對象列表。

train(
    input_fn,
    hooks=None,
    steps=None,
    max_steps=None,
    saving_listeners=None
)

我們現(xiàn)在想要在訓(xùn)練之前初始化 dataset 的 placeholder赚导,那么我們就應(yīng)該具體實現(xiàn) SessionRunHook 的after_create_session 成員函數(shù):

class IteratorInitializerHook(tf.train.SessionRunHook):
   def __init__(self):
       super(IteratorInitializerHook, self).__init__()
       self.iterator_initializer_fn = None

   def after_create_session(self, session, coord):
       del coord
       self.iterator_initializer_fn(session)

def make_input_fn():
   iterator_initializer_hook = IteratorInitializerHook()

   def input_fn():
       x = tf.placeholder(tf.float32, shape=[None,2])
       dataset = tf.data.Dataset.from_tensor_slices(x)
       dataset = dataset.shuffle(100000).repeat().batch(batch_size)
       iter = dataset.make_initializable_iterator()
       data = np.random.sample((100,2))
       iterator_initializer_hook.iterator_initializer_fn = (
           lambda sess: sess.run(iter.initializer, feed_dict={x: data})
       )
       return iter.get_next()
   return input_fn, iterator_initializer_hook

...
input_fn, iterator_initializer_hook = make_input_fn()
estimator.train(input_fn, hooks=[iterator_initializer_hook])

當然茬缩,SessionRunHook 不光能用在初始化上,還有許多應(yīng)用場景吼旧,可以參考源碼[3]中提供的幾個內(nèi)置 Hook 和文獻[4]凰锡。

[1] https://github.com/tensorflow/models/tree/master/samples/core/get_started
[2] https://www.jiqizhixin.com/articles/03137
[3] https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/session_run_hook.py
[4] https://blog.csdn.net/mrr1ght/article/details/81011280

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子掂为,更是在濱河造成了極大的恐慌裕膀,老刑警劉巖,帶你破解...
    沈念sama閱讀 219,539評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件勇哗,死亡現(xiàn)場離奇詭異昼扛,居然都是意外死亡,警方通過查閱死者的電腦和手機欲诺,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,594評論 3 396
  • 文/潘曉璐 我一進店門抄谐,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人扰法,你說我怎么就攤上這事蛹含。” “怎么了塞颁?”我有些...
    開封第一講書人閱讀 165,871評論 0 356
  • 文/不壞的土叔 我叫張陵浦箱,是天一觀的道長。 經(jīng)常有香客問我祠锣,道長酷窥,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,963評論 1 295
  • 正文 為了忘掉前任伴网,我火速辦了婚禮蓬推,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘是偷。我一直安慰自己拳氢,他們只是感情好,可當我...
    茶點故事閱讀 67,984評論 6 393
  • 文/花漫 我一把揭開白布蛋铆。 她就那樣靜靜地躺著馋评,像睡著了一般。 火紅的嫁衣襯著肌膚如雪刺啦。 梳的紋絲不亂的頭發(fā)上留特,一...
    開封第一講書人閱讀 51,763評論 1 307
  • 那天,我揣著相機與錄音玛瘸,去河邊找鬼蜕青。 笑死,一個胖子當著我的面吹牛糊渊,可吹牛的內(nèi)容都是我干的右核。 我是一名探鬼主播,決...
    沈念sama閱讀 40,468評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼渺绒,長吁一口氣:“原來是場噩夢啊……” “哼贺喝!你這毒婦竟也來了菱鸥?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,357評論 0 276
  • 序言:老撾萬榮一對情侶失蹤躏鱼,失蹤者是張志新(化名)和其女友劉穎氮采,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體染苛,經(jīng)...
    沈念sama閱讀 45,850評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡鹊漠,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,002評論 3 338
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了茶行。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片躯概。...
    茶點故事閱讀 40,144評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖拢军,靈堂內(nèi)的尸體忽然破棺而出楞陷,到底是詐尸還是另有隱情怔鳖,我是刑警寧澤茉唉,帶...
    沈念sama閱讀 35,823評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站结执,受9級特大地震影響度陆,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜献幔,卻給世界環(huán)境...
    茶點故事閱讀 41,483評論 3 331
  • 文/蒙蒙 一懂傀、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧蜡感,春花似錦蹬蚁、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,026評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至情连,卻和暖如春叽粹,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背却舀。 一陣腳步聲響...
    開封第一講書人閱讀 33,150評論 1 272
  • 我被黑心中介騙來泰國打工虫几, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人挽拔。 一個月前我還...
    沈念sama閱讀 48,415評論 3 373
  • 正文 我出身青樓辆脸,卻偏偏與公主長得像,于是被迫代替她去往敵國和親螃诅。 傳聞我的和親對象是個殘疾皇子啡氢,可洞房花燭夜當晚...
    茶點故事閱讀 45,092評論 2 355

推薦閱讀更多精彩內(nèi)容