我們以數(shù)據(jù)流向?yàn)橹骶€索谐算,講講論文代碼做了些什么事情熄守。
跑算法就是先收集數(shù)據(jù),然后把它feed到構(gòu)建好的模型中去訓(xùn)練汰寓。這個(gè)代碼還多了一步planning。planning完收到新的數(shù)據(jù)苹粟,于是又開始新的一輪訓(xùn)練有滑,循環(huán)下去。
那么問題來了:
首先數(shù)據(jù)從哪里得到嵌削?
def start(logdir, args):
1.1 #首先config是多層dist(dist嵌套dist嵌套dist...毛好,總之裝各種東西的),所有的參數(shù)和函數(shù)都存放到config里面苛秕。
config = tools.AttrDict()
config = getattr(configs, args.config)(config, args.params)
#即這個(gè)函數(shù):
def default(config, params):
config.zero_step_losses = tools.AttrDict(_unlocked=True)
config = _data_processing(config, params)??? # data config
config = _model_components(config, params)?? # model config
config = _tasks(config, params)????????????? # task config
config = _loss_functions(config, params)???? # loss config
config = _training_schedule(config, params)? # training config
return config
1.2 #開始去獲取數(shù)據(jù)肌访,
training.utility.collect_initial_episodes(config)
def random_episodes(env_ctor, num_episodes, output_dir=None):
#層層wrap,每層上都處理一點(diǎn)操作艇劫,直達(dá)最后的核心吼驶。
# 其實(shí)這一系列封裝的env是在子進(jìn)程中進(jìn)行的,是由子進(jìn)程產(chǎn)生真正的環(huán)境互動(dòng)店煞。
env = env_ctor()
#原進(jìn)程發(fā)出命令蟹演,子進(jìn)程產(chǎn)生了互動(dòng),原進(jìn)程收到后再將episode寫入outdir指定的位置顷蟀。
env = wrappers.CollectGymDataset(env, output_dir)
#起了一個(gè)子進(jìn)程跑env sever酒请,與原進(jìn)程的通過pipe通信:
env = control.wrappers.ExternalProcess(env_ctor)
#當(dāng)下面函數(shù)運(yùn)行時(shí),實(shí)際是去call-> pipe.send
obs = env.reset()
#其實(shí)reset,step,close 都是去call鸣个,然后block到receive返回值羞反,這個(gè)返回值就是具體observation
2。通過與子進(jìn)程的環(huán)境互動(dòng)得到的episode保存到哪里囤萤?
當(dāng)層層封裝的env調(diào)用函數(shù)時(shí)苟弛,env.step會(huì)遞歸地深入最里層,然后執(zhí)行最后一行self._process_step阁将。這一行將episode寫入到outdir膏秫。
def step(self, action, *args, **kwargs):
if kwargs.get('blocking', True):
transition = self._env.step(action, *args, **kwargs)
return self._process_step(action, *transition)
3.從這個(gè)函數(shù)開始將數(shù)據(jù)從硬盤讀到內(nèi)存來使用:
def numpy_episodes:
1 #三個(gè)參數(shù):1 生成數(shù)據(jù)的函數(shù) 2 數(shù)據(jù)類型 3 tensor shape
train = tf.data.Dataset.from_generator(
functools.partial(loader, train_dir, shape[0], **kwargs), dtypes, shapes):
loader實(shí)際是這個(gè)函數(shù),即:將硬盤中的npz文件讀入
def _read_episodes_reload
將讀入的數(shù)據(jù)chunking(就是將讀入的數(shù)據(jù)x做盅,切成固定的長度chunk_length缤削,這樣數(shù)據(jù)就是以chunk為單位了)
train = train.flat_map(chunking)
好,數(shù)據(jù)準(zhǔn)備好了吹榴,就是構(gòu)造網(wǎng)絡(luò)亭敢,計(jì)算出loss,再optimization就好了图筹。
整個(gè)loss函數(shù)就是兩部分帅刀,construction部分和KL部分让腹,KL部分用到了overshooting。
那么什么是overshooting扣溺?
1. 首先當(dāng)length=50骇窍,即50個(gè)time step。由這個(gè)函數(shù)得出的post 和 posterior 锥余,將posterior作為prev_state 經(jīng)過第一次cell腹纳,輸出prior。用這個(gè)prior和posterior做KL驱犹,就是d=0的overshooting嘲恍。這個(gè)之所以叫zero_step,因?yàn)檫@里KL的prior確實(shí)是由posterior直接生成的,且都屬于同一個(gè)state雄驹。
有一個(gè)認(rèn)知特別重要,一個(gè)state表示下圖虛線框佃牛,一個(gè)state可以是posterior,也可以是prior医舆。 一個(gè)state如圖有5個(gè)元素俘侠,也就是posterior和prior有5個(gè)元素。且彬向,同一個(gè)state的posterior和prior的belif和rnn_state相同兼贡。
(prior, posterior), _ = tf.nn.dynamic_rnn(
cell, (embedded, prev_action, use_obs), length, dtype=tf.float32,??? # cell, inputs:shape(batchsize,max_time,?):(40,50,?), sequence_length:shape(batchsize,):(40,)
swap_memory=True)
上面這個(gè)函數(shù)很關(guān)鍵攻冷,有兩層cell娃胆,外層cell先傳入dynamic_rnn. 兩層cell分別做了什么呢?如下圖:
2.繼續(xù)講overshooting等曼。因?yàn)閐不光=0里烦,不光是固定步長,d可以= 1,2,3....amount 禁谦。所有overshooting就是求每一列的posterior(在圖中每一列的底部)與這列的其他所有priors做KL胁黑。
3. (說來說去overshooting本質(zhì)上就做了這樣一件事)將1.中所得posterior 作為prev_state 放入dynamic_rnn()求出每一斜行的priors,與這一斜行對(duì)應(yīng)的投影posterior做KL州泊。
做完overshooting把posterior和priors等都準(zhǔn)備好了丧蘸,可以計(jì)算loss了,loss由zero_step loss 和 overshooting loss組成:
1.KL遥皂,global KL loss就是求兩個(gè)分布的距離力喷,而共有(50,50)個(gè)這樣的分布
2.去調(diào)用相應(yīng)函數(shù)去計(jì)算出output(reward,image演训,state) :
output = heads[key](features)?? # decoder is used.
output與對(duì)應(yīng)的target做交叉熵:
loss = -tools.mask(output.log_prob(target[key]), mask)
最后再將loss求個(gè)均值弟孟,再根據(jù)key存放到一個(gè)字典losses中。
loss = -tools.mask(output.log_prob(target[key]), mask)
losses[key]
如下圖:
這樣zero_step loss就計(jì)算完了样悟,現(xiàn)在來看overshooting loss拂募,調(diào)用的函數(shù)都是compute_losses庭猩,區(qū)別只是準(zhǔn)備好的數(shù)據(jù)不同。
loss計(jì)算完了陈症,接下來是優(yōu)化部分:
config.optimizers由state和main兩個(gè)元素蔼水,main函數(shù)一執(zhí)行生成tools.CustomOptimizer對(duì)象,里面什么配置都有爬凑,包括lr徙缴,用那個(gè)優(yōu)化函數(shù)等。state與main同理
_define_optimizers:
optimizers[name] = functools.partial(
tools.CustomOptimizer, include=r'.*/head_{}/.*'.format(name), **kwargs)
優(yōu)化設(shè)置好了嘁信,就開始訓(xùn)練模型N步于样。這輪訓(xùn)練好了,要用這輪的模型做planning了:
其實(shí)無論是計(jì)算loss潘靖,model還是simulation所有這些工作穿剖,代碼中都按時(shí)間先后分為兩大步驟。以simulation工作為例卦溢,它的第一步都是在config階段完成配置糊余,第二步在define_model()去具體執(zhí)行。
1.配置階段单寂,把planning會(huì)用到cem函數(shù)贬芥,參數(shù),各種參數(shù)都封裝好宣决,放入config.sim_collects
config.sim_collects = _active_collection(config, params) -> _define_simulation
2.在loss計(jì)算完蘸劈,optimization 配置好后,在define_model()中開始一系列調(diào)用:
注:-> 指調(diào)用
這條線首先會(huì)判斷該不該should_collect尊沸,如果應(yīng)該威沫,則進(jìn)入這條長長的函數(shù)調(diào)用線:
define_model() -> simulate_episodes() -> simulate() ->
-> collect_rollouts()此函數(shù)里生成了MPCAgent()(即生成algo對(duì)象)-> simulate_step模擬出一步的總?cè)肟冢鴗f.scan()會(huì)讓這個(gè)函數(shù)執(zhí)行200次洼专,即走200步棒掠。->_define_begin_episode()這里把環(huán)境reset了 -> _define_summaries() -> _define_step() -> algo.perform(agent_indices, prevob) 首先embedded,再求posterior屁商,用這個(gè)state開始做planning烟很,輸出一個(gè)action。
上面是不斷往里層調(diào)函數(shù)蜡镶,返回值中有一個(gè)關(guān)鍵的score從哪里來的雾袱,返回到哪里去,如下:
add_score = score_var.assign_add(batch_env.reward) 返回給 step, score, length = _define_step() 返回給 define_summaries
最后的總結(jié)帽哑,如下圖谜酒,整個(gè)過程中的一圈就完成了,再來就是新的一輪fit model再planning妻枕。