Consider the following piece of code to fetch a data set for training from torchvision.datasets
and to create a DataLoader
for it.
import torch
from torchvision import datasets, transforms
training_set_mnist = datasets.MNIST('./mnist_data', train=True, download=True)
train_loader_mnist = torch.utils.data.DataLoader(training_set_mnist, batch_size=128,
shuffle=True)
Assume that several Python processes have access to the folder ./mnist_data
and execute the above piece of code simultaneously; in my case, each process is a different machine on a cluster and the data set is stored in an NFS location accessible by everyone. You may also assume that the data is already downloaded in this folder so download=True
should have no effect. Moreover, each process may use a different seed, as set by torch.manual_seed()
.
I would like to know whether this scenario is allowed in PyTorch. My main concern is whether the above code can change the data folders or files in ./mnist_data
such that if ran by multiple processes it can potentially lead to unexpected behavior or other issues. Also, given that shuffle=True
I would expect that if 2 or more processes try to create the DataLoader
each of them will get a different shuffling of the data assuming that the seeds are different. Is this true?
My main concern is whether the above code can change the data folders or files in ./mnist_data such that if ran by multiple processes it can potentially lead to unexpected behavior or other issues.
You will be fine as processes are only reading data, not modifying in (loading tensors
with data into RAM in case of MNIST
). Please notice processes do not share memory addresses, hence tensor
with data will be loaded multiple times (which shouldn't be a big problem in case of MNIST
).
Also, given that
shuffle=True
I would expect that if 2 or more processes try to create the DataLoader each of them will get a different shuffling of the data assuming that the seeds are different.
shuffle=True
has nothing to do with data itself. What it does, is it get __len__()
of provided dataset
, makes a range [0, __len__())
and this range is shuffled and used to index dataset
's __getitem__
. Check out this section for more info about Samplers
.