MCTS肯骇,即蒙特卡羅樹(shù)搜索步绸,是一類(lèi)搜索算法樹(shù)的統(tǒng)稱(chēng),可以較為有效地解決一些搜索空間巨大的問(wèn)題枢泰。
如一個(gè)8*8的棋盤(pán)狼忱,第一步棋有64種著法膨疏,那么第二步則有63種,依次類(lèi)推钻弄,假如我們把第一步棋作為根節(jié)點(diǎn),那么其子節(jié)點(diǎn)就有63個(gè)者吁,再往下的子節(jié)點(diǎn)就有62個(gè)……
如果不加干預(yù)窘俺,樹(shù)結(jié)構(gòu)將會(huì)繁雜,MCTS采用策略來(lái)對(duì)獲勝性較小的著法不予考慮,如第二步的63種著法中有10種是不可能勝利的瘤泪,那么這十個(gè)子節(jié)點(diǎn)不予再次分配子節(jié)點(diǎn)灶泵。
MCTS的主要步驟分為四個(gè):
1, 選擇(Selection)
即找一個(gè)最好的值得探索的結(jié)點(diǎn)对途,通常是先選擇沒(méi)有探索過(guò)的結(jié)點(diǎn)赦邻,如果都探索過(guò)了,再選擇UCB值最大的進(jìn)行選擇(UCB是由一系列算法計(jì)算得到的值实檀,這里先不詳細(xì)講惶洲,可以簡(jiǎn)單視為value)
2, 擴(kuò)展(Expansion)
已經(jīng)選擇好了需要進(jìn)行擴(kuò)展的結(jié)點(diǎn)膳犹,那么就對(duì)其進(jìn)行擴(kuò)展恬吕,即對(duì)其一個(gè)子節(jié)點(diǎn)最為下一步棋的假設(shè),一般為隨機(jī)取一個(gè)可選的節(jié)點(diǎn)進(jìn)行擴(kuò)展须床。
3铐料, 模擬(Simulation)
擴(kuò)展出了子節(jié)點(diǎn),就可以根據(jù)該子節(jié)點(diǎn)繼續(xù)進(jìn)行模擬了豺旬,我們隨機(jī)選擇一個(gè)可選的位置作為模擬下一步的落子钠惩,將其作為子節(jié)點(diǎn),然后依據(jù)該子節(jié)點(diǎn)族阅,繼續(xù)尋找可選的位置作為子節(jié)點(diǎn)篓跛,依次類(lèi)推,直到博弈已經(jīng)判斷出了勝負(fù)耘分,將勝負(fù)信息作為最終得分举塔。
4, 回溯更新(Backpropagation)
將最終的得分累加到父節(jié)點(diǎn)求泰,不斷從下向上累加更新央渣。
對(duì)于UCB值,計(jì)算方法很簡(jiǎn)單渴频,公式如下:
其中v'表示當(dāng)前樹(shù)節(jié)點(diǎn)芽丹,v表示父節(jié)點(diǎn),Q表示這個(gè)樹(shù)節(jié)點(diǎn)的累計(jì)quality值卜朗,N表示這個(gè)樹(shù)節(jié)點(diǎn)的visit次數(shù)拔第,C是一個(gè) 常量參數(shù),通常值設(shè)為1/√2
接下來(lái)再討論怎么使用Python實(shí)現(xiàn)MCTS樹(shù)场钉。
首先樹(shù)的每個(gè)節(jié)點(diǎn)Node需要記錄其父節(jié)點(diǎn)Node parent蚊俺,和子節(jié)點(diǎn)Node children[],用于計(jì)算UCB的這個(gè)節(jié)點(diǎn)的quality值和visit次數(shù)逛万。
def __init__(self):
self.parent = None
self.children = []
self.visit_times = 0
self.quality_value = 0.0
self.state = None
state中除了需要記錄每一步的選擇泳猬,還需要記錄每一步的層數(shù)round值與reward值。
class State(object):
def __init__(self):
self.value = 0
self.round = 0
self.choices = []
整棵樹(shù)需要實(shí)現(xiàn)的功能則是,在一個(gè)環(huán)境下得封,選擇出一個(gè)最有可能獲勝的策略埋心。選擇的方法則是通過(guò)以上介紹的四個(gè)步驟不停模擬得到每個(gè)選擇的value。
其中忙上,tree_policy函數(shù)實(shí)現(xiàn)了Selection和Expansion拷呆,default_poliy函數(shù)實(shí)現(xiàn)的是Simulation過(guò)程,backup函數(shù)是BackPropagation的實(shí)現(xiàn)疫粥。
def MCTS(node):
computation_budget = 3
for i in range(computation_budget):
# 1\. 找到最合適的可擴(kuò)展子節(jié)點(diǎn)
expand_node = tree_policy(node)
# 2\. 隨機(jī)選擇下一步策略對(duì)此子節(jié)點(diǎn)進(jìn)行模擬
reward = default_policy(expand_node)
# 3\. 將模擬結(jié)果向上回傳
backup(expand_node, reward)
# 最終得到勝利的可能性最大的子節(jié)點(diǎn)
best_next_node = best_child(node, False)
return best_next_node
tree_policy:選擇最合適的子節(jié)點(diǎn)茬斧,選擇策略如下:
1,如果當(dāng)前的根節(jié)點(diǎn)是葉子節(jié)點(diǎn)手形,即沒(méi)有子節(jié)點(diǎn)可以擴(kuò)展啥供,以開(kāi)頭下棋的例子來(lái)講,即是已經(jīng)判斷出了勝負(fù)或者棋盤(pán)已滿(mǎn)的情況下库糠,則直接返回當(dāng)前節(jié)點(diǎn)伙狐。
2,如果還有沒(méi)有選擇過(guò)的葉子節(jié)點(diǎn)(下一步的某個(gè)位置的著法還沒(méi)有被模擬過(guò))瞬欧,就在沒(méi)有選擇過(guò)的方法中選擇一個(gè)返回贷屎。
3,如果所有可選擇的結(jié)點(diǎn)都已經(jīng)選擇過(guò)(當(dāng)前環(huán)境下所有的著法都已經(jīng)試過(guò))艘虎,那么往下選擇UCB值最大的子節(jié)點(diǎn)唉侄,直到滿(mǎn)足1或2的情況,到達(dá)葉子節(jié)點(diǎn)或者出現(xiàn)未選擇過(guò)的結(jié)點(diǎn)野建。
def tree_policy(node):
# 是否是葉子節(jié)點(diǎn)
while not node.get_state().is_terminal():
# 如果全部可選的結(jié)點(diǎn)都選擇過(guò)
if node.is_all_expand():
# 選擇UCB最大的值
node = best_child(node, True)
else:
# 隨機(jī)選擇一個(gè)節(jié)點(diǎn)返回
sub_node = expand(node)
return sub_node
# 返回找到的最佳子節(jié)點(diǎn)
return node
default_policy:對(duì)當(dāng)前情況進(jìn)行模擬属划,直到判斷出勝負(fù)。
策略為:輸入需要擴(kuò)展的結(jié)點(diǎn)候生,隨機(jī)操作后 創(chuàng)建新的結(jié)點(diǎn)同眯,直到最后遇到葉子節(jié)點(diǎn),得到該次模擬的reward唯鸭,然后將reward返回须蜗。
def default_policy(node):
# 獲取當(dāng)前點(diǎn)的環(huán)境狀態(tài)
current_state = node.get_state()
# 如果沒(méi)有遇到葉子節(jié)點(diǎn),就一直循環(huán)
while current_state.is_terminal() == False:
# 隨機(jī)選取一個(gè)子節(jié)點(diǎn)目溉,返回新的環(huán)境參數(shù)
current_state = current_state.get_next_state_with_random_choice()
# 結(jié)束后明肮,根據(jù)當(dāng)前的環(huán)境判斷勝負(fù),即獲得的reward值缭付,并將其返回
final_state_reward = current_state.compute_reward()
return final_state_reward
關(guān)于這個(gè)算法柿估,我簡(jiǎn)單做了一個(gè)實(shí)現(xiàn),每次從數(shù)組[1, -1, 2, -2]之間隨機(jī)取一個(gè)數(shù)做累加陷猫,共累計(jì)MAX_DEPTH層官份,使最終的和最大只厘,我們根據(jù)運(yùn)行結(jié)果可以看到烙丛,開(kāi)始-1舅巷, -2的概率比較大,但是隨著訓(xùn)練層數(shù)的增大河咽,越來(lái)越小钠右,而1,2的比例會(huì)越來(lái)越大忘蟹。
import sys
import math
import random
MAX_CHOICE = 4
MAX_DEPTH = 50
CHOICES = [1, -1, 2, -2]
class State(object):
def __init__(self):
self.value = 0
self.round = 0
self.choices = []
def new_state(self):
choice = random.choice(CHOICES)
state = State()
state.value = self.value + choice
state.round = self.round + 1
state.choices = self.choices + [choice]
return state
def __repr__(self):
return "State: {}, value: {}, choices: {}".format(
hash(self), self.value, self.choices)
class Node(object):
def __init__(self):
self.parent = None
self.children = []
self.quality = 0.0
self.visit = 0
self.state = None
def add_child(self, node):
self.children.append(node)
node.parent = self
def __repr__(self):
return "Node: {}, Q/N: {}/{}, state: {}".format(
hash(self), self.quality, self.visit, self.state)
def expand(node):
states = [nodes.state for nodes in node.children]
state = node.state.new_state()
while state in states:
state = node.state.new_state()
child_node = Node()
child_node.state = state
node.add_child(child_node)
return child_node
# 選擇飒房, 擴(kuò)展
def tree_policy(node):
# 選擇是否是葉子節(jié)點(diǎn),
while node.state.round < MAX_DEPTH:
if len(node.children) < MAX_CHOICE:
node = expand(node)
return node
else:
node = best_child(node)
return node
# 模擬
def default_policy(node):
now_state = node.state
while now_state.round < MAX_DEPTH:
now_state = now_state.new_state()
return now_state.value
def backup(node, reward):
while node != None:
node.visit += 1
node.quality += reward
node = node.parent
def best_child(node):
best_score = -sys.maxsize
best = None
for sub_node in node.children:
C = 1 / math.sqrt(2.0)
left = sub_node.quality / sub_node.visit
right = 2.0 * math.log(node.visit) / sub_node.visit
score = left + C * math.sqrt(right)
if score > best_score:
best = sub_node
best_score = score
return best
def mcts(node):
times = 5
for i in range(times):
expand = tree_policy(node)
reward = default_policy(expand)
backup(expand, reward)
best = best_child(node)
return best
def main():
init_state = State()
init_node = Node()
init_node.state = init_state
current_node = init_node
for i in range(MAX_DEPTH):
a = 0.0
b = 0.0
c = 0.0
d = 0.0
current_node = mcts(current_node)
for j in range(len(current_node.state.choices)):
if current_node.state.choices[j] == -2:
a += 1
if current_node.state.choices[j] == -1:
b += 1
if current_node.state.choices[j] == 1:
c += 1
if current_node.state.choices[j] == 2:
d += 1
print("-2的概率為", round(a/(i + 1.0), 2),
"-1的概率為", round(b/(i + 1.0), 2),
"1的概率為", round(c/(i + 1.0), 2),
"2的概率為", round(d/(i + 1.0), 2))
if __name__ == "__main__":
main()
運(yùn)行結(jié)果: