深度強(qiáng)化學(xué)習(xí)DQN詳解CartPole(2)

二、 卷積網(wǎng)絡(luò)和訓(xùn)練

接上回 處理環(huán)境圖片禀倔。
python幾處值得關(guān)注的用法(連接)

示例用卷積網(wǎng)絡(luò)來(lái)訓(xùn)練動(dòng)作輸出:

def conv2d_size_out(size, kernel_size = 5, stride = 2):
    return (size - (kernel_size - 1) - 1) // stride  + 1

class DQN(nn.Module):
    def __init__(self, h, w, outputs):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)

        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
        linear_input_size = convw * convh * 32
        self.head = nn.Linear(linear_input_size, outputs)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))

還是比較直白的:

  • Conv 3通道 \rightarrow 16通道
  • Conv 16通道 \rightarrow 32通道
  • Conv 32通道 \rightarrow 32通道
  • Linear 512節(jié)點(diǎn) \rightarrow 2節(jié)點(diǎn)

為何第2層最后轉(zhuǎn)為512節(jié)點(diǎn),用到了卷積形狀計(jì)算公式:

conv = \frac{(X - kernel + 2*padding)}{stride}+1

conv 為某維度上卷積后的尺寸恋腕,X為卷積前的尺寸歼捏。

(W - kernel_size + 2 * padding ) // stride + 1

示例中的Conv層沒(méi)有padding处嫌,所以公式變?yōu)椋?/p>

(size - kernel_size) // stride  + 1

但不知為何示例代碼將 - kernel_size 寫(xiě)為 - (kernel_size - 1) - 1啥寇。因?yàn)閮烧咄耆嗟龋?/p>

def conv2d_size_out(size, kernel_size = 5, stride = 2):
    return (size - (kernel_size - 1) - 1) // stride  + 1

這只是某個(gè)維度的一次卷積變化偎球,所以一張圖,完整的尺寸應(yīng)該是2個(gè)維度的乘積辑甜,再經(jīng)過(guò)3層變化甜橱,乘上第三層通道數(shù),就是最終全連接層的大姓淮痢:conv _{height} \times conv _{width} \times channel。代碼寫(xiě)作:

        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
        linear_input_size = convw * convh * 32

這個(gè)網(wǎng)絡(luò)的輸出為動(dòng)作值难裆,動(dòng)作值為0或1子檀,但0/1代表的是枚舉類型,并不是值類型乃戈,也就是說(shuō)褂痰,動(dòng)作0并不意味著沒(méi)有,動(dòng)作1也不意味著1和0之間的某種數(shù)值度量關(guān)系症虑,0和1純粹是枚舉缩歪,所以輸出數(shù)為2個(gè),而不是1個(gè)谍憔。應(yīng)為將圖像縮放到40 x 90匪蝙,所以網(wǎng)絡(luò)的參數(shù)就是(40, 90习贫,2)逛球。試一下這個(gè)網(wǎng)絡(luò):

net = DQN(40, 90, 2).to(device)
scr = get_screen()
net(scr)

tensor([[-1.0281, 0.0997]], device='cuda:0', grad_fn=<AddmmBackward>)

OK,返回兩個(gè)值苫昌。


行動(dòng)決策采用 epsilon greedy policy颤绕,就是有一定的比例,選擇隨機(jī)行為(否則按照網(wǎng)絡(luò)預(yù)測(cè)的最佳行為行事)祟身。這個(gè)比例從0.9逐漸降到0.05奥务,按EXP曲線遞減:

EPS_START = 0.9 # 概率從0.9開(kāi)始
EPS_END = 0.05  #     下降到 0.05
EPS_DECAY = 200 #     越小下降越快
steps_done = 0 # 執(zhí)行了多少步

100時(shí)

200時(shí)

隨機(jī)行為是強(qiáng)化學(xué)習(xí)的靈魂,沒(méi)有隨機(jī)行動(dòng)袜硫,就沒(méi)有探索氯葬,沒(méi)有探索就沒(méi)有持續(xù)的成長(zhǎng)。select_action() 的作用就是 選擇網(wǎng)絡(luò)輸出的2個(gè)值中的最大值()或 隨機(jī)數(shù)

def select_action(state):
    global steps_done
    sample = random.random() #[0, 1)
    #epsilon greedy policy婉陷。EPS_END 加上額外部分溢谤,steps_done 越小瞻凤,額外部分越接近0.9
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            #選擇使用網(wǎng)絡(luò)來(lái)做決定。max返回 0:最大值和 1:索引
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        #選擇一個(gè)隨機(jī)數(shù) 0 或 1
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

通常網(wǎng)絡(luò)做枚舉輸出世杀,是需要用到CrossEntropy的阀参。(關(guān)于CrossEntropy的文章),示例代碼在使用網(wǎng)絡(luò)時(shí)瞻坝,簡(jiǎn)單判斷了一下蛛壳,誰(shuí)大就取誰(shuí)的索引,所以就相當(dāng)于做了一個(gè)CrossEntropy所刀。

pytorch 的 tensor.max() 返回所有維度的最大值及其索引衙荐,但如果指定了維度,就會(huì)返回namedtuple浮创,包含各維度最大值及索引 (values=..., indices=...) 忧吟。

max(1)[1] 只取了索引值,也可以用 max(1).indices斩披。view(1,1) 把數(shù)值做成[[1]] 的二維數(shù)組形式溜族。為何返回一個(gè)二維 [[1]] ? 這是因?yàn)楹竺嬉阉械膕tate用torch.cat() 合成batch(cat()說(shuō)明連接)

    return policy_net(state).max(1)[1].view(1, 1)
    # return 0 if value[0] > value[1] else 1

示例中垦沉,訓(xùn)練是用兩次屏幕截圖的差別來(lái)訓(xùn)練網(wǎng)絡(luò):

for t in count():
    # 1. 獲取屏幕 1
    last_screen = get_screen()
    # 2. 選擇行為煌抒、步進(jìn)
    action = select_action(state)
    _, reward, done, _ = env.step(action)
    # 3. 獲取屏幕 2
    current_screen = get_screen()
    # 4. 計(jì)算差別 2-1
    state = current_screen - last_screen
    # 5. 優(yōu)化網(wǎng)絡(luò)
    optimize_model()

當(dāng)前狀態(tài)及兩次狀態(tài)的差,如下所示厕倍,

  • 上邊兩個(gè)分別是step0和step1原圖
  • 中間灰色圖是差值部分寡壮,藍(lán)色是少去的部分,棕色是多出的部分
  • 下面兩圖是原始圖覆蓋差值圖讹弯,step0將完全復(fù)原為step1况既,step1則多出部分顏色加強(qiáng)

可以看出,差值是step0到step1的變化组民。

以下是關(guān)鍵訓(xùn)練循環(huán)代碼坏挠,邏輯是一樣的。只是有一處需要注意邪乍,在循環(huán)的時(shí)候降狠,會(huì)將(state, action, next_state, reward)這四個(gè)值,保存起來(lái)庇楞,循環(huán)存放在一個(gè)叫memory的列表里榜配,湊夠批次后,才會(huì)用數(shù)據(jù)訓(xùn)練網(wǎng)絡(luò)吕晌,否則optimize_model()直接返回蛋褥。

num_episodes = 50
TARGET_UPDATE = 10

for i_episode in range(num_episodes):
    env.reset()
    last_screen = get_screen()
    current_screen = get_screen()
    state = current_screen - last_screen
    
    # [0, 無(wú)限) 直到 done
    for t in count(): 
        action = select_action(state)
        _, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        last_screen = current_screen
        current_screen = get_screen()
        next_state = None if done else current_screen - last_screen
        // 保存 state, action, next_state, reward 到列表 memory
       
        state = next_state
        optimize_model()
 
        if done:
            break


關(guān)于optimize_model(),大致過(guò)程是這樣的:

  1. 從memory列表里選取n個(gè) (state, action, next_state, reward)
  2. 用net獲取state的Y[0, 1](net輸出為2個(gè)值)睛驳,再用action選出結(jié)果y
  3. 用net獲取next_state獲取Y'[0,1]烙心,取最大值 y'膜廊。如果state沒(méi)有對(duì)應(yīng)的next_state,則y' = 0
  4. 用公式算出期望y:\hat y = \gamma y' + reward (常量 \gamma = 0.9
  5. 用smooth_l1_loss計(jì)算誤差
  6. 用RMSprop 反向傳導(dǎo)優(yōu)化網(wǎng)絡(luò)

期望y的計(jì)算方法很簡(jiǎn)單淫茵,就是把next_state的net結(jié)果爪瓜,直接乘一個(gè)0.9然后加上獎(jiǎng)勵(lì)。如果有 next_state匙瘪,就是1铆铆,如果next_state為None,獎(jiǎng)勵(lì)是0丹喻。因此薄货,沒(méi)有明天的state,期望y最小碍论。
這里的關(guān)鍵是如何求期望y谅猾,用了Q learning:Q Learning解釋
也就是遺忘率為1的Q learning求值函數(shù)。為何遺忘率是1呢鳍悠?我的想法是税娜,在NN optimize的時(shí)候,本身就是有一個(gè)learning rate的贼涩,就相當(dāng)于
y \leftarrow y + \hat y \times lr,所以 Q Learning 公式中的
Q_{s,a} \leftarrow ( 1- \alpha) \times Q_{s,a} + \alpha \times \hat y
前面的部分就省掉了薯蝎。

示例使用的gamma \gamma為0.99遥倦,效果并不好,幾乎不會(huì)學(xué)習(xí)占锯。我改為0.7后袒哥,訓(xùn)練120次達(dá)到57步,總的來(lái)說(shuō)消略,就小車(chē)環(huán)境而言堡称,示例中的卷積網(wǎng)絡(luò),效果比128節(jié)點(diǎn)的全連接層網(wǎng)絡(luò)差太多艺演。128節(jié)點(diǎn)的全連接層網(wǎng)絡(luò)却紧,訓(xùn)練幾十次就可以達(dá)到滿分200步。


這是訓(xùn)練中持續(xù)時(shí)長(zhǎng)統(tǒng)計(jì)胎撤,橙色為平均值晓殊,最高也就是50多,感覺(jué)示例代碼的效果并不是很好伤提。OpenAI官方的要求是巫俺,連續(xù)跑100次平均持續(xù)時(shí)長(zhǎng)為195。這是\gamma改為0.7后的訓(xùn)練結(jié)果肿男。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末介汹,一起剝皮案震驚了整個(gè)濱河市却嗡,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌嘹承,老刑警劉巖窗价,帶你破解...
    沈念sama閱讀 212,383評(píng)論 6 493
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異赶撰,居然都是意外死亡舌镶,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,522評(píng)論 3 385
  • 文/潘曉璐 我一進(jìn)店門(mén)豪娜,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)餐胀,“玉大人,你說(shuō)我怎么就攤上這事瘤载》裨郑” “怎么了?”我有些...
    開(kāi)封第一講書(shū)人閱讀 157,852評(píng)論 0 348
  • 文/不壞的土叔 我叫張陵鸣奔,是天一觀的道長(zhǎng)墨技。 經(jīng)常有香客問(wèn)我,道長(zhǎng)挎狸,這世上最難降的妖魔是什么扣汪? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 56,621評(píng)論 1 284
  • 正文 為了忘掉前任,我火速辦了婚禮锨匆,結(jié)果婚禮上崭别,老公的妹妹穿的比我還像新娘。我一直安慰自己恐锣,他們只是感情好茅主,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,741評(píng)論 6 386
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著土榴,像睡著了一般诀姚。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上玷禽,一...
    開(kāi)封第一講書(shū)人閱讀 49,929評(píng)論 1 290
  • 那天赫段,我揣著相機(jī)與錄音,去河邊找鬼矢赁。 笑死瑞佩,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的坯台。 我是一名探鬼主播炬丸,決...
    沈念sama閱讀 39,076評(píng)論 3 410
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了稠炬?” 一聲冷哼從身側(cè)響起焕阿,我...
    開(kāi)封第一講書(shū)人閱讀 37,803評(píng)論 0 268
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎首启,沒(méi)想到半個(gè)月后暮屡,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,265評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡毅桃,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,582評(píng)論 2 327
  • 正文 我和宋清朗相戀三年褒纲,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片钥飞。...
    茶點(diǎn)故事閱讀 38,716評(píng)論 1 341
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡莺掠,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出读宙,到底是詐尸還是另有隱情彻秆,我是刑警寧澤,帶...
    沈念sama閱讀 34,395評(píng)論 4 333
  • 正文 年R本政府宣布结闸,位于F島的核電站唇兑,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏桦锄。R本人自食惡果不足惜扎附,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 40,039評(píng)論 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望结耀。 院中可真熱鬧留夜,春花似錦、人聲如沸饼记。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 30,798評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)具则。三九已至,卻和暖如春具帮,著一層夾襖步出監(jiān)牢的瞬間博肋,已是汗流浹背。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 32,027評(píng)論 1 266
  • 我被黑心中介騙來(lái)泰國(guó)打工蜂厅, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留匪凡,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,488評(píng)論 2 361
  • 正文 我出身青樓掘猿,卻偏偏與公主長(zhǎng)得像病游,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,612評(píng)論 2 350