solving TicTacToe with MCTS
This is a simple guide on implementing MCTS for a simple MCTS perfect game like TICTACTOE.
First you need to represent the game in data structure
board = [0, 0, 0,
0, 0, 0,
0, 0, 0]
We store 0 as empty, 1 as X and -1 as O, so 1 is player X and -1 is player O
You need helper functions,
- legal_moves (a list of possible moves)
- make_move (append the move in the board list if it is legal and switch turns)
- check_winner (need to check if a player has a win condition)
Game implementation needs to be correct for MCTS to work
MCTS implementation
you will need to:
-
define a node structure, each node represents a game state, it includes
- the board
- player to move
- parent node
- children nodes
- previous move
- list of legal moves in this state
- visits to the node
- reward value of the node
Now you need to implement the UCT formula, which is the upper confidence bound also called Thompson sampling applied to a tree.
For any given child node, the formula is stated as follows
Where,
Q is the total reward accumulated at this node
N is the number of time the node was visited
\(N_p\) the number of time parent node was visited
c is an exploration constant, by convention it is usually \(\sqrt{2}\)
The first term is the average reward, the second term is a measure of how much this node has been explored. As N is inversely square root proportional
- the UCT value returns more bonus if N is small
- if N is large the UCT value returns less bonus
This ensures that moves which haven't been tried, will get a priority.
The reason why \(\sqrt{2}\) is set as convention is because it is derived from the bandit regret bounds (UCB1), for all we care this constant controls the risk, if its large, it will explore more, if small it will be greedy and go for risky exploitations
This can be represented in code as
def uct(child, c=math.sqrt(2)):
return (child.value / child.visits) + \
c * math.sqrt(math.log(child.parent.visits) / child.visits)
child is an object of the node class which we defined earlier, with key attributes value and visits.
The goal now is to select the best child node from a current node. As said earlier, the node represents the game state, which encodes board, player to move and other information. From the current node we branch into children nodes, for each valid move from the current state, there will be a child node. And among those child nodes we need to figure out which to select based on the reward computed by the uct.
def select(node):
# untried moves define the child nodes
while node.untried_moves == [] and node.children:
# find the node with the maximum uct
node = max(node.children, key=uct)
return node
To define a child node is to create a new Node with the untried moves, it can be defined as the following:
def expand(node):
# pick a random untried move
move = random.choice(node.untried_moves)
# remove it from the list of untried moves
node.untried_moves.remove(move)
# define a new board from the current board with the move
new_board = make_move(node.board, move, node.player)
# create a new node now
child = Node(
board=new_board,
player=-node.player,
parent=node,
move=move
)
# append this new child the current node
node.children.append(child)
return child
Now we need to setup the environment to simulate a game, it is a function of the board and player which randomly makes legal moves and returns the winner. Its a simple function where you keep iterating with random moves until check_winner() returns true for a given board.
The next step is to backpropagate, this becomes insanely simple as we defined X as 1 and O as -1, with draw being 0, so we can simply do the following
def backprop(node, result):
while node is not None:
node.visits += 1
node.value += result * node.player
node = node.parent
This means for each node we get a value that accumulates how many times this node was in a winning or losing sequence
Now we just need to iterate by this just expanding the node, simulating with random moves, and backpropagating to update the tree, for many iterations, we just pick the child node with the maximum visits
def best_move(root):
return max(root.children, key=lambda c: c.visits).move
Now the agent can be simply structured a:
def mcts_move(board, player, iterations):
root = Node(board, player)
for _ in range(iterations):
mcts_iterate(root)
return best_move(root)
To check if the implementation, just make it play against itself, since TICTACTOE is mcts perfect, it should always result in a draw.