Search code examples
google-cloud-platformpytorchjupyter-labmnisttorchvision

torchvision.datasets.mnist RunTimeError on JupyterLab


I'm trying to run the following sample code on JupyterLab (through GCP vertex AI):

import torch
from torchvision import transforms
from torchvision import datasets

train_data = datasets.MNIST(root='data', train=True, download=True, transform=None)
print(train_data)

with versions: torch-1.12.1+cu113 torchvision-0.13.1+cu113

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_10081/229378695.py in <module>
     11 from torchvision import datasets
     12 
---> 13 train_data = datasets.MNIST(root='data', train=True, download=True, transform=None)
     14 print(train_data)

/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in __init__(self, root, train, transform, target_transform, download)
    102             raise RuntimeError("Dataset not found. You can use download=True to download it")
    103 
--> 104         self.data, self.targets = self._load_data()
    105 
    106     def _check_legacy_exist(self):

/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in _load_data(self)
    121     def _load_data(self):
    122         image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
--> 123         data = read_image_file(os.path.join(self.raw_folder, image_file))
    124 
    125         label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"

/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in read_image_file(path)
    542 
    543 def read_image_file(path: str) -> torch.Tensor:
--> 544     x = read_sn3_pascalvincent_tensor(path, strict=False)
    545     if x.dtype != torch.uint8:
    546         raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")

/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in read_sn3_pascalvincent_tensor(path, strict)
    529 
    530     assert parsed.shape[0] == np.prod(s) or not strict
--> 531     return parsed.view(*s)
    532 
    533 

RuntimeError: shape '[60000, 28, 28]' is invalid for input of size 9437168
____________________

and I'm getting this strange error when trying to load MNIST

  • I tried reproducing it in other envaironments but couldn't - it works great locally & on cloab
  • I tried lots of other versions of torch and torchvision but non of them works

Solution

  • This error is often caused by an issue with the MNIST dataset files that are downloaded onto your system. Try deleting the MNIST dataset files in the data directory and then running the code again to download fresh copies of the dataset files. Follow this code:

    import os
    import shutil
    
    mnist_folder = 'data/MNIST'
    if os.path.exists(mnist_folder):
        shutil.rmtree(mnist_folder)
    
    train_data = datasets.MNIST(root='data', train=True, download=True, transform=None)
    

    If this method doesn't work, visit this website and placing them in the data/MNIST folder.