強(qiáng)化學(xué)習(xí)基礎(chǔ)篇(十九)TD與MC在隨機(jī)游走問題應(yīng)用

強(qiáng)化學(xué)習(xí)基礎(chǔ)篇(十九)TD與MC在隨機(jī)游走問題應(yīng)用

為了比較討論一下TD與MC方法电谣,本位簡(jiǎn)單探索在一個(gè)隨機(jī)游走的示例中勒叠,TD與MC的差異。

1脊串、隨機(jī)行走(Random Walk)問題設(shè)定

狀態(tài)空間:如上圖A凉泄、B躏尉、C、D后众、E為中間狀態(tài)胀糜,C同時(shí)作為起始狀態(tài)稼锅。灰色方格表示終止?fàn)顟B(tài)僚纷;

行為空間:除終止?fàn)顟B(tài)外矩距,任一狀態(tài)可以選擇向左片排、向右兩個(gè)行為之一蒸走;

即時(shí)獎(jiǎng)勵(lì):右側(cè)的終止?fàn)顟B(tài)得到即時(shí)獎(jiǎng)勵(lì)為1,左側(cè)終止?fàn)顟B(tài)得到的即時(shí)獎(jiǎng)勵(lì)為0旗吁,在其他狀態(tài)間轉(zhuǎn)化得到的即時(shí)獎(jiǎng)勵(lì)是0痊臭;

狀態(tài)轉(zhuǎn)移:100%按行為進(jìn)行狀態(tài)轉(zhuǎn)移哮肚,進(jìn)入終止?fàn)顟B(tài)即終止;

衰減系數(shù):1广匙;

給定的策略:隨機(jī)選擇向左允趟、向右兩個(gè)行為。

問題:對(duì)這個(gè)MDP問題進(jìn)行預(yù)測(cè)鸦致,也就是評(píng)估隨機(jī)行走這個(gè)策略的狀態(tài)潮剪。

2、初始化

首先導(dǎo)入需要用到的庫函數(shù)

# 導(dǎo)入庫函數(shù)
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
## 解決matplotlib中文畫圖亂碼問題
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei']

按照問題的設(shè)定分唾,狀態(tài)空間一共為七個(gè):S \in \{left\_terminate,A,B,C,D,E,right\_terminate\}抗碰。

由于這個(gè)任務(wù)沒有折扣,所以每個(gè)狀態(tài)的真實(shí)價(jià)值是從這個(gè)狀態(tài)開始并終止與最右側(cè)的概率绽乔。因此中心狀態(tài)的真實(shí)價(jià)值為v_{\pi}(C)=0.5弧蝇。狀態(tài)A-E的真實(shí)價(jià)值分別為:1/6,2/6,3/6,4/6以及5/6。

# 定義七個(gè)狀態(tài)折砸,初始化A-E的值為0.5看疗,右邊終點(diǎn)值為1.
VALUES = np.zeros(7)
VALUES[1:6] = 0.5
VALUES[6] = 1
# 由于這個(gè)任務(wù)沒有折扣,所以每個(gè)狀態(tài)的真實(shí)價(jià)值是從這個(gè)狀態(tài)開始并終止與最右側(cè)的概率睦授。
# 因此中心狀態(tài)的真實(shí)價(jià)值為0.5两芳。
# 狀態(tài)A-E的真實(shí)價(jià)值分別為:1/6,2/6,3/6,4/6,5/6
TRUE_VALUE = np.zeros(7)
TRUE_VALUE[1:6] = np.arange(1, 6) / 6.0
TRUE_VALUE[6] = 1
# 定義向左與向右兩個(gè)動(dòng)作
ACTION_LEFT = 0
ACTION_RIGHT = 1

3、定義TD方法的實(shí)現(xiàn)

該代碼實(shí)現(xiàn)遵循表格型TD(0)算法偽代碼:

image.png
def temporal_difference(values, alpha=0.1, batch=False):
    # 定義開始點(diǎn)為C睹逃,即state=3
    state = 3
    # 定義軌跡列表
    trajectory = [state]
    # 定義獎(jiǎng)勵(lì)列表
    rewards = [0]
    while True:
        old_state = state
        # 通過一個(gè)二項(xiàng)分布盗扇,隨機(jī)選擇一個(gè)動(dòng)作祷肯,并按照動(dòng)作更新狀態(tài)
        if np.random.binomial(1, 0.5) == ACTION_LEFT:
            state -= 1
        else:
            state += 1
        # 按照問題定義沉填,處理右邊終點(diǎn),其余的獎(jiǎng)勵(lì)都是0佑笋。
        reward = 0
        # 將state狀態(tài)加入trajectory列表中
        trajectory.append(state)
        # 進(jìn)行TD更新
        if not batch:
            values[old_state] += alpha * (reward + values[state] - values[old_state])
        # 遇到終結(jié)點(diǎn)則結(jié)束該次的episode翼闹。
        if state == 6 or state == 0:
            break
        rewards.append(reward)
    return trajectory, rewards
  • 在決定隨機(jī)行走的動(dòng)作過程中,這里使用了二項(xiàng)分布 X \sim b(n,p),這里即選擇的為 X \sim b(1,0.5)

    np.random.binomial(1, 0.5)
    
  • TD的更新過程為:V(S) \leftarrow V(S)+\alpha (R+\gamma V(S')-V(S))
values[old_state] += alpha * (reward + values[state] - values[old_state])

4蒋纬、定義MC方法的實(shí)現(xiàn)

def monte_carlo(values, alpha=0.1, batch=False):
    # 定義開始點(diǎn)為C猎荠,即state=3
    state = 3
    # 定義軌跡列表
    trajectory = [3]

    # 如果最終是在左邊介紹坚弱,那么回報(bào)是0。
    # 如果最終是在右邊介紹关摇,那么回報(bào)是1荒叶。
    while True:
         # 通過一個(gè)二項(xiàng)分布,隨機(jī)選擇一個(gè)動(dòng)作输虱,并按照動(dòng)作更新狀態(tài)
        if np.random.binomial(1, 0.5) == ACTION_LEFT:
            state -= 1
        else:
            state += 1
        trajectory.append(state)

        if state == 6:
            returns = 1.0
            break
        elif state == 0:
            returns = 0.0
            break
    # 在episode完成后進(jìn)行MC更新些楣。
    if not batch:
        for state_ in trajectory[:-1]:
            # MC update
            values[state_] += alpha * (returns - values[state_])
    return trajectory, [returns] * (len(trajectory) - 1)
  • MC更新的方式遵循:V(S_t) \leftarrow V(S_t)+\alpha (G_t-V(S_t))

     values[state_] += alpha * (returns - values[state_])
    

5、計(jì)算價(jià)值函數(shù)

這里考慮在episode為[0,1,10,100]四種情況下的估計(jì)價(jià)值宪睹,我們會(huì)將運(yùn)行一次TD(0)所得到的價(jià)值估計(jì)值和真實(shí)值進(jìn)行比較

def compute_state_value():
    episodes = [0, 1, 10, 100]
    current_values = np.copy(VALUES)
    plt.figure(1)
    for i in range(episodes[-1] + 1):
        if i in episodes:
            plt.plot(current_values, label=str(i) + ' episodes(幕的次數(shù))')
        temporal_difference(current_values)
    plt.plot(TRUE_VALUE, label='true valuesd(真實(shí)價(jià)值)')
    plt.xlabel('state(狀態(tài))')
    plt.ylabel('estimated value(估計(jì)價(jià)值)')
    plt.legend()

以下為運(yùn)行結(jié)果愁茁,中間紫色線條為真實(shí)價(jià)值。我們可以看到在100幕后亭病,估計(jì)值就非常接近于真實(shí)值了鹅很。這里我們使用了默認(rèn)的步長(zhǎng)參數(shù)\alpha =0.1

plt.figure()
compute_state_value()
plt.show()
image.png

6罪帖、不同狀態(tài)下平均經(jīng)驗(yàn)均方根誤差

上面 只是簡(jiǎn)單比較了TD在運(yùn)行過程中估計(jì)價(jià)值的變化促煮,接下來我們考慮不同步長(zhǎng)參數(shù)設(shè)置的情況下,MC與TD在不同步長(zhǎng)參數(shù)下平均經(jīng)驗(yàn)均方根誤差變化情況整袁。

這里我們將比較TD的三種步長(zhǎng)參數(shù) [0.15, 0.1, 0.05]以及MC的四種步長(zhǎng)參數(shù)[0.01, 0.02, 0.03, 0.04]污茵,他們?cè)?00幕運(yùn)行過程

def rms_error():
    # 設(shè)置TD與MC的步長(zhǎng)參數(shù)
    td_alphas = [0.15, 0.1, 0.05]
    mc_alphas = [0.01, 0.02, 0.03, 0.04]
    # 設(shè)定總episode數(shù)量
    episodes = 100 + 1
    runs = 100
    # 遍歷每個(gè)alpha設(shè)置
    for i, alpha in enumerate(td_alphas + mc_alphas):
        total_errors = np.zeros(episodes)
        if i < len(td_alphas):
            method = 'TD'
            linestyle = 'solid'
        else:
            method = 'MC'
            linestyle = 'dashdot'
        # 這里整個(gè)過程一共運(yùn)行100次,每次都是100幕葬项,最后會(huì)對(duì)結(jié)果進(jìn)行平均泞当。
        for r in tqdm(range(runs)):
            errors = []
            current_values = np.copy(VALUES)
            for i in range(0, episodes):
                # 計(jì)算當(dāng)次幕下當(dāng)前估計(jì)值和真實(shí)值之間的均方根誤差。
                errors.append(np.sqrt(np.sum(np.power(TRUE_VALUE - current_values, 2)) / 5.0))
                if method == 'TD':
                    temporal_difference(current_values, alpha=alpha)
                else:
                    monte_carlo(current_values, alpha=alpha)
            total_errors += np.asarray(errors)
        total_errors /= runs
        plt.plot(total_errors, linestyle=linestyle, label=method + ', alpha = %.02f' % (alpha))
    plt.xlabel('episodes')
    plt.ylabel('RMS')
    plt.legend()

運(yùn)行如下代碼:

plt.figure(figsize=(10,5))
rms_error()
plt.tight_layout()

運(yùn)行的最終結(jié)果如下:

image.png

我們可以看到對(duì)于不同的\alpha取值民珍,兩種方法的學(xué)習(xí)曲線襟士。圖中顯示的性能衡量指標(biāo)是學(xué)到的價(jià)值函數(shù)和真實(shí)價(jià)值函數(shù)的均方根(RMS)誤差。圖中顯示的誤差是在5個(gè)狀態(tài)上的平均誤差嚷量,并在100次運(yùn)行中取平均的結(jié)果陋桂。在所有情況下,對(duì)于所有s蝶溶,近似價(jià)值函數(shù)都被初始化為中間值V(s)=0.5嗜历。在這個(gè)任務(wù)中,TD方法一直比MC方法要好抖所。

7梨州、批量更新的隨機(jī)游走

在隨機(jī)游走問題中,批量更新版本的TD(0)和常數(shù)\alpha MC方法的過程是這樣的:每經(jīng)過新的一幕序列之后田轧,之前所有幕的數(shù)據(jù)就放視為一個(gè)批次暴匠。算法TD(0)常數(shù)\alpha MC方法不斷地使用這些批次進(jìn)行逐次更新。這里\alpha要設(shè)置得足夠小以使價(jià)值函數(shù)能夠收斂傻粘。最后將所得的價(jià)值函數(shù)與v_\pi進(jìn)行比較每窖,繪制5個(gè)狀態(tài)下的平均均方根誤差(以整個(gè)實(shí)驗(yàn)的100次的獨(dú)立重復(fù)為基礎(chǔ))的學(xué)習(xí)曲線帮掉。

def batch_updating(method, episodes, alpha=0.001):
    # 整個(gè)實(shí)驗(yàn)進(jìn)行100次獨(dú)立重復(fù)運(yùn)行
    runs = 100
    total_errors = np.zeros(episodes)
    for r in tqdm(range(0, runs)):
        current_values = np.copy(VALUES)
        errors = []
        # trajectories需要記錄所有episode以及獎(jiǎng)勵(lì)
        trajectories = []
        rewards = []
        for ep in range(episodes):
            # 執(zhí)行TD(0)
            if method == 'TD':
                trajectory_, rewards_ = temporal_difference(current_values, batch=True)
            # 執(zhí)行MC
            else:
                trajectory_, rewards_ = monte_carlo(current_values, batch=True)
            trajectories.append(trajectory_)
            rewards.append(rewards_)
            while True:
                # 持續(xù)不斷得將到目前為止所有的trajectories都用于訓(xùn)練。
                updates = np.zeros(7)
                for trajectory_, rewards_ in zip(trajectories, rewards):
                    for i in range(0, len(trajectory_) - 1):
                        if method == 'TD':
                            updates[trajectory_[i]] += rewards_[i] + current_values[trajectory_[i + 1]] - current_values[trajectory_[i]]
                        else:
                            updates[trajectory_[i]] += rewards_[i] - current_values[trajectory_[i]]
                updates *= alpha
                # 當(dāng)接近收斂時(shí)才停止
                if np.sum(np.abs(updates)) < 1e-3:
                    break
                # 進(jìn)行批量更新
                current_values += updates
            # 計(jì)算rms
            errors.append(np.sqrt(np.sum(np.power(current_values - TRUE_VALUE, 2)) / 5.0))
        total_errors += np.asarray(errors)
    total_errors /= runs
    return total_errors

執(zhí)行以下代碼檢查結(jié)果:

episodes = 100 + 1
# 運(yùn)行TD(0)的批量更新
td_erros = batch_updating('TD', episodes)
# 運(yùn)行MC的批量更新
mc_erros = batch_updating('MC', episodes)

# 畫圖
plt.plot(td_erros, label='TD')
plt.plot(mc_erros, label='MC')
plt.xlabel('episodes')
plt.ylabel('RMS error')
plt.legend()
plt.show()

測(cè)試結(jié)果如下:

在隨機(jī)游走任務(wù)中窒典,批量訓(xùn)練下的TD和MC的性能對(duì)比

測(cè)試的實(shí)驗(yàn)結(jié)果可以看出蟆炊,批量TD的rms始終是低于MC方法的,批量TD方法始終優(yōu)于批量蒙特卡洛方法瀑志。其原因在蒙特卡洛方法只是從某些有限的方面來說是最優(yōu)的盅称,而TD方法的最優(yōu)性則與預(yù)測(cè)回報(bào)這個(gè)任務(wù)更為相關(guān)。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末后室,一起剝皮案震驚了整個(gè)濱河市缩膝,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌岸霹,老刑警劉巖疾层,帶你破解...
    沈念sama閱讀 221,635評(píng)論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異贡避,居然都是意外死亡痛黎,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,543評(píng)論 3 399
  • 文/潘曉璐 我一進(jìn)店門刮吧,熙熙樓的掌柜王于貴愁眉苦臉地迎上來湖饱,“玉大人,你說我怎么就攤上這事杀捻【幔” “怎么了?”我有些...
    開封第一講書人閱讀 168,083評(píng)論 0 360
  • 文/不壞的土叔 我叫張陵致讥,是天一觀的道長(zhǎng)仅仆。 經(jīng)常有香客問我,道長(zhǎng)垢袱,這世上最難降的妖魔是什么墓拜? 我笑而不...
    開封第一講書人閱讀 59,640評(píng)論 1 296
  • 正文 為了忘掉前任,我火速辦了婚禮请契,結(jié)果婚禮上咳榜,老公的妹妹穿的比我還像新娘。我一直安慰自己爽锥,他們只是感情好涌韩,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,640評(píng)論 6 397
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著救恨,像睡著了一般贸辈。 火紅的嫁衣襯著肌膚如雪释树。 梳的紋絲不亂的頭發(fā)上肠槽,一...
    開封第一講書人閱讀 52,262評(píng)論 1 308
  • 那天擎淤,我揣著相機(jī)與錄音,去河邊找鬼秸仙。 笑死嘴拢,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的寂纪。 我是一名探鬼主播席吴,決...
    沈念sama閱讀 40,833評(píng)論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼捞蛋!你這毒婦竟也來了孝冒?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,736評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤拟杉,失蹤者是張志新(化名)和其女友劉穎庄涡,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體搬设,經(jīng)...
    沈念sama閱讀 46,280評(píng)論 1 319
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡穴店,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,369評(píng)論 3 340
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了拿穴。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片泣洞。...
    茶點(diǎn)故事閱讀 40,503評(píng)論 1 352
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖默色,靈堂內(nèi)的尸體忽然破棺而出球凰,到底是詐尸還是另有隱情,我是刑警寧澤腿宰,帶...
    沈念sama閱讀 36,185評(píng)論 5 350
  • 正文 年R本政府宣布弟蚀,位于F島的核電站,受9級(jí)特大地震影響酗失,放射性物質(zhì)發(fā)生泄漏义钉。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,870評(píng)論 3 333
  • 文/蒙蒙 一规肴、第九天 我趴在偏房一處隱蔽的房頂上張望捶闸。 院中可真熱鬧,春花似錦拖刃、人聲如沸删壮。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,340評(píng)論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽央碟。三九已至,卻和暖如春均函,著一層夾襖步出監(jiān)牢的瞬間亿虽,已是汗流浹背菱涤。 一陣腳步聲響...
    開封第一講書人閱讀 33,460評(píng)論 1 272
  • 我被黑心中介騙來泰國(guó)打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留洛勉,地道東北人粘秆。 一個(gè)月前我還...
    沈念sama閱讀 48,909評(píng)論 3 376
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像收毫,于是被迫代替她去往敵國(guó)和親攻走。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,512評(píng)論 2 359