solving TicTacToe with MCTS

This is a simple guide on implementing MCTS for a simple MCTS perfect game like TICTACTOE.

First you need to represent the game in data structure

board = [0, 0, 0,
         0, 0, 0,
         0, 0, 0]

We store 0 as empty, 1 as X and -1 as O, so 1 is player X and -1 is player O

You need helper functions,

Game implementation needs to be correct for MCTS to work

MCTS implementation

you will need to:

Now you need to implement the UCT formula, which is the upper confidence bound also called Thompson sampling applied to a tree.

For any given child node, the formula is stated as follows

\[UCT = \frac{Q}{N} + c\sqrt{\frac{ln{N_p}}{N}}\]

Where,
Q is the total reward accumulated at this node
N is the number of time the node was visited
\(N_p\) the number of time parent node was visited
c is an exploration constant, by convention it is usually \(\sqrt{2}\)

The first term is the average reward, the second term is a measure of how much this node has been explored. As N is inversely square root proportional

The reason why \(\sqrt{2}\) is set as convention is because it is derived from the bandit regret bounds (UCB1), for all we care this constant controls the risk, if its large, it will explore more, if small it will be greedy and go for risky exploitations

This can be represented in code as

def uct(child, c=math.sqrt(2)):
    return (child.value / child.visits) + \
           c * math.sqrt(math.log(child.parent.visits) / child.visits)

child is an object of the node class which we defined earlier, with key attributes value and visits.

The goal now is to select the best child node from a current node. As said earlier, the node represents the game state, which encodes board, player to move and other information. From the current node we branch into children nodes, for each valid move from the current state, there will be a child node. And among those child nodes we need to figure out which to select based on the reward computed by the uct.

def select(node):
    # untried moves define the child nodes 
    while node.untried_moves == [] and node.children:
        # find the node with the maximum uct
        node = max(node.children, key=uct)
    return node

To define a child node is to create a new Node with the untried moves, it can be defined as the following:

def expand(node):
    # pick a random untried move
    move = random.choice(node.untried_moves)
    # remove it from the list of untried moves
    node.untried_moves.remove(move)

    # define a new board from the current board with the move
    new_board = make_move(node.board, move, node.player)
    # create a new node now
    child = Node(
        board=new_board,
        player=-node.player,
        parent=node,
        move=move
    )
    # append this new child the current node
    node.children.append(child)
    return child

Now we need to setup the environment to simulate a game, it is a function of the board and player which randomly makes legal moves and returns the winner. Its a simple function where you keep iterating with random moves until check_winner() returns true for a given board.

The next step is to backpropagate, this becomes insanely simple as we defined X as 1 and O as -1, with draw being 0, so we can simply do the following

def backprop(node, result):
    while node is not None:
        node.visits += 1
        node.value += result * node.player
        node = node.parent

This means for each node we get a value that accumulates how many times this node was in a winning or losing sequence

Now we just need to iterate by this just expanding the node, simulating with random moves, and backpropagating to update the tree, for many iterations, we just pick the child node with the maximum visits

def best_move(root):
    return max(root.children, key=lambda c: c.visits).move

Now the agent can be simply structured a:

def mcts_move(board, player, iterations):
    root = Node(board, player)
    for _ in range(iterations):
        mcts_iterate(root)

    return best_move(root)

To check if the implementation, just make it play against itself, since TICTACTOE is mcts perfect, it should always result in a draw.