代碼解析《Learning Latent Dynamics for Planning from Pixels》

我們以數(shù)據(jù)流向?yàn)橹骶€索谐算,講講論文代碼做了些什么事情熄守。

跑算法就是先收集數(shù)據(jù),然后把它feed到構(gòu)建好的模型中去訓(xùn)練汰寓。這個(gè)代碼還多了一步planning。planning完收到新的數(shù)據(jù)苹粟,于是又開始新的一輪訓(xùn)練有滑,循環(huán)下去。

那么問題來了:

首先數(shù)據(jù)從哪里得到嵌削?

def start(logdir, args):

1.1 #首先config是多層dist(dist嵌套dist嵌套dist...毛好,總之裝各種東西的),所有的參數(shù)和函數(shù)都存放到config里面苛秕。

config = tools.AttrDict()

config = getattr(configs, args.config)(config, args.params)

#即這個(gè)函數(shù):

def default(config, params):

config.zero_step_losses = tools.AttrDict(_unlocked=True)

config = _data_processing(config, params)??? # data config

config = _model_components(config, params)?? # model config

config = _tasks(config, params)????????????? # task config

config = _loss_functions(config, params)???? # loss config

config = _training_schedule(config, params)? # training config

return config

1.2 #開始去獲取數(shù)據(jù)肌访,

training.utility.collect_initial_episodes(config)

def random_episodes(env_ctor, num_episodes, output_dir=None):

#層層wrap,每層上都處理一點(diǎn)操作艇劫,直達(dá)最后的核心吼驶。

# 其實(shí)這一系列封裝的env是在子進(jìn)程中進(jìn)行的,是由子進(jìn)程產(chǎn)生真正的環(huán)境互動(dòng)店煞。

env = env_ctor()

#原進(jìn)程發(fā)出命令蟹演,子進(jìn)程產(chǎn)生了互動(dòng),原進(jìn)程收到后再將episode寫入outdir指定的位置顷蟀。

env = wrappers.CollectGymDataset(env, output_dir)

#起了一個(gè)子進(jìn)程跑env sever酒请,與原進(jìn)程的通過pipe通信:

env = control.wrappers.ExternalProcess(env_ctor)

#當(dāng)下面函數(shù)運(yùn)行時(shí),實(shí)際是去call-> pipe.send

obs = env.reset()

#其實(shí)reset,step,close 都是去call鸣个,然后block到receive返回值羞反,這個(gè)返回值就是具體observation

2。通過與子進(jìn)程的環(huán)境互動(dòng)得到的episode保存到哪里囤萤?

當(dāng)層層封裝的env調(diào)用函數(shù)時(shí)苟弛,env.step會(huì)遞歸地深入最里層,然后執(zhí)行最后一行self._process_step阁将。這一行將episode寫入到outdir膏秫。

def step(self, action, *args, **kwargs):

if kwargs.get('blocking', True):

transition = self._env.step(action, *args, **kwargs)

return self._process_step(action, *transition)

3.從這個(gè)函數(shù)開始將數(shù)據(jù)從硬盤讀到內(nèi)存來使用:

def numpy_episodes:

1 #三個(gè)參數(shù):1 生成數(shù)據(jù)的函數(shù) 2 數(shù)據(jù)類型 3 tensor shape

train = tf.data.Dataset.from_generator(

functools.partial(loader, train_dir, shape[0], **kwargs), dtypes, shapes):

loader實(shí)際是這個(gè)函數(shù),即:將硬盤中的npz文件讀入

def _read_episodes_reload

將讀入的數(shù)據(jù)chunking(就是將讀入的數(shù)據(jù)x做盅,切成固定的長度chunk_length缤削,這樣數(shù)據(jù)就是以chunk為單位了)

train = train.flat_map(chunking)

好,數(shù)據(jù)準(zhǔn)備好了吹榴,就是構(gòu)造網(wǎng)絡(luò)亭敢,計(jì)算出loss,再optimization就好了图筹。

整個(gè)loss函數(shù)就是兩部分帅刀,construction部分和KL部分让腹,KL部分用到了overshooting。

那么什么是overshooting扣溺?

1. 首先當(dāng)length=50骇窍,即50個(gè)time step。由這個(gè)函數(shù)得出的post 和 posterior 锥余,將posterior作為prev_state 經(jīng)過第一次cell腹纳,輸出prior。用這個(gè)prior和posterior做KL驱犹,就是d=0的overshooting嘲恍。這個(gè)之所以叫zero_step,因?yàn)檫@里KL的prior確實(shí)是由posterior直接生成的,且都屬于同一個(gè)state雄驹。

有一個(gè)認(rèn)知特別重要,一個(gè)state表示下圖虛線框佃牛,一個(gè)state可以是posterior,也可以是prior医舆。 一個(gè)state如圖有5個(gè)元素俘侠,也就是posterior和prior有5個(gè)元素。且彬向,同一個(gè)state的posterior和prior的belif和rnn_state相同兼贡。

(prior, posterior), _ = tf.nn.dynamic_rnn(

cell, (embedded, prev_action, use_obs), length, dtype=tf.float32,??? # cell, inputs:shape(batchsize,max_time,?):(40,50,?), sequence_length:shape(batchsize,):(40,)

swap_memory=True)

上面這個(gè)函數(shù)很關(guān)鍵攻冷,有兩層cell娃胆,外層cell先傳入dynamic_rnn. 兩層cell分別做了什么呢?如下圖:

2.繼續(xù)講overshooting等曼。因?yàn)閐不光=0里烦,不光是固定步長,d可以= 1,2,3....amount 禁谦。所有overshooting就是求每一列的posterior(在圖中每一列的底部)與這列的其他所有priors做KL胁黑。

3. (說來說去overshooting本質(zhì)上就做了這樣一件事)將1.中所得posterior 作為prev_state 放入dynamic_rnn()求出每一斜行的priors,與這一斜行對(duì)應(yīng)的投影posterior做KL州泊。

做完overshooting把posterior和priors等都準(zhǔn)備好了丧蘸,可以計(jì)算loss了,loss由zero_step loss 和 overshooting loss組成:

1.KL遥皂,global KL loss就是求兩個(gè)分布的距離力喷,而共有(50,50)個(gè)這樣的分布

2.去調(diào)用相應(yīng)函數(shù)去計(jì)算出output(reward,image演训,state) :

output = heads[key](features)?? # decoder is used.

output與對(duì)應(yīng)的target做交叉熵:

loss = -tools.mask(output.log_prob(target[key]), mask)

最后再將loss求個(gè)均值弟孟,再根據(jù)key存放到一個(gè)字典losses中。

loss = -tools.mask(output.log_prob(target[key]), mask)

losses[key]

如下圖:

這樣zero_step loss就計(jì)算完了样悟,現(xiàn)在來看overshooting loss拂募,調(diào)用的函數(shù)都是compute_losses庭猩,區(qū)別只是準(zhǔn)備好的數(shù)據(jù)不同。

loss計(jì)算完了陈症,接下來是優(yōu)化部分:

config.optimizers由state和main兩個(gè)元素蔼水,main函數(shù)一執(zhí)行生成tools.CustomOptimizer對(duì)象,里面什么配置都有爬凑,包括lr徙缴,用那個(gè)優(yōu)化函數(shù)等。state與main同理

_define_optimizers:

optimizers[name] = functools.partial(

tools.CustomOptimizer, include=r'.*/head_{}/.*'.format(name), **kwargs)

優(yōu)化設(shè)置好了嘁信,就開始訓(xùn)練模型N步于样。這輪訓(xùn)練好了,要用這輪的模型做planning了:

其實(shí)無論是計(jì)算loss潘靖,model還是simulation所有這些工作穿剖,代碼中都按時(shí)間先后分為兩大步驟。以simulation工作為例卦溢,它的第一步都是在config階段完成配置糊余,第二步在define_model()去具體執(zhí)行。

1.配置階段单寂,把planning會(huì)用到cem函數(shù)贬芥,參數(shù),各種參數(shù)都封裝好宣决,放入config.sim_collects

config.sim_collects = _active_collection(config, params) -> _define_simulation

2.在loss計(jì)算完蘸劈,optimization 配置好后,在define_model()中開始一系列調(diào)用:

注:-> 指調(diào)用

這條線首先會(huì)判斷該不該should_collect尊沸,如果應(yīng)該威沫,則進(jìn)入這條長長的函數(shù)調(diào)用線

define_model() -> simulate_episodes() -> simulate() ->

-> collect_rollouts()此函數(shù)里生成了MPCAgent()(即生成algo對(duì)象)-> simulate_step模擬出一步的總?cè)肟冢鴗f.scan()會(huì)讓這個(gè)函數(shù)執(zhí)行200次洼专,即走200步棒掠。->_define_begin_episode()這里把環(huán)境reset了 -> _define_summaries() -> _define_step() -> algo.perform(agent_indices, prevob) 首先embedded,再求posterior屁商,用這個(gè)state開始做planning烟很,輸出一個(gè)action。

上面是不斷往里層調(diào)函數(shù)蜡镶,返回值中有一個(gè)關(guān)鍵的score從哪里來的雾袱,返回到哪里去,如下:

add_score = score_var.assign_add(batch_env.reward) 返回給 step, score, length = _define_step() 返回給 define_summaries

最后的總結(jié)帽哑,如下圖谜酒,整個(gè)過程中的一圈就完成了,再來就是新的一輪fit model再planning妻枕。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末僻族,一起剝皮案震驚了整個(gè)濱河市粘驰,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌述么,老刑警劉巖蝌数,帶你破解...
    沈念sama閱讀 207,248評(píng)論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異度秘,居然都是意外死亡顶伞,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,681評(píng)論 2 381
  • 文/潘曉璐 我一進(jìn)店門剑梳,熙熙樓的掌柜王于貴愁眉苦臉地迎上來唆貌,“玉大人,你說我怎么就攤上這事垢乙∠橇” “怎么了?”我有些...
    開封第一講書人閱讀 153,443評(píng)論 0 344
  • 文/不壞的土叔 我叫張陵追逮,是天一觀的道長酪刀。 經(jīng)常有香客問我,道長钮孵,這世上最難降的妖魔是什么骂倘? 我笑而不...
    開封第一講書人閱讀 55,475評(píng)論 1 279
  • 正文 為了忘掉前任,我火速辦了婚禮巴席,結(jié)果婚禮上历涝,老公的妹妹穿的比我還像新娘。我一直安慰自己情妖,他們只是感情好睬关,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,458評(píng)論 5 374
  • 文/花漫 我一把揭開白布诱担。 她就那樣靜靜地躺著毡证,像睡著了一般。 火紅的嫁衣襯著肌膚如雪蔫仙。 梳的紋絲不亂的頭發(fā)上料睛,一...
    開封第一講書人閱讀 49,185評(píng)論 1 284
  • 那天,我揣著相機(jī)與錄音摇邦,去河邊找鬼恤煞。 笑死,一個(gè)胖子當(dāng)著我的面吹牛施籍,可吹牛的內(nèi)容都是我干的居扒。 我是一名探鬼主播,決...
    沈念sama閱讀 38,451評(píng)論 3 401
  • 文/蒼蘭香墨 我猛地睜開眼丑慎,長吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼喜喂!你這毒婦竟也來了瓤摧?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,112評(píng)論 0 261
  • 序言:老撾萬榮一對(duì)情侶失蹤玉吁,失蹤者是張志新(化名)和其女友劉穎照弥,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體进副,經(jīng)...
    沈念sama閱讀 43,609評(píng)論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡这揣,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,083評(píng)論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了影斑。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片给赞。...
    茶點(diǎn)故事閱讀 38,163評(píng)論 1 334
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖矫户,靈堂內(nèi)的尸體忽然破棺而出塞俱,到底是詐尸還是另有隱情,我是刑警寧澤吏垮,帶...
    沈念sama閱讀 33,803評(píng)論 4 323
  • 正文 年R本政府宣布障涯,位于F島的核電站,受9級(jí)特大地震影響膳汪,放射性物質(zhì)發(fā)生泄漏唯蝶。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,357評(píng)論 3 307
  • 文/蒙蒙 一遗嗽、第九天 我趴在偏房一處隱蔽的房頂上張望粘我。 院中可真熱鬧,春花似錦痹换、人聲如沸征字。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,357評(píng)論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽匙姜。三九已至,卻和暖如春冯痢,著一層夾襖步出監(jiān)牢的瞬間氮昧,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,590評(píng)論 1 261
  • 我被黑心中介騙來泰國打工浦楣, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留袖肥,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 45,636評(píng)論 2 355
  • 正文 我出身青樓振劳,卻偏偏與公主長得像椎组,于是被迫代替她去往敵國和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子历恐,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,925評(píng)論 2 344

推薦閱讀更多精彩內(nèi)容