Search code examples
python-3.xpytorch

Get the next state of a generator in Pytorch


I initialized a generator in PyTorch with a manual seed. How do I obtain the current state/seed of the generator? For example, if I want to check the state in a loop, how can I do that? This is what I want to achieve:

import torch

gen = torch.Generator("cpu").manual_seed(42)

for _ in range(5):
    # print the next state of the generator

PS: I tried to use gen.seed() but it gives different results on different runs.


Solution

  • You can use torch.Generator.get_state to retrieve the current state of the generator.

    Per the docs:

    get_state() → Tensor

    Returns the Generator state as a torch.ByteTensor.

    Returns: A torch.ByteTensor which contains all the necessary bits to restore a Generator to a specific point in time.

    You can pass the value you get from get_state to set_state to restore the previous state of the generator.