自定義gym環(huán)境并使用RL訓(xùn)練--尋找寶石


完整代碼已上傳到 github

result_polyDL.mp4.gif


最近有項(xiàng)目需要用到RL相關(guān)的一些東西纳像,于是就開(kāi)始嘗試自己搭建一個(gè)自定義的gym環(huán)境病毡,并使用入門的DQN網(wǎng)絡(luò)對(duì)這個(gè)環(huán)境進(jìn)行訓(xùn)練计雌,這個(gè)是我入門的第一個(gè)項(xiàng)目同窘,可能有一些地方理解的不夠的或者有問(wèn)題的焙贷,希望見(jiàn)諒并能指正镜廉。

其中環(huán)境的自定義部分參考了csdn extremebingo的文章贺喝,模型建立與訓(xùn)練過(guò)程參考了: pytorch official tutorials,訓(xùn)練結(jié)果的展示參考了:tensorflow org tutorials

尋找寶石游戲

綠色的小圓圈代表機(jī)器人斗幼,紅色圈圈表示火坑澎蛛,藍(lán)色圓圈表示寶石,褐色圈圈表示石柱孟岛,其中環(huán)境每次重置機(jī)器人便會(huì)出生在任意一個(gè)空白的格子中瓶竭,機(jī)器人需要找到含有寶石的格子獲得獎(jiǎng)勵(lì)結(jié)束游戲。在尋找的過(guò)程中如果踩入火坑游戲結(jié)束獲得負(fù)獎(jiǎng)勵(lì)渠羞,機(jī)器人無(wú)法移動(dòng)到石柱所在的格子中斤贰。

自定義gym環(huán)境

自定義gym環(huán)境模塊主要參考了csdn extremebingo的文章,可以直接點(diǎn)擊查看自定義的具體流程介紹次询,也可以參考github Readme 的gym Env set up模塊介紹中的操作流程荧恍。這里就不再贅述,下面主要介紹下使用這個(gè)流程中可能有的坑:

  1. 將自定義的文件拷貝到環(huán)境中可能不生效屯吊,可以嘗試在這個(gè)路徑同樣進(jìn)行一遍操作:
    C:\Users\xxx\AppData\Roaming\Python\Python37\site-packages\gym\envs

  2. extremebingo 構(gòu)建的環(huán)境中有部分代碼存在一些筆誤還有一些bug送巡,這里進(jìn)行了一些修改,修改后的環(huán)境代碼

模型構(gòu)建與訓(xùn)練

數(shù)據(jù)收集

訓(xùn)練數(shù)據(jù)主要有:(state, action, next_state, reward)

  • state 當(dāng)前環(huán)境的狀態(tài)
  • action 當(dāng)前狀態(tài)時(shí)盒卸,機(jī)器人執(zhí)行的動(dòng)作
  • next_state 執(zhí)行該動(dòng)作后的狀態(tài)
  • reward 執(zhí)行該動(dòng)作后獲得的激勵(lì)

(這里用環(huán)境render的圖表示state 見(jiàn)get_screen,actions = ['n', 'e', 's', 'w'] 含意為:n 上 s下 w左 e 右 reward 找到寶石+1骗爆,踩到火坑-1,增加步數(shù)在訓(xùn)練的過(guò)程中進(jìn)行適度的懲罰

數(shù)據(jù)收集過(guò)程中的action 根據(jù)當(dāng)前訓(xùn)練的狀態(tài)按照概率選擇使用模型結(jié)果或者隨機(jī)選擇動(dòng)作執(zhí)行下一步操作
這個(gè)概率值由EPS_END EPS_STAR EPS_DECAY 還有steps_done 共同控制 結(jié)果按照指數(shù)進(jìn)行衰減
這里我使用的值為:

    EPS_START = 0.9
    EPS_END = 0.05
    EPS_DECAY = 20000

選擇隨機(jī)策略的概率隨著訓(xùn)練次數(shù)steps_done 的變化如下圖所示:


eps.png

這里eps_decay 改為了20000而不是torch offical tutorials里的200,主要是因?yàn)檫@個(gè)環(huán)境比小車的稍微復(fù)雜蔽介,因此前期需要更多的隨機(jī)策略的樣本訓(xùn)練摘投,offical turorials 里概率的變化曲線如下:

eps.jpg

,當(dāng)我們?cè)趖est模型時(shí)虹蓄,主要應(yīng)選取模型的輸出作為下一個(gè)action 因此 我在代碼中增加了eval時(shí)eps_threshold=0.001

def select_action(state, eval=False):
    global steps_done
    sample = random.random()

    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
                    math.exp(-1. * steps_done / EPS_DECAY)
    if eval:
        eps_threshold = 0.001
    print("eps_threshold:{} ,steps_done:{}".format(eps_threshold, steps_done))
    steps_done += 1

    if sample > eps_threshold:
        print("select Model")
        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            if eval:
                return target_net(state).max(1)[1].view(1, 1)
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        print("select random")
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

運(yùn)行過(guò)程中生產(chǎn)的數(shù)據(jù)放到一個(gè)存儲(chǔ)類中犀呼,每次隨機(jī)采樣batchSize條數(shù)據(jù)訓(xùn)練:

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

get_screen 函數(shù)主要獲取環(huán)境狀態(tài)改變時(shí)的圖像

def get_screen():
    # Returned screen requested by gym is 400x600x3, but is sometimes larger
    # such as 800x1200x3. Transpose it into torch order (CHW).
    screen = env.render(mode='rgb_array').transpose((2, 0, 1))
    # Cart is in the lower half, so strip off the top and bottom of the screen
    _, screen_height, screen_width = screen.shape
    # print("screen_height {}, screen_width {}".format(screen_height,screen_width))
    screen = screen[:, int(screen_height * 0):int(screen_height * 0.9)]
    view_width = int(screen_width * 0.6)

    # Strip off the edges, so that we have a square image centered on a cart
    # screen = screen[:, :, slice_range]
    # Convert to float, rescale, convert to torch tensor
    # (this doesn't require a copy)
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = torch.from_numpy(screen)
    # Resize, and add a batch dimension (BCHW)
    return resize(screen).unsqueeze(0).to(device)

模型構(gòu)建

DQN 網(wǎng)絡(luò)使用三層卷積,根據(jù)狀態(tài) 預(yù)測(cè)下一步采取各個(gè)行動(dòng)的收益


class DQN(nn.Module):

    def __init__(self, h, w, outputs):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)

        # Number of Linear input connections depends on output of conv2d layers
        # and therefore the input image size, so compute it.
        def conv2d_size_out(size, kernel_size=5, stride=2):
            return (size - (kernel_size - 1) - 1) // stride + 1

        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
        linear_input_size = convw * convh * 32
        self.head = nn.Linear(linear_input_size, outputs)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))

訓(xùn)練過(guò)程模型參數(shù)更新

通過(guò)policy_net (參數(shù)實(shí)時(shí)更新的net)根據(jù)batch數(shù)據(jù)中的state信息預(yù)測(cè)下一步采取的每個(gè)行動(dòng)的收益薇组,生成bx4(action 可選擇的個(gè)數(shù)4)的矩陣外臂,根據(jù)batch 中 的action 的index 選擇 這一action 模型預(yù)測(cè)的值(Q(s_t, a) - model computes Q(s_t)):

 state_action_values = policy_net(state_batch).gather(1, action_batch)

使用target_net (參數(shù)更新copy from policy net延遲的net) 使用next state信息(過(guò)濾掉 狀態(tài)為none)預(yù)測(cè)最大收益的行動(dòng):next_state_values

當(dāng)前狀態(tài)的收益期望值 = 下一狀態(tài)預(yù)測(cè)的行動(dòng)最大收益(next_state_values)*GAMMA + 當(dāng)前狀態(tài)行為的實(shí)際收益 reward_batch 如下所示:

expected_state_action_values = (next_state_values * GAMMA) + reward_batch

根據(jù)當(dāng)前網(wǎng)絡(luò)預(yù)測(cè)的動(dòng)作收益 state_action_values 與實(shí)際期望的收益的誤差作為模型的loss 更新整個(gè)策略網(wǎng)絡(luò)

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    print("loss:{}".format(loss.item()))

    # Optimize the model
    optimizer.zero_grad()  
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

該函數(shù)optimize_model完整代碼如下:

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                       if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values

    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    print("loss:{}".format(loss.item()))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市律胀,隨后出現(xiàn)的幾起案子宋光,更是在濱河造成了極大的恐慌,老刑警劉巖炭菌,帶你破解...
    沈念sama閱讀 217,406評(píng)論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件罪佳,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡娃兽,警方通過(guò)查閱死者的電腦和手機(jī)菇民,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,732評(píng)論 3 393
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人第练,你說(shuō)我怎么就攤上這事阔馋。” “怎么了娇掏?”我有些...
    開(kāi)封第一講書人閱讀 163,711評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵呕寝,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我婴梧,道長(zhǎng)下梢,這世上最難降的妖魔是什么? 我笑而不...
    開(kāi)封第一講書人閱讀 58,380評(píng)論 1 293
  • 正文 為了忘掉前任塞蹭,我火速辦了婚禮孽江,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘番电。我一直安慰自己岗屏,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,432評(píng)論 6 392
  • 文/花漫 我一把揭開(kāi)白布漱办。 她就那樣靜靜地躺著这刷,像睡著了一般。 火紅的嫁衣襯著肌膚如雪娩井。 梳的紋絲不亂的頭發(fā)上暇屋,一...
    開(kāi)封第一講書人閱讀 51,301評(píng)論 1 301
  • 那天,我揣著相機(jī)與錄音洞辣,去河邊找鬼咐刨。 笑死,一個(gè)胖子當(dāng)著我的面吹牛屋彪,可吹牛的內(nèi)容都是我干的所宰。 我是一名探鬼主播绒尊,決...
    沈念sama閱讀 40,145評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼畜挥,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了婴谱?” 一聲冷哼從身側(cè)響起蟹但,我...
    開(kāi)封第一講書人閱讀 39,008評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎谭羔,沒(méi)想到半個(gè)月后华糖,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,443評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡瘟裸,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,649評(píng)論 3 334
  • 正文 我和宋清朗相戀三年客叉,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 39,795評(píng)論 1 347
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡兼搏,死狀恐怖卵慰,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情佛呻,我是刑警寧澤裳朋,帶...
    沈念sama閱讀 35,501評(píng)論 5 345
  • 正文 年R本政府宣布,位于F島的核電站吓著,受9級(jí)特大地震影響鲤嫡,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜绑莺,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,119評(píng)論 3 328
  • 文/蒙蒙 一暖眼、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧纺裁,春花似錦罢荡、人聲如沸。這莊子的主人今日做“春日...
    開(kāi)封第一講書人閱讀 31,731評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至浪南,卻和暖如春笼才,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背络凿。 一陣腳步聲響...
    開(kāi)封第一講書人閱讀 32,865評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工骡送, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人絮记。 一個(gè)月前我還...
    沈念sama閱讀 47,899評(píng)論 2 370
  • 正文 我出身青樓摔踱,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親怨愤。 傳聞我的和親對(duì)象是個(gè)殘疾皇子派敷,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,724評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容