微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

为什么这个蒙特卡罗树搜索算法不能正常工作?

如何解决为什么这个蒙特卡罗树搜索算法不能正常工作?

问题 我正在编写一个蒙特卡罗树搜索算法来在 Python 中下棋。我用自定义评估函数替换了模拟阶段。我的代码看起来很完美,但由于某种原因,行为很奇怪。它很容易识别即时获胜,但无法识别将死在 2 移动和将死在 3 移动位置。有什么想法吗?

我的尝试 我试着给它更多的时间来搜索,但它仍然无法找到最好的一步,即使它可以保证在两步中获胜。但是,我注意到当我关闭自定义评估并使用经典的蒙特卡罗树搜索模拟时,结果会有所改善。 (要关闭自定义评估,只需不要将任何参数传递给 Agent 构造函数。)但我确实需要它来处理自定义评估,因为我正在研究用于电路板评估的机器学习技术。

我尝试打印出搜索结果,看看算法认为哪些动作是好的。它始终将 2 中配对和 3 中配对情况下的最佳举措列为最差的。排名基于探索棋步的次数(这是 MCTS 挑选最佳棋步的方式)。

我的代码 我已经包含了整个代码,因为一切都与问题相关。要运行此代码,您可能需要安装 python-chess (pip install python-chess)。

我已经为此苦苦挣扎了一个多星期,而且越来越令人沮丧。有什么想法吗?

import math
import random
import time

import chess
import chess.engine

class Node:

    def __init__(self,state,parent,action):
        """Initializes a node structure for a Monte-Carlo search tree."""
        self.state = state
        self.parent = parent
        self.action = action

        self.unexplored_actions = list(self.state.legal_moves)
        random.shuffle(self.unexplored_actions)
        self.colour = self.state.turn
        self.children = []
        
        self.w = 0 # number of wins
        self.n = 0 # number of simulations

class Agent:
    
    def __init__(self,custom_evaluation=None):
        """Initializes a Monte-Carlo tree search agent."""
        
        if custom_evaluation:
            self._evaluate = custom_evaluation

    def mcts(self,time_limit=float('inf'),node_limit=float('inf')):
        """Runs Monte-Carlo tree search and returns an evaluation."""

        nodes_searched = 0
        start_time = time.time()

        # Initialize the root node.
        root = Node(state,None,None)

        while (time.time() - start_time) < time_limit and nodes_searched < node_limit:
            
            # Select a leaf node.
            leaf = self._select(root)

            # Add a new child node to the tree.
            if leaf.unexplored_actions:
                child = self._expand(leaf)
            else:
                child = leaf

            # Evaluate the node.
            result = self._evaluate(child)

            # Backpropagate the results.
            self._backpropagate(child,result)

            nodes_searched += 1

        result = max(root.children,key=lambda node: node.n) 

        return result

    def _uct(self,node):
        """Returns the Upper Confidence Bound 1 of a node."""
        c = math.sqrt(2)

        # We want every WHITE node to choose the worst BLACK node and vice versa.
        # scores for each node are relative to that colour.
        w = node.n - node.w

        n = node.n
        N = node.parent.n

        try:
            ucb = (w / n) + (c * math.sqrt(math.log(N) / n))
        except ZeroDivisionError:
            ucb = float('inf')

        return ucb

    def _select(self,node):
        """Returns a leaf node that either has unexplored actions or is a terminal node."""
        while (not node.unexplored_actions) and node.children:
            # Pick the child node with highest UCB.
            selection = max(node.children,key=self._uct)
            # Move to the next node.
            node = selection
        return node

    def _expand(self,node):
        """Adds one child node to the tree."""
        # Pick an unexplored action.
        action = node.unexplored_actions.pop()
        # Create a copy of the node state.
        state_copy = node.state.copy()
        # Carry out the action on the copy.
        state_copy.push(action)
        # Create a child node.
        child = Node(state_copy,node,action)
        # Add the child node to the list of children.
        node.children.append(child)
        # Return the child node.
        return child

    def _evaluate(self,node):
        """Returns an evaluation of a given node."""
        # If no custom evaluation function was passed into the object constructor,# use classic simulation.
        return self._simulate(node)

    def _simulate(self,node):
        """Randomly plays out to the end and returns a static evaluation of the terminal state."""
        board = node.state.copy()
        while not board.is_game_over():
            # Pick a random action.
            move = random.choice(list(board.legal_moves))
            # Perform the action.
            board.push(move)
        return self._calculate_static_evaluation(board)

    def _backpropagate(self,result):
        """Updates a node's values and subsequent parent values."""
        # Update the node's values.
        node.w += result.pov(node.colour).expectation()
        node.n += 1
        # Back up values to parent nodes.
        while node.parent is not None:
            node.parent.w += result.pov(node.parent.colour).expectation()
            node.parent.n += 1
            node = node.parent

    def _calculate_static_evaluation(self,board):
        """Returns a static evaluation of a *terminal* board state."""
        result = board.result(claim_draw=True)

        if result == '1-0':
            wdl = chess.engine.Wdl(wins=1000,draws=0,losses=0)
        elif result == '0-1':
            wdl = chess.engine.Wdl(wins=0,losses=1000)        
        else:
            wdl = chess.engine.Wdl(wins=0,draws=1000,losses=0)

        return chess.engine.PovWdl(wdl,chess.WHITE)


def custom_evaluation(node):
    """Returns a static evaluation of a board state."""

    board = node.state
    
    # Evaluate terminal states.
    if board.is_game_over(claim_draw=True):
        result = board.result(claim_draw=True)
        if result == '1-0':
            wdl = chess.engine.Wdl(wins=1000,chess.WHITE)
    
    # Evaluate material.
    material_balance = 0
    material_balance += len(board.pieces(chess.PAWN,chess.WHITE)) * +100
    material_balance += len(board.pieces(chess.PAWN,chess.BLACK)) * -100
    material_balance += len(board.pieces(chess.ROOK,chess.WHITE)) * +500
    material_balance += len(board.pieces(chess.ROOK,chess.BLACK)) * -500
    material_balance += len(board.pieces(chess.KNIGHT,chess.WHITE)) * +300
    material_balance += len(board.pieces(chess.KNIGHT,chess.BLACK)) * -300
    material_balance += len(board.pieces(chess.BISHOP,chess.WHITE)) * +300
    material_balance += len(board.pieces(chess.BISHOP,chess.BLACK)) * -300
    material_balance += len(board.pieces(chess.QUEEN,chess.WHITE)) * +900
    material_balance += len(board.pieces(chess.QUEEN,chess.BLACK)) * -900

    # Todo: Evaluate mobility.
    mobility = 0

    # Aggregate values.
    centipawn_evaluation = material_balance + mobility

    # Convert evaluation from centipawns to wdl.
    wdl = chess.engine.Cp(centipawn_evaluation).wdl(model='lichess')
    static_evaluation = chess.engine.PovWdl(wdl,chess.WHITE)

    return static_evaluation


m1 = chess.Board('8/8/7k/8/8/8/5R2/6R1 w - - 0 1') # f2h2
# WHITE can win in one move. Best move is f2-h2.

m2 = chess.Board('8/6k1/8/8/8/8/1K2R3/5R2 w - - 0 1')
# WHITE can win in two moves. Best move is e2-g2.

m3 = chess.Board('8/8/5k2/8/8/8/3R4/4R3 w - - 0 1')
# WHITE can win in three moves. Best move is d2-f2.

agent = Agent(custom_evaluation)

result = agent.mcts(m2,time_limit=30)
print(result)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?