I'm trying to implement AlphaZero on a new game using this repository. I'm not sure if they are handling the MCTS search tree correctly.
The logic of their MCTS implementation is as follows:
- Get a "canonical form" of the current game state. Basically, switching player colors because the Neural Net always needs the input from the perspective of player with ID = 1. So if the current player is 1, nothing changes. If the current player is -1 the board is inverted.
- Call MCTS search. Source code
- In the expand-step of the algorithm, a new node is generated like this:
next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
next_s = self.game.getCanonicalForm(next_s, next_player)
"1" is the current player and "a" is the selected action. Since the input current player is always 1, next_player is always -1 and the board always gets inverted.
The problem occurs once we hit a terminal state:
- Assume that action a ends the game
- A next state (next_s) is returned by the "getNextState" method, next_player is set to -1. The board gets inverted one last time (1 becomes -1, -1 becomes 1). We now view the board from the perspective of the loser player. That means that a call to getGameEnded(canonicalBoard, 1) will always return -1 (or 0.0001 if it's a draw). Which means we can never observe a win for the player with ID 1.
- The getGameEnded function is implemented from the perspective of player with ID = 1. So it returns +1 if player with ID 1 wins, -1 if player with ID 1 loses.
My current understanding about MCTS is that we need to observe all possible game ending states of a two player zero-sum game. I tried to use the framework on my game, and it didn't learn or get better. I changed the game logic to explicitly keep track of the current player id so that I can return all three possible outcomes. Now, at least it seems to learn a bit, but I still think that there is something wrong.
Questions:
- Could this implementation theoretically work? Is it a correct implementation of the MCTS algorithm?
- Does MCTS need to observe all possible outcomes of a two player zero-sum game?
- Are there any obvious quick fixes of the code? Am I missing something about the implementation?