Categorical DQN-一種建模價(jià)值分布的深度強(qiáng)化學(xué)習(xí)方法钻趋!

之前介紹的DQN及其各種變體,網(wǎng)絡(luò)輸出的都是狀態(tài)-動(dòng)作價(jià)值Q的期望預(yù)估值剂习。而本文將介紹的Categorical DQN蛮位,它建模的是狀態(tài)-動(dòng)作價(jià)值Q的分布。這樣的估計(jì)方法使得估計(jì)結(jié)果更加細(xì)致可信鳞绕。

本文的論文名稱(chēng)為《A Distributional Perspective on Reinforcement Learning》失仁,地址為:https://arxiv.org/abs/1707.06887

不過(guò)論文里數(shù)學(xué)公式非常多们何,如果想要快速了解這種方法的原理萄焦,建議大家閱讀《強(qiáng)化學(xué)習(xí)精要:核心算法與Tensorflow實(shí)現(xiàn)》一書(shū)。

1、Categorical DQN

1.1 為什么要輸出價(jià)值分布拂封?

之前介紹的DQN及其各種變體茬射,網(wǎng)絡(luò)輸出的都是狀態(tài)-動(dòng)作價(jià)值Q的期望預(yù)估值。這個(gè)期望值其實(shí)忽略很多信息冒签。比如同一狀態(tài)下的兩個(gè)動(dòng)作在抛,能夠獲得的價(jià)值期望是相同的,比如都是20萧恕,第一個(gè)動(dòng)作在90%的情況下價(jià)值是10刚梭,在10%的情況下是110,另一個(gè)動(dòng)作在50%的情況下是15廊鸥,在50%的情況下是25望浩。那么雖然期望一樣,但如果我們想要減小風(fēng)險(xiǎn)惰说,我們應(yīng)該選擇后一種動(dòng)作磨德。而只有期望值的話(huà),我們是無(wú)法看到動(dòng)作背后所蘊(yùn)含的風(fēng)險(xiǎn)的吆视。

所以從理論上來(lái)說(shuō)典挑,從分布視角(distributional perspective)來(lái)建模我們的深度強(qiáng)化學(xué)習(xí)模型,可以獲得更多有用的信息啦吧,從而得到更好您觉、更穩(wěn)定的結(jié)果。

1.2 Categorical DQN原理

我們首先需要考慮的一個(gè)問(wèn)題是授滓,選擇什么樣的分布呢琳水?一種很自然的想法是一個(gè)正態(tài)分布,我們需要估計(jì)的是動(dòng)作狀態(tài)價(jià)值的期望和方差般堆,但是使用正態(tài)分布有很多限制在孝,這就將狀態(tài)價(jià)值限制為了中間概率大,兩頭概率小的一種分布形式淮摔,如果是兩頭概率大私沮,中間概率小呢?同時(shí)和橙,在訓(xùn)練時(shí)仔燕,我們計(jì)算兩個(gè)分布的差距,選擇正態(tài)分布從計(jì)算的層面也是非常困難的魔招。因此晰搀,我們選擇的分布至少需要滿(mǎn)足兩個(gè)條件:

  1. 可以表示各種各樣的分布形式,不受太多的限制办斑;
  2. 便于損失函數(shù)的計(jì)算和模型參數(shù)的更新厕隧。

基于以上兩點(diǎn),我們選擇用直方圖來(lái)表示一個(gè)分布俄周。同時(shí)我們假設(shè)價(jià)值的最終值落在[Vmin,Vmax]之間吁讨。我們要在這段中均勻找N個(gè)價(jià)值采樣點(diǎn)。要找N個(gè)價(jià)值采樣點(diǎn)峦朗,兩個(gè)點(diǎn)之間的間距計(jì)算為△z = (Vmax - Vmin)/(N-1)建丧,從而采樣點(diǎn)的集合為{zi = Vmin + i △z,i=0,1,...,N-1}。

所以波势,我們的模型要輸出一個(gè)N個(gè)值的向量翎朱,每一個(gè)值代表一個(gè)價(jià)值采樣點(diǎn)出現(xiàn)的概率。而輸入是當(dāng)前的狀態(tài)以及選擇的動(dòng)作尺铣。

接下來(lái)的關(guān)鍵是拴曲,如何進(jìn)行更新?既然是分布凛忿,我們自然地想到使用交叉熵?fù)p失函數(shù)來(lái)刻畫(huà)兩個(gè)分布的差距澈灼。而根據(jù)強(qiáng)化學(xué)習(xí)的思想,我么會(huì)有一個(gè)價(jià)值的估計(jì)分布店溢,以及一個(gè)價(jià)值的實(shí)際分布叁熔。估計(jì)分布的價(jià)值采樣點(diǎn)是z,這沒(méi)問(wèn)題床牧,而實(shí)際分布的價(jià)值采樣點(diǎn)呢荣回?z' = r + gamma * z。舉個(gè)簡(jiǎn)單的例子:

可以看到戈咳,預(yù)估的價(jià)值分布和實(shí)際的價(jià)值分布心软,由于它們的采樣點(diǎn)變的不一樣了,我們不能直接比較兩個(gè)分布的差距著蛙,因此我們需要把實(shí)際的價(jià)值分布的采樣點(diǎn)删铃,變換成跟預(yù)估的價(jià)值分布的采樣點(diǎn)一樣,即將[0.8,1.7,2,6,3.5,4.4,5.3] 投影為[0,1,2,3,4,5]册踩,當(dāng)然泳姐,相應(yīng)的概率也會(huì)發(fā)生變化。為了更方便的解釋?zhuān)覀兎Q(chēng)原有的價(jià)值采樣點(diǎn)為z暂吉,而經(jīng)過(guò)r+gamma*z得到的價(jià)值采樣點(diǎn)為z'胖秒。

為了進(jìn)行投影,我們首先要對(duì)z'的兩頭進(jìn)行裁剪慕的,也就是把小于0的變?yōu)?阎肝,大于5的變?yōu)?,此時(shí)概率不變肮街,所以經(jīng)過(guò)第一步风题,價(jià)值采樣點(diǎn)變?yōu)閦'=[0.8,1.7,2,6,3.5,4.4,5]。

接下來(lái),我們就要進(jìn)行采樣點(diǎn)的投影了沛硅。N個(gè)價(jià)值采樣點(diǎn)共有N-1個(gè)間隔眼刃,我們首先需要判斷z'中每個(gè)采樣點(diǎn)屬于z中的第幾個(gè)間隔,然后把概率按照距離分配給該間隔兩頭的價(jià)值采樣點(diǎn)上摇肌。舉例來(lái)說(shuō)擂红,z'中第一個(gè)價(jià)值采樣點(diǎn)0.8在z的第一個(gè)間隔,其兩頭的價(jià)值采樣點(diǎn)分別是0和1围小。根據(jù)距離昵骤,其對(duì)應(yīng)概率的20%(0.2 *0.2 = 0.04)應(yīng)該分配到0這個(gè)采樣點(diǎn)上,80% (0.2 * 0.8 = 0.16)應(yīng)該分配到1這個(gè)采樣點(diǎn)上肯适。這里你可能沒(méi)有繞過(guò)彎來(lái)变秦,0.8距離1較近,分配的概率應(yīng)該越多框舔。對(duì)z'所有采樣點(diǎn)進(jìn)行相同的操作蹦玫,就可以把對(duì)應(yīng)的概率投影到原有采樣點(diǎn)z上。過(guò)程的示意圖如下:

其中雨饺,1這個(gè)價(jià)值采樣點(diǎn)的概率計(jì)算如下:

z'中有兩個(gè)采樣點(diǎn)的概率要分配到1這個(gè)采樣點(diǎn)上來(lái)钳垮,分別是0.8和1.7。原有的價(jià)值采樣點(diǎn)的間隔是1额港,所以0.8距離1的上一個(gè)價(jià)值采樣點(diǎn)0的是0.8個(gè)間隔饺窿,距離1是0.2個(gè)間隔,所以0.8對(duì)應(yīng)概率的80%應(yīng)該分配到1上面移斩,同理肚医,1.7對(duì)應(yīng)概率的30% 要分配到1上,所以投影后1對(duì)應(yīng)的概率是0.2 * 0.8 + 0.3 * 0.3 = 0.25向瓷。

在進(jìn)行裁剪和投影之后肠套,實(shí)際的價(jià)值分布和預(yù)估的價(jià)值分布的價(jià)值采樣點(diǎn)都統(tǒng)一了,我們就可以計(jì)算交叉熵?fù)p失猖任,并更新模型的參數(shù)了你稚。

2、Categorical DQN的Tensorflow實(shí)現(xiàn)

本文代碼的實(shí)現(xiàn)地址為:https://github.com/princewen/tensorflow_practice/tree/master/RL/Basic-DisRL-Demo

這里我們玩的還是atrai游戲朱躺,只介紹一下模型實(shí)現(xiàn)的最關(guān)鍵的地方刁赖。

首先看模型的輸入,我們這里不用batch的形式了长搀,一次只輸入一個(gè)狀態(tài)動(dòng)作進(jìn)行更新宇弛,m_input是經(jīng)過(guò)投影后實(shí)際的價(jià)值分布:

target_state_shape = [1]
target_state_shape.extend(self.state_shape)
self.state_input = tf.placeholder(tf.float32,target_state_shape)
self.action_input = tf.placeholder(tf.int32,[1,1])
self.m_input = tf.placeholder(tf.float32,[self.atoms])

隨后是我們的價(jià)值采樣點(diǎn):

self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1)
self.z = [self.v_min + i * self.delta_z for i in range(self.atoms)]

接下來(lái)是構(gòu)建我們的dqn網(wǎng)絡(luò)結(jié)構(gòu),一個(gè)是eval-net源请,一個(gè)是target-net枪芒,網(wǎng)絡(luò)的輸入是當(dāng)前的state以及采取的動(dòng)作action彻况,只不過(guò)action是在中間過(guò)程中拼接上去的,而不是最開(kāi)始就輸入進(jìn)去的:

def build_layers(self, state, action, c_names, units_1, units_2, w_i, b_i, reg=None):
    with tf.variable_scope('conv1'):
        conv1 = conv(state, [5, 5, 3, 6], [6], [1, 2, 2, 1], w_i, b_i)
    with tf.variable_scope('conv2'):
        conv2 = conv(conv1, [3, 3, 6, 12], [12], [1, 2, 2, 1], w_i, b_i)
    with tf.variable_scope('flatten'):
        flatten = tf.contrib.layers.flatten(conv2)

    with tf.variable_scope('dense1'):
        dense1 = dense(flatten, units_1, [units_1], w_i, b_i)
    with tf.variable_scope('dense2'):
        dense2 = dense(dense1, units_2, [units_2], w_i, b_i)
    with tf.variable_scope('concat'):
        concatenated = tf.concat([dense2, tf.cast(action, tf.float32)], 1)
    with tf.variable_scope('dense3'):
        dense3 = dense(concatenated, self.atoms, [self.atoms], w_i, b_i) # 返回
    return tf.nn.softmax(dense3)

def build_cate_dqn_net(self):
    with tf.variable_scope('target_net'):
        c_names = ['target_net_arams',tf.GraphKeys.GLOBAL_VARIABLES]
        w_i = tf.random_uniform_initializer(-0.1,0.1)
        b_i = tf.constant_initializer(0.1)
        self.z_target = self.build_layers(self.state_input,self.action_input,c_names,24,24,w_i,b_i)

    with tf.variable_scope('eval_net'):
        c_names = ['eval_net_params',tf.GraphKeys.GLOBAL_VARIABLES]
        w_i = tf.random_uniform_initializer(-0.1,0.1)
        b_i = tf.constant_initializer(0.1)
        self.z_eval = self.build_layers(self.state_input,self.action_input,c_names,24,24,w_i,b_i)

可以看到舅踪,我們這里使用的是兩層卷積和三層全連接操作纽甘,動(dòng)作只在最后一層全連接時(shí)拼接上去。最后的輸出經(jīng)過(guò)softmax變?yōu)槊總€(gè)價(jià)值采樣點(diǎn)的概率硫朦。

因此我們可以根據(jù)分布求出q值:

self.q_eval = tf.reduce_sum(self.z_eval * self.z)
self.q_target = tf.reduce_sum(self.z_target * self.z)

構(gòu)建好了兩個(gè)網(wǎng)絡(luò)贷腕,我們?cè)趺催M(jìn)行訓(xùn)練呢?我們的經(jīng)驗(yàn)池中還是存放了(state,action,reward,next_state)咬展,我們首先根據(jù)把next_state放入到target-net中,遍歷每個(gè)可行的動(dòng)作瞒斩,找到q值最大的動(dòng)作破婆,作為next_action:

list_q_ = [self.sess.run(self.q_target,feed_dict={self.state_input:[s_],self.action_input:[[a]]}) for a in range(self.action_dim)]
a_ = tf.argmax(list_q_).eval()

接下來(lái),我們使用target-net計(jì)算(next_state,next_action)在原始價(jià)值采樣點(diǎn)下的概率分布:

p = self.sess.run(self.z_target,feed_dict = {self.state_input:[s_],self.action_input:[[a_]]})[0]

隨后就是進(jìn)行投影操作了胸囱,過(guò)程我們剛才已經(jīng)介紹過(guò)了祷舀,這里不在贅述:

m = np.zeros(self.atoms)
for j in range(self.atoms):
    Tz = min(self.v_max,max(self.v_min,r+gamma * self.z[j]))
    bj = (Tz - self.v_min) / self.delta_z # 分在第幾個(gè)塊里
    l,u = math.floor(bj),math.ceil(bj) # 上下界
    pj = p[j]
    m[int(l)] += pj * (u - bj)
    m[int(u)] += pj * (bj - l)

這樣,我們就得到了當(dāng)前狀態(tài)動(dòng)作的實(shí)際價(jià)值分布烹笔。然后我們可以將當(dāng)前的state和action輸入到eval-net中裳扯,并通過(guò)交叉熵?fù)p失來(lái)對(duì)模型參數(shù)進(jìn)行更新:

self.sess.run(self.optimizer,feed_dict={self.state_input:[s] , self.action_input:[action], self.m_input: m })

其中,優(yōu)化器的定義如下:

self.cross_entropy_loss = -tf.reduce_sum(self.m_input * tf.log(self.z_eval))
self.optimizer = tf.train.AdamOptimizer(self.config.LEARNING_RATE).minimize(self.cross_entropy_loss)

好了谤职,代碼部分就介紹到這里饰豺,關(guān)于Categorical DQN的更多的知識(shí),大家可以結(jié)合論文和代碼進(jìn)行更深入的理解允蜈!

參考文獻(xiàn)

https://baijiahao.baidu.com/s?id=1573880107529940&wfr=spider&for=pc

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末冤吨,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子饶套,更是在濱河造成了極大的恐慌漩蟆,老刑警劉巖,帶你破解...
    沈念sama閱讀 212,718評(píng)論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件妓蛮,死亡現(xiàn)場(chǎng)離奇詭異怠李,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)蛤克,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,683評(píng)論 3 385
  • 文/潘曉璐 我一進(jìn)店門(mén)捺癞,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人咖耘,你說(shuō)我怎么就攤上這事翘簇。” “怎么了儿倒?”我有些...
    開(kāi)封第一講書(shū)人閱讀 158,207評(píng)論 0 348
  • 文/不壞的土叔 我叫張陵版保,是天一觀的道長(zhǎng)呜笑。 經(jīng)常有香客問(wèn)我,道長(zhǎng)彻犁,這世上最難降的妖魔是什么叫胁? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 56,755評(píng)論 1 284
  • 正文 為了忘掉前任,我火速辦了婚禮汞幢,結(jié)果婚禮上驼鹅,老公的妹妹穿的比我還像新娘。我一直安慰自己森篷,他們只是感情好输钩,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,862評(píng)論 6 386
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著仲智,像睡著了一般买乃。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上钓辆,一...
    開(kāi)封第一講書(shū)人閱讀 50,050評(píng)論 1 291
  • 那天剪验,我揣著相機(jī)與錄音,去河邊找鬼前联。 笑死功戚,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的似嗤。 我是一名探鬼主播啸臀,決...
    沈念sama閱讀 39,136評(píng)論 3 410
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼双谆!你這毒婦竟也來(lái)了壳咕?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書(shū)人閱讀 37,882評(píng)論 0 268
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤顽馋,失蹤者是張志新(化名)和其女友劉穎谓厘,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體寸谜,經(jīng)...
    沈念sama閱讀 44,330評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡竟稳,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,651評(píng)論 2 327
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了熊痴。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片他爸。...
    茶點(diǎn)故事閱讀 38,789評(píng)論 1 341
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖果善,靈堂內(nèi)的尸體忽然破棺而出诊笤,到底是詐尸還是另有隱情,我是刑警寧澤巾陕,帶...
    沈念sama閱讀 34,477評(píng)論 4 333
  • 正文 年R本政府宣布讨跟,位于F島的核電站纪他,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏晾匠。R本人自食惡果不足惜茶袒,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 40,135評(píng)論 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望凉馆。 院中可真熱鬧薪寓,春花似錦、人聲如沸澜共。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 30,864評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)咳胃。三九已至植康,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間展懈,已是汗流浹背。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 32,099評(píng)論 1 267
  • 我被黑心中介騙來(lái)泰國(guó)打工供璧, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留存崖,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,598評(píng)論 2 362
  • 正文 我出身青樓睡毒,卻偏偏與公主長(zhǎng)得像来惧,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子演顾,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,697評(píng)論 2 351

推薦閱讀更多精彩內(nèi)容