Search code examples
pytorchrandom-seedreproducible-research

How to save and load random number generator state in Pytorch?


I am training a DL model in Pytorch, and want to train my model in a deterministic way. As written in this official guide, I set random seeds like this:

np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Now, my training is long and i want to save, then later load everything, including the RNGs. I use torch.save and torch.load_state_dict for the model and the optimizer.

How can the random number generators be saved & loaded?


Solution

  • You can use torch.get_rng_state and torch.set_rng_state

    When calling torch.get_rng_state you will get your random number generator state as a torch.ByteTensor.

    You can then save this tensor somewhere in a file and later you can load and use torch.set_rng_state to set the random number generator state.


    When using numpy you can of course do the same there using:
    numpy.random.get_state and numpy.random.set_state