Search code examples
pythonpython-3.xconstructorpytorchsubclass

Subclass of PyTorch dataset class cannot find dataset files


I'm trying to create a subclass of the PyTorch MNIST dataset class, which I call CustomMNISTDataset, as follows:

import torchvision.datasets as datasets

class CustomMNISTDataset(datasets.MNIST):

    def __init__(self, root='/home/psando'):
        super().__init__(root=root,
                         download=False)

but when I execute:

dataset = CustomMNISTDataset()

it fails with error: "RuntimeError: Dataset not found. You can use download=True to download it".

However, when I run the following in the same file:

dataset = datasets.MNIST(root='/home/psando', download=False)
print(len(dataset))

it succeeds and prints "60000", as expected.

Since CustomMNISTDataset subclasses datasets.MNIST why is the behavior different? I've verified that the path '/home/psando' contains the MNIST directory with raw and processed subdirectories (otherwise, explicitly calling the constructor for datasets.MNIST() would have failed). The current behavior implies that the call to super().__init__() within CustomMNISTDataset is not calling the constructor for datasets.MNIST which is very strange!

Other details: I'm using Python 3.6.8 with torch==1.6.0 and torchvision==0.7.0. Any help would be appreciated!


Solution

  • This requires some source-diving, but your problem is this function. The path to the dataset is dependant on the name of the class, so when you subclass MNIST the root folder changes to /home/psando/CustomMNISTDataset

    So if you rename /home/psando/MNIST to /home/psando/CustomMNISTDataset it works.