Search code examples
pythonneural-networkpytorchshuffletraining-data

PyTorch DataLoader shuffle


I did an experiment and I did not get the result I was expecting.

For the first part, I am using

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=False, num_workers=0)

I save trainloader.dataset.targets to the variable a, and trainloader.dataset.data to the variable b before training my model. Then, I train the model using trainloader.
After the training is finished, I save trainloader.dataset.targets to the variable c, and trainloader.dataset.data to the variable d. Finally, I check a == c and b == d and they both give True, which was expected because the shuffle parameter of the DataLoader is False.

For the second part, I am using

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=True, num_workers=0)

I save trainloader.dataset.targets to the variable e, and trainloader.dataset.data to the variable f before training my model. Then, I train the model using trainloader. After the training is finished, I save trainloader.dataset.targets to the variable g, and trainloader.dataset.data to the variable h. I expect e == g and f == h to be both False since shuffle=True, but they give True again. What am I missing from the definition of DataLoader class?


Solution

  • I believe that the data that is stored directly in the trainloader.dataset.data or .target will not be shuffled, the data is only shuffled when the DataLoader is called as a generator or as iterator

    You can check it by doing next(iter(trainloader)) a few times without shuffling and with shuffling and they should give different results

    import torch
    import torchvision
    
    transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            ])
    MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                               transform = transform)
    dataLoader = torch.utils.data.DataLoader(MNIST_dataset,
                                             batch_size = 128,
                                             shuffle = False,
                                             num_workers = 10)
    target = dataLoader.dataset.targets
    
    
    MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                               transform = transform)
    
    dataLoader_shuffled= torch.utils.data.DataLoader(MNIST_dataset,
                                             batch_size = 128,
                                             shuffle = True,
                                             num_workers = 10)
    
    target_shuffled = dataLoader_shuffled.dataset.targets
    
    print(target == target_shuffled)
    
    _, target = next(iter(dataLoader));
    _, target_shuffled = next(iter(dataLoader_shuffled))
    
    print(target == target_shuffled)
    

    This will give :

    tensor([True, True, True,  ..., True, True, True])
    tensor([False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False, False, False, False, False,  True,
            False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False, False, False, False, False, False,
            False, False, False, False,  True, False, False, False, False, False,
            False,  True, False, False, False, False, False, False, False, False,
            False, False, False, False, False, False, False, False, False, False,
            False, False, False, False,  True,  True, False, False, False, False,
            False, False, False, False, False, False, False, False, False, False,
            False, False, False, False, False,  True, False, False,  True, False,
            False, False, False, False, False, False, False, False])
    

    However the data and label stored in data and target is a fixed list and since you are trying to access it directly, they will not be shuffled.