Search code examples
pytorchtorch

Fixing the seed for torch random_split()


Is it possible to fix the seed for torch.utils.data.random_split() when splitting a dataset so that it is possible to reproduce the test results?


Solution

  • You can use torch.manual_seed function to seed the script globally:

    import torch
    torch.manual_seed(0)
    

    See reproducibility documentation for more information.

    If you want to specifically seed torch.utils.data.random_split you could "reset" the seed to it's initial value afterwards. Simply use torch.initial_seed() like this:

    torch.manual_seed(torch.initial_seed())
    

    AFAIK pytorch does not provide arguments like seed or random_state (which could be seen in sklearn for example).