使用驗(yàn)證集判斷模型效果
為了評(píng)測(cè)神經(jīng)網(wǎng)絡(luò)模型在不同參數(shù)下的效果,一般會(huì)從訓(xùn)練集中抽取一部分作為驗(yàn)證數(shù)據(jù)。除了使用驗(yàn)證數(shù)據(jù)集,還可以采用交叉驗(yàn)證(cross validation )
的方式驗(yàn)證模型效果标锄,但是使用交叉驗(yàn)證會(huì)花費(fèi)大量的時(shí)間。但在海量數(shù)據(jù)情況下茁计,一般采用驗(yàn)證數(shù)據(jù)集的形式評(píng)測(cè)模型的效果料皇。
一般采用的驗(yàn)證數(shù)據(jù)分布越接近測(cè)試數(shù)據(jù)分布,模型在驗(yàn)證數(shù)據(jù)上的表現(xiàn)越可以體現(xiàn)模型在測(cè)試數(shù)據(jù)上的保險(xiǎn)星压。
使用滑動(dòng)平均模型和指數(shù)衰減的學(xué)習(xí)率在一定程度上都是限制神經(jīng)網(wǎng)絡(luò)中參數(shù)更新的速度践剂。
在處理復(fù)雜問題時(shí),使用滑動(dòng)平均模型娜膘、指數(shù)衰減的學(xué)習(xí)率和正則化損失可以明顯提升模型的訓(xùn)練效果逊脯。
變量管理
Tensorflow提供了通過變量名稱來創(chuàng)建或者獲取一個(gè)變量的機(jī)制,避免了復(fù)雜神經(jīng)網(wǎng)絡(luò)頻繁傳遞參數(shù)的情況竣贪。通過該機(jī)制男窟,在不同的函數(shù)中可以直接通過變量的名字來使用變量,而不需要將變量通過參數(shù)的形式到處傳遞贾富。
Tensorflow中通過變量名獲取變量的機(jī)制主要通過tf.get_variable()
和tf.variable_scope()
函數(shù)實(shí)現(xiàn)。
-
tf.get_variable()
該函數(shù)創(chuàng)建變量的方法和tf.Variable()
函數(shù)的用法基本一樣牺六,提供維度信息(shape
)以及初始化方法(initializer
)的參數(shù)颤枪。該函數(shù)的變量名稱是一個(gè)必填參數(shù),函數(shù)會(huì)根據(jù)這個(gè)名字去創(chuàng)建或者獲取變量淑际。當(dāng)已經(jīng)有同名參數(shù)時(shí)畏纲,會(huì)報(bào)錯(cuò)。 -
tf.variable_scope()
該函數(shù)可以控制tf.get_variable()
函數(shù)的語義春缕。當(dāng)tf.variable_scope()
函數(shù)使用參數(shù)reuse=True
生成上下文管理器時(shí)盗胀,這個(gè)上下文管理器內(nèi)所有的tf.get_variable()
函數(shù)會(huì)直接獲取已經(jīng)創(chuàng)建的變量。如果不存在锄贼,則報(bào)錯(cuò)票灰;當(dāng)reuse=False
或者reuse=None
創(chuàng)建上下文管理器時(shí),tf.get_variable()
操作將創(chuàng)建新的變量宅荤,如果同名變量已經(jīng)存在屑迂,則報(bào)錯(cuò)。
同時(shí)tf.variable_scope()
函數(shù)可以嵌套冯键。新建一個(gè)嵌套的上下文管理器但不指定reuse惹盼,這時(shí)的reuse的取值和外面一層保持一致。當(dāng)退出reuse設(shè)置為True的上下文之后reuse的值又回到了False(內(nèi)層reuse不設(shè)置)惫确。
同時(shí)手报,tf.variable_scope()函數(shù)生成的上下文管理器也會(huì)創(chuàng)建一個(gè)Tensorflow中的命名空間蚯舱,在命名空間內(nèi)創(chuàng)建的變量名稱都會(huì)帶上這個(gè)命名空間名作為前綴⊙诟颍可以直接通過帶命名空間名稱的變量名來獲取其它命名空間下的變量(創(chuàng)建一個(gè)名稱為空的命名空間枉昏,并設(shè)置為reuse=True)。
with tf.variable_scope(" ", reuse=True):
v5 = tf.get_variable("foo/bar/v", [1])
print(v5.name)
===>v:0 # 0表示variable這個(gè)運(yùn)算輸出的第一個(gè)結(jié)果
Tensorflow模型持久化
將訓(xùn)練得到的模型保存下來盏档,可以方便下次直接使用(避免重新訓(xùn)練花費(fèi)大量的時(shí)間)凶掰。Tensorflow提供的持久化機(jī)制可以將訓(xùn)練之后的模型保存到文件中。
Tensorflow提供了tf.train.Saver
類來保存和還原神經(jīng)網(wǎng)絡(luò)模型蜈亩。當(dāng)保存模型之后懦窘,目錄下一般會(huì)出現(xiàn)三個(gè)文件,這是因?yàn)門ensorflow會(huì)將計(jì)算圖的結(jié)構(gòu)和圖上參數(shù)值分開保存稚配。
-
model.ckpy.meta
文件畅涂,保存了Tensorflow計(jì)算圖的結(jié)構(gòu)。 -
model.ckpt
文件道川,保存了Tensorflow程序每一個(gè)變量的取值午衰。 -
checkpoint
文件,保存了一個(gè)目錄下所有的模型文件列表冒萄。
保存模型
saver = tf.train.Saver()
saver.save(sess, "path/model.ckpt")
加載模型臊岸,此時(shí)不用進(jìn)行變量的初始化過程
saver.restore(sess, "path/model.ckpt")
sess.run(result)
為了保存和加載部分變量,在聲明tf.train.Saver類時(shí)可以提供一個(gè)列表來指定需要保存或加載的變量尊流,saver = tf.train.Saver([v1])
帅戒。同時(shí),tf.train.Saver類也支持在保存或者加載時(shí)給變量重命名崖技,如果直接加載就會(huì)導(dǎo)致程序報(bào)變量找不到的錯(cuò)誤逻住,Tensorflow提供通過字典將模型保存時(shí)的變量名和要加載的變量聯(lián)系起來。
v = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
saver = tf.train.Saver({"v1": v})
將原先變量名為v1的變量加載到變量v中迎献,變量v的名稱為other-v1瞎访。
這樣做的目的時(shí)為了方便使用變量的滑動(dòng)平均值。因?yàn)槊恳粋€(gè)變量的滑動(dòng)平均值是通過影子變量維護(hù)的吁恍,如果在加載模型時(shí)直接將影子變量映射到變量自身扒秸,就不需要在調(diào)用函數(shù)來獲取變量的平均值了。
為了方便加載重命名滑動(dòng)平均變量践盼,tf.train.ExponentialMovingAverage類提供了variables_to_restore()函數(shù)來生成tf.train.Saver類所需要的變量重命名字典鸦采。
v = tf.Variable(0)
ema = tf.train.ExponentialMovingAverage(0.99)
saver = tf.train.Saver(ema.variable_to_restore())
with tf.Session() as sess:
saver.restore(sess, "path/model.ckpt")
sess.run(v)
有時(shí)候不需要類似于變量初始化、模型保存等輔助節(jié)點(diǎn)的信息咕幻,Tensorflow提供了convert_variables_to_constants()函數(shù)將計(jì)算圖中的變量及其取值通過常量的方式保存渔伯。
持久化原理及數(shù)據(jù)格式
Tensorflow程序中所有計(jì)算都會(huì)被表達(dá)為計(jì)算圖上的節(jié)點(diǎn)。
MetaGraphDef
Tensorflow通過元圖(MetaGraph)
來記錄計(jì)算圖中節(jié)點(diǎn)的信息以及運(yùn)行計(jì)算圖中節(jié)點(diǎn)所需要的元數(shù)據(jù)肄程,元圖是由MetaGraphDef Protocol Buffer
定義的锣吼,MetaGraphDef
中的內(nèi)容構(gòu)成了Tensorflow持久化的第一個(gè)文件选浑,也就是model.ckpt.meta
文件。
-
meta_info_def
屬性玄叠,記錄了Tensorflow計(jì)算圖中的元數(shù)據(jù)以及Tensorflow程序中所有使用到的運(yùn)算方法的信息古徒。元數(shù)據(jù)包括了計(jì)算圖的版本號(hào)以及用戶指定的一些標(biāo)簽,其中meta_info_def
屬性的stripped_op_list
屬性保存了Tensorflow運(yùn)算方法的信息读恃,如果一個(gè)運(yùn)算方法在計(jì)算圖中出現(xiàn)了多次隧膘,在該字段中也只出現(xiàn)一次。stripped_op_list
屬性的類型是OpList
寺惫,OpList
類型是一個(gè)OpDef
類型的列表疹吃,該類型定義了一個(gè)運(yùn)算的所有信息,包括運(yùn)算名西雀、輸入輸出和運(yùn)算的參數(shù)信息萨驶。 -
graph_def
屬性,主要記錄了Tensorflow計(jì)算圖上的節(jié)點(diǎn)信息艇肴,Tensorflow計(jì)算圖的每一個(gè)節(jié)點(diǎn)對(duì)應(yīng)了Tensorflow程序中的一個(gè)運(yùn)算腔呜。meta_info_def
屬性已經(jīng)包含了所有運(yùn)算的具體信息,所以graph_def
屬性只關(guān)注運(yùn)算的連接結(jié)果再悼。
該屬性是通過GraphDef Protocol Buffer定義的核畴,GraphDef主要包含了一個(gè)NodeDef
類型的列表,GraphDef
的versions
屬性存儲(chǔ)了Tensorflow的版本號(hào)冲九,node
屬性記錄了所有的節(jié)點(diǎn)信息膛檀。node
為NodeDef
類型,該類型的op
屬性給出了該節(jié)點(diǎn)使用的運(yùn)算方法名稱娘侍,具體信息可以通過meta_info_def
獲取,input
屬性是一個(gè)字符串列表泳炉,定義了運(yùn)算的輸入憾筏,device
屬性定義了處理該運(yùn)算的設(shè)備,attr
屬性定義了和當(dāng)前運(yùn)算相關(guān)的配置信息花鹅。 -
saver_def
屬性氧腰,記錄了持久化模型所需要用到的一些參數(shù),比如保存到文件的文件名刨肃,保存操作和加載操作的名稱以及保存頻率古拴、清理歷史記錄等。
該屬性主要通過SaverDef
定義真友。 -
collention_def
屬性黄痪,Tensorflow計(jì)算圖中可以維護(hù)不同的集合,底層實(shí)現(xiàn)就是通過collention_def
這個(gè)屬性盔然。collection_def
屬性是一個(gè)從集合名稱到集合內(nèi)容的映射桅打,其中集合名稱為字符串是嗜,集合內(nèi)容為CollentionDef Protocol Buffer
。Tensorflow計(jì)算圖上的集合主要可以維護(hù)4類不同的集合:NodeList
用于維護(hù)計(jì)算圖上的節(jié)點(diǎn)集合挺尾;BytesList
用于維護(hù)字符串或者序列化之后的Protocol Buffer的集合鹅搪;Int64List
用于維護(hù)整數(shù)集合;FloatList
用于維護(hù)實(shí)數(shù)集合遭铺。
SSTable
持久化Tensorflow中變量的取值丽柿,tf.Save
r得到的model.ckpt
文件保存了所有的變量,該文件使用SSTable
格式存儲(chǔ)的魂挂,相當(dāng)于一個(gè)(key, value)
列表甫题。
CheckpointState
持久化的最后一個(gè)文件名叫checkpoint
,這個(gè)文件是tf.train.Saver
類自動(dòng)生成且自動(dòng)維護(hù)的锰蓬。該文件中維護(hù)了一個(gè)由tf.train.Saver
類持久化的所有Tensoflow模型文件的文件名幔睬,當(dāng)某個(gè)模型文件被刪除時(shí),這個(gè)模型對(duì)應(yīng)的文件名也會(huì)被移除芹扭,checkpoint中內(nèi)容的格式為CheckpointState Protocol Buffer
麻顶。