Search code examples
pythondeep-learningpytorchreinforcement-learningbackpropagation

How to do backpropagation in PyTorch when training AlphaZero?


I'm trying to implement my version of AlphaZero for Connect Four. I have implemented a convolutional network using PyTorch and can get (random) value- and policy outputs from the model for given boardstates. Now I would like to simulate some games and train the model using them. However, I have encountered a problem:

As far as I understand, the training consists of basically two steps: a step in which selfplay is used to gather game data, and after that a step where the collected data is used to train the model using backpropagation. In the selfplay step, the network is used to get an evaluation of a position and a policy on how to choose the next move. The policy is then improved upon using a version of the MCTS algorithm.

After a game is finished, all the moves and the result is saved.

For simplicity, assume that I only play a single game and then want to update the model. If I save the MCTS policies and the network policies I can now calculate the loss. But I can't backpropagate through the model, since the forward pass happened during the collection step. I could in theory forward the same position through the model again, but that sounds not only inefficient, but since my architecture uses dropout layers I would not even get the same results.

So how can I solve this problem in PyTorch? Can I somehow save a model together with the dropout configuration that was used to create a policy? Then I could at least just forward the position again and use backprop afterwards, even if that would be inefficient.


Solution

  • In general it is not the practice to use gradients from self-play to backprop during training (for many reasons). It would be rather in-efficient to store gradients for later backprop. Plus there is exploration noise in self-play. Re-running is normal in RL training phase.

    In self-play, you will likely use eval mode to be on-policy. Drop out is only used in training for regularization purpose. In a sense drop-out can be helpful for exploration, but I think more apt exploration is using parameter noise.

    I don't know about AlphaZero, but IMHO it makes less sense to store dropout noise. If you want to do that, use replay buffer to store drop activation which you capture using register_forward_hook.