使用TensorFlow訓(xùn)練一個(gè)模型讥脐,可以多次運(yùn)行訓(xùn)練操作攒至,并在完成后保存訓(xùn)練參數(shù)的檢查點(diǎn)(checkpoint)。這對(duì)能夠在幾個(gè)小時(shí)內(nèi)訓(xùn)練的小模型很有效峻黍。但是如果是訓(xùn)練的數(shù)據(jù)量比較大,可能需要訓(xùn)練幾天或者幾個(gè)月拨匆。姆涩。。
那原生的tensorflow的健壯性可能就比較堪憂(yōu)惭每。阵面。。
萬(wàn)一斷電了之類(lèi)洪鸭。。仑扑。
這時(shí)候我們就可以使用supervisor
需要長(zhǎng)時(shí)間訓(xùn)練的較大模型览爵,需要更魯棒(robust)的訓(xùn)練過(guò)程:
- 能處理關(guān)機(jī)以及徹底崩潰的情況。
- 可以在關(guān)機(jī)或崩潰后恢復(fù)镇饮。
- 可以通過(guò)TensorBoard進(jìn)行監(jiān)控蜓竹。
為了能夠在停機(jī)或崩潰后恢復(fù)訓(xùn)練,訓(xùn)練過(guò)程必須周期保存檢查點(diǎn)储藐。在重新啟動(dòng)時(shí)俱济,它必須查找最新的檢查點(diǎn),并在恢復(fù)訓(xùn)練之前加載它钙勃。supervisor可以看做一個(gè)工具蛛碌,或者說(shuō)是對(duì)原生tensorflow的一層封裝,目的主要是通過(guò)定期save的方法增強(qiáng)訓(xùn)練健壯性辖源,
就算程序掛掉了也可以從上一次save的checkpoint恢復(fù)蔚携,而不是從頭再來(lái)(雖然這些也可以手動(dòng)實(shí)現(xiàn),同時(shí)也可以簡(jiǎn)化代碼量
tf.train.Supervisor
提供了一套有助于實(shí)施魯棒的訓(xùn)練過(guò)程的服務(wù)克饶。除了supervisor
,還有tf.learn
庫(kù)酝蜒,里面提供對(duì)原生tensorflow
更高層的封裝,也提供更豐富的功能矾湃。
請(qǐng)注意亡脑,Supervisor對(duì)訓(xùn)練大模型非常有幫助,但也可以用于較小型號(hào)邀跃,不會(huì)有任何不好的地方霉咨。
supervisor可以看做一個(gè)工具,或者說(shuō)是對(duì)原生tensorflow的一層封裝坞嘀,目的主要是通過(guò)定期save的方法增強(qiáng)訓(xùn)練健壯性躯护。
1.一個(gè)簡(jiǎn)單方案
使用Supervisor的最簡(jiǎn)單的方案是:
創(chuàng)建一個(gè)
Supervisor
對(duì)象,將其傳遞到保存檢查點(diǎn)和summary的目錄丽涩。用
tf.train.Supervisor.managed_session
向Supervisor
請(qǐng)求一個(gè)會(huì)話(huà)(session)棺滞。使用會(huì)話(huà)執(zhí)行訓(xùn)練操作裁蚁,如果Supervisor要求訓(xùn)練停止,請(qǐng)檢查每一步继准。
...create graph...
my_train_op = ...
sv = tf.train.Supervisor(logdir="/my/training/directory")
with sv.managed_session() as sess:
for step in range(100000):
if sv.should_stop():
break
sess.run(my_train_op)
開(kāi)始服務(wù)
managed_session()
啟動(dòng)一些服務(wù)枉证,它們?cè)谧约旱木€程中運(yùn)行,并利用managed session在圖中運(yùn)行各種操作移必。
如果圖中包含一個(gè)名為global_step
的整型變量室谚,則服務(wù)使用其值來(lái)測(cè)量執(zhí)行的訓(xùn)練步驟數(shù)量。有關(guān)如何創(chuàng)建global_step
變量崔泵,請(qǐng)參閱MNIST訓(xùn)練教程秒赤。
檢查點(diǎn)服務(wù):在logdir中保存圖形變量的副本。
global_step
如果添加到您的圖中憎瘸,則檢查點(diǎn)文件名將使用該變量的值入篮。默認(rèn)運(yùn)行10分鐘。summary服務(wù):運(yùn)行所有summary操作幌甘,并將其輸出附加到logdir 中的 事件文件中潮售。默認(rèn)情況下每2分鐘運(yùn)行一次。
步驟計(jì)數(shù)器:通過(guò)查看
global_step
變量的更改來(lái)計(jì)算執(zhí)行了多少步锅风。向事件文件追加一個(gè)summary酥诽,報(bào)告每秒鐘的全局步數(shù)。 summary tag 為“global_step / sec”皱埠。這也默認(rèn)每2分鐘運(yùn)行一次肮帐。Queue Runners:如果
tf.train.QueueRunner
添加到圖形中,Supervisor將在自己的線程中啟動(dòng)它們边器。
構(gòu)建Supervisor對(duì)象時(shí)可以更改所有時(shí)間間隔泪姨。有關(guān)詳細(xì)信息,請(qǐng)參閱Supervisor參考饰抒。
檢查停止
在主訓(xùn)練循環(huán)中對(duì)停止的檢查是重要和必要的肮砾。
在服務(wù)線程中引發(fā)的異常報(bào)告給Supervisor,然后將其should_stop()
條件設(shè)置為true袋坑。其他服務(wù)線程告知此情形并合理終止仗处。managed_session()
塊內(nèi)的主訓(xùn)練循環(huán) 還必須檢查停止條件并終止。
請(qǐng)注意managed_session()
捕獲從訓(xùn)練循環(huán)中引發(fā)的異常情況枣宫,將其報(bào)告給Supervisor婆誓。主循環(huán)不需要對(duì)異常做任何特別的處理。它只需要檢查停止條件也颤。
復(fù)蘇
如果訓(xùn)練程序關(guān)閉或崩潰洋幻,其最新的檢查點(diǎn)和事件文件將留在logdir中。當(dāng)重新啟動(dòng)程序時(shí)翅娶, managed_session()
從最近的檢查點(diǎn)恢復(fù)圖形文留,并恢復(fù)停止的訓(xùn)練好唯。
創(chuàng)建一個(gè)新的事件文件。如果啟動(dòng)TensorBoard并將其指向logdir燥翅,它將會(huì)知道如何合并兩個(gè)事件文件的內(nèi)容骑篙,并將在檢查點(diǎn)的最后一個(gè)全局步驟中顯示訓(xùn)練恢復(fù)。
2.較大的模式場(chǎng)景
最簡(jiǎn)單的情景已經(jīng)足以處理大多數(shù)小到中模型的訓(xùn)練森书。更大的模型也許會(huì)在運(yùn)行summary sevice的時(shí)候耗盡內(nèi)存:summary ops是與main loop中的train op一起并行地run的靶端。這會(huì)導(dǎo)致內(nèi)存使用達(dá)到通常使用的兩倍多。
對(duì)于打得模型你可以通知supervisor不要運(yùn)行summary服務(wù)凛膏,作為替代杨名,你在自己的主訓(xùn)練循環(huán)中來(lái)運(yùn)行:創(chuàng)建supervisor的時(shí)候傳遞summary_op=None。
例如猖毫,該代碼在訓(xùn)練循環(huán)中每100個(gè)步驟運(yùn)行摘要:
...create graph...
my_train_op = ...
my_summary_op = tf.summary.merge_all()
sv = tf.train.Supervisor(logdir="/my/training/directory",
summary_op=None) # Do not run the summary service
with sv.managed_session() as sess:
for step in range(100000):
if sv.should_stop():
break
if step % 100 == 0:
_, summ = session.run([my_train_op, my_summary_op])
sv.summary_computed(sess, summ)
else:
session.run(my_train_op)
預(yù)訓(xùn)練的模型情景
managed_session()
調(diào)用很關(guān)心在session
中初始化模型镣煮。模型會(huì)在可能的時(shí)候從一個(gè)checkpoint
中加載,亦或從scratch
中初始化鄙麦。
一個(gè)常見(jiàn)的情景是要用加載的預(yù)訓(xùn)練的checkpoint
來(lái)初始化模型,而該預(yù)訓(xùn)練模型和當(dāng)前模型有些許的不同镊折。
你可以通過(guò)給supervisor
傳遞init function
的方式來(lái)加載預(yù)訓(xùn)練的checkpoint
胯府。這個(gè)函數(shù)只有在模型需要從scratch
初始化時(shí)才被調(diào)用,而模型從logdir
中的checkpoint恢復(fù)的時(shí)候并不會(huì)恨胚。
為了加載預(yù)訓(xùn)練模型骂因,init
函數(shù)需要一個(gè)tf.train.Saver
對(duì)象,所以你應(yīng)該創(chuàng)建一個(gè)saver
赃泡。新模型也許包含一些預(yù)訓(xùn)練的checkpoin
t中不存在的變量寒波,所以這是一個(gè)很好的思想:這個(gè)saver
必須只加載預(yù)訓(xùn)練的變量。如果你正在使用默認(rèn)的saver
升熊,你會(huì)在嘗試加載所有變量的時(shí)候得到一個(gè)錯(cuò)誤俄烁。
...create graph...
my_train_op = ...
my_summary_op = tf.summary.merge_all()
sv = tf.train.Supervisor(logdir="/my/training/directory",
summary_op=None) # Do not run the summary service
with sv.managed_session() as sess:
for step in range(100000):
if sv.should_stop():
break
if step % 100 == 0:
_, summ = session.run([my_train_op, my_summary_op])
sv.summary_computed(sess, summ)
else:
session.run(my_train_op)
運(yùn)行你自己的服務(wù)
Supervisor
服務(wù),比如checkpointing
服務(wù)级野,與主訓(xùn)練循環(huán)并行運(yùn)行页屠。有時(shí)候你想加入你自己的服務(wù),比如取出和 通常的summary
的schedule
不一樣的不同設(shè)置的summaries
蓖柔。
使用supervisor中的tf.train.Supervisor.loop
來(lái)達(dá)成這個(gè)目的辰企。它會(huì)根據(jù)你選擇的定時(shí)器重復(fù)地調(diào)用一個(gè)函數(shù),直到supervisor
的stop condition
為true况鸣,所以它和其他服務(wù)很協(xié)調(diào)牢贸。
例如:每20分鐘調(diào)用一次my_additional_summaries():
def my_additional_sumaries(sv, sess):
...fetch and write summaries, see below...
...
sv = tf.train.Supervisor(logdir="/my/training/directory")
with sv.managed_session() as sess:
# Call my_additional_sumaries() every 1200s, or 20mn,
# passing (sv, sess) as arguments.
sv.loop(1200, my_additional_sumaries, args=(sv, sess))
...main training loop...
寫(xiě)summaries
supervisor
總是在其logdir
中生成一個(gè)事件文件,同時(shí)用一個(gè)tf.summary.FileWriter
將事件和summaries
添加到事件文件镐捧。如果你想寫(xiě)自己的summaries
潜索,也可以將它們添加到同一個(gè)事件文件中去:TensorBoard很喜歡在目錄中只有一個(gè)事件文件臭增。
supervisor
提供了一個(gè)輔助函數(shù)來(lái)添加summaries
:tf.train.Supervisor.summary_computed
:只需要傳遞一份summary_op
的返回輸出函數(shù)。以下是使用該函數(shù)實(shí)現(xiàn)之前例子中my_additional_sumaries()
的例子:
def my_additional_sumaries(sv, sess):
summaries = sess.run(my_additional_summary_op)
sv.summary_computed(sess, summaries)
更多前沿的用法參看tf.train.Supervisor.summary_writer屬性帮辟。
supervisor 參考
在簡(jiǎn)單的情景以及更大的模型方案的情景展示了supervisor的基本用法速址。更高級(jí)的情景可以用supervisor提供的很多選項(xiàng)來(lái)創(chuàng)建。
Checkpointing:何時(shí)何處
managed_session()
調(diào)用開(kāi)啟了checkpointing
服務(wù)由驹,而這可以通過(guò)對(duì)Supervisor()創(chuàng)建時(shí)以下的參數(shù)來(lái)配置:
- logdir: checkpointing服務(wù)床創(chuàng)建checkpoints的目錄路徑芍锚。如果需要,創(chuàng)建該目錄蔓榄。傳遞None禁用checkpointing以及summary服務(wù)并炮。
- checkpoint_basename: 欲創(chuàng)建的checkpoint文件的名稱(chēng),默認(rèn)為”model.ckpt”甥郑。
如果模型包含一個(gè)名為的標(biāo)量整數(shù)變量global_step逃魄,則該變量的值將附加到檢查點(diǎn)文件名。
例如澜搅,在global_step 1234伍俘,checkpoint 文件名就是 “model.ckpt-1234”。
- save_model_secs: 每個(gè)checkpoint之間的秒數(shù)勉躺。默認(rèn)為600癌瘾,即10分鐘。
當(dāng)選擇一個(gè)值時(shí)饵溅,要考慮一旦有crash時(shí)你要丟失多少工作:你永遠(yuǎn)不會(huì)丟失多于save_model_secs秒的工作妨退。設(shè)置為0就禁用了checkpointing服務(wù)。
- saver: 一個(gè)tf.train.Saver對(duì)象蜕企,用來(lái)checkpointing咬荷。
如果不傳遞saver,supervisor會(huì)調(diào)用tf.train.Saver()來(lái)創(chuàng)建一個(gè)轻掩,該saver會(huì)把所有的ops保存幸乒,并加載你模型中所有的變量。你通常也需要這么做唇牧。
示例:每30秒使用自定義保護(hù)程序和檢查點(diǎn)逝变。
...create graph...
my_saver = tf.train.Saver(<only some variables>)
sv = tf.train.Supervisor(logdir="/my/training/directory",
saver=my_saver,
save_model_secs=30)
with sv.managed_session() as sess:
...training loop...
Summaries:何時(shí)何處
類(lèi)似checkpointing,logdir對(duì)summaries的作用也是一樣的奋构。事件文件在此創(chuàng)建壳影,如果None則禁用了summary服務(wù)。
save_summaries_secs:該參數(shù)代表每次運(yùn)行summary sevice服務(wù)的間隔的秒數(shù)弥臼。默認(rèn)為120秒宴咧,即兩分鐘。同樣径缅,設(shè)置為0時(shí)則禁用了summary服務(wù)掺栅。
-
summary_op烙肺,用來(lái)取得summaries的op。
如果沒(méi)指定氧卧,supervisor會(huì)使用
tf.GraphKeys.SUMMARY_OP
圖集合(graph collection)中第一個(gè)op桃笙。如果該集合為空,supervisor則創(chuàng)建一個(gè)op沙绝,它會(huì)將圖中的所有summaries使用tf.summary.merge_all()
聚集在一起搏明。如果給summary_op傳遞None則禁用了summary服務(wù)。
-
global_step:用來(lái)計(jì)算全局步數(shù)的張量闪檬。
如果沒(méi)有指明星著,supervisor使用
tf.GraphKeys.GLOBAL_STEP
圖集合(graph collection)中第一個(gè)tensor,如果該集合為空粗悯,supervisor在圖中尋找一個(gè)name為global_step的整型的變量的標(biāo)量虚循。
如果找到,global step張量被用來(lái)衡量訓(xùn)練步數(shù)執(zhí)行的數(shù)量样傍。注意横缔,你的訓(xùn)練op會(huì)增加global step的值。
模型的初始化和恢復(fù)
managed_session()
調(diào)用野專(zhuān)注于初始化以及恢復(fù)一個(gè)session衫哥。它返回一個(gè)session同時(shí)伴隨一個(gè)全部初始化了的模型茎刚,準(zhǔn)備去訓(xùn)練。如果managed_session()
調(diào)用時(shí)logdir里有一個(gè)checkpoint炕檩,模型會(huì)通過(guò)加載checkpoint初始化,否則會(huì)通過(guò)調(diào)用一個(gè)初始化op或者選擇一個(gè)init function捌斧。
如果沒(méi)有可用的checkpoint笛质,模型的初始化則有下面的參數(shù)傳遞給supervisor()的創(chuàng)建器來(lái)控制:
-
init_op: 需要被運(yùn)行來(lái)初始化模型的op。
如果沒(méi)有指定捞蚂,supervisor會(huì)使用tf.GraphKeys.INIT_OP圖集合( collection)中第一個(gè)op妇押。如果集合是空的,則會(huì)通過(guò)調(diào)用tf.global_variables_initializer()添加一個(gè)初始化所有變量的op姓迅。
傳遞None則不適用初始化op敲霍。
-
init_fn: 調(diào)用它來(lái)初始化模型。
如果指定則這樣調(diào)用 :init_fn(sess)丁存,這里的sess是managed session肩杈。如果init op同時(shí)使用,則init function在init op之后被調(diào)用解寝。
-
local_init_op: 一個(gè)額外的op扩然,用來(lái)初始化圖段一部分,這部分沒(méi)有被保存在checkpoints中聋伦。比如比如tables以及一些local variables夫偶。local init op在init op以及 init function之后運(yùn)行界睁。
如果沒(méi)有指定,supervisor使用tf.GraphKeys.LOCAL_INIT_OP集合里的第一個(gè)op兵拢。如果集合為空翻斟,則通過(guò)調(diào)用tf.tables_initializer() 和 tf.local_variables_initializer()添加一初始化所有tables以及l(fā)ocal variables的op。
傳遞None禁用local init op说铃。
ready_op: 核查模型是否被初始化的op访惜。
運(yùn)行了local init op,init op以及init function之后截汪,supervisor會(huì)通過(guò)執(zhí)行ready op來(lái)驗(yàn)證模型是否被完全初始化疾牲。如果初始化則該op返回空字符串,否則返回模型那部分未被初始化的一個(gè)描述衙解。
如果未指定阳柔,supervisor會(huì)使用tf.GraphKeys.READY_OP 集合中的第一個(gè)op。若集合未空蚓峦,supervisosr通過(guò)調(diào)用tf.report_uninitialized_variables()創(chuàng)建一個(gè)ready op來(lái)確保所有變量都被初始化舌剂。
傳遞None來(lái)禁用ready op。在這種情況下模型初始化之后不進(jìn)行核查暑椰。
checkpoint的恢復(fù)是由以下傳給superfisor()創(chuàng)建器的參數(shù)控制:
-
logdir:尋找checkpoints的路徑霍转。checkpoint服務(wù)保存了一個(gè)metadata文件,名為 “checkpoint”一汽,在這個(gè)checkpoint目錄中指明最近的一個(gè)checkpoint的路徑避消。
這個(gè)文件是文本格式的。你可以手工編輯它來(lái)從一個(gè)不同于最近的checkpoint的checkpoint中恢復(fù)召夹。
ready_op:和上面的一樣岩喷。ready op在加載checkpoint之前和之后運(yùn)行。第一次運(yùn)行檢查模型是否需要被初始化监憎,第二次驗(yàn)證模型完全被初始化纱意。
local_init_op:和上面的一樣。local init op在第一次運(yùn)行ready op之前運(yùn)行鲸阔,來(lái)初始化局部變量以及tables偷霉。
saver:和上面的一樣。用來(lái)加載checkpoint的的Saver對(duì)象褐筛。