Search code examples
pythonnumpyparallel-processingpytorchdataloader

PyTorch DataLoader uses same random seed for batches run in parallel


There is a bug in PyTorch/Numpy where when loading batches in parallel with a DataLoader (i.e. setting num_workers > 1), the same NumPy random seed is used for each worker, resulting in any random functions applied being identical across parallelized batches.

Minimal example:

import numpy as np
from torch.utils.data import Dataset, DataLoader

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1000, 2)

    def __len__(self):
        return 9
    
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=1, num_workers=3)

for batch in dataloader:
    print(batch)

As you can see, for each parallelized set of batches (3), the results are the same:

# First 3 batches
tensor([[891, 674]])
tensor([[891, 674]])
tensor([[891, 674]])
# Second 3 batches
tensor([[545, 977]])
tensor([[545, 977]])
tensor([[545, 977]])
# Third 3 batches
tensor([[880, 688]])
tensor([[880, 688]])
tensor([[880, 688]])

What is the recommended/most elegant way to fix this? i.e. have each batch produce a different randomization, irrespective of the number of workers.


Solution

  • It seems this works, at least in Colab:

    dataloader = DataLoader(dataset, batch_size=1, num_workers=3, 
        worker_init_fn = lambda id: np.random.seed(id) )
    

    EDIT:

    it produces identical output (i.e. the same problem) when iterated over epochs. – iacob

    Best fix I have found so far:

    ...
    dataloader = DataLoader(ds, num_workers= num_w, 
               worker_init_fn = lambda id: np.random.seed(id + epoch * num_w ))
    
    for epoch in range ( 2 ):
        for batch in dataloader:
            print(batch)
        print()
    

    Still can't suggest closed form, thing depends on a var (epoch) then called. Ideally It must be something like worker_init_fn = lambda id: np.random.seed(id + EAGER_EVAL(np.random.randint(10000) ) where EAGER_EVAL evaluate seed on loader construction, before lambda is passed as parameter. Is it possible in python, I wonder.