I'm trying to create a subclass of the PyTorch MNIST dataset class, which I call CustomMNISTDataset
, as follows:
import torchvision.datasets as datasets
class CustomMNISTDataset(datasets.MNIST):
def __init__(self, root='/home/psando'):
super().__init__(root=root,
download=False)
but when I execute:
dataset = CustomMNISTDataset()
it fails with error: "RuntimeError: Dataset not found. You can use download=True to download it".
However, when I run the following in the same file:
dataset = datasets.MNIST(root='/home/psando', download=False)
print(len(dataset))
it succeeds and prints "60000", as expected.
Since CustomMNISTDataset
subclasses datasets.MNIST
why is the behavior different? I've verified that the path '/home/psando' contains the MNIST directory with raw and processed subdirectories (otherwise, explicitly calling the constructor for datasets.MNIST()
would have failed). The current behavior implies that the call to super().__init__()
within CustomMNISTDataset
is not calling the constructor for datasets.MNIST
which is very strange!
Other details: I'm using Python 3.6.8 with torch==1.6.0
and
torchvision==0.7.0
. Any help would be appreciated!
This requires some source-diving, but your problem is this function. The path to the dataset is dependant on the name of the class, so when you subclass MNIST
the root folder changes to /home/psando/CustomMNISTDataset
So if you rename /home/psando/MNIST
to /home/psando/CustomMNISTDataset
it works.