I'm trying to run the following sample code on JupyterLab (through GCP vertex AI):
import torch
from torchvision import transforms
from torchvision import datasets
train_data = datasets.MNIST(root='data', train=True, download=True, transform=None)
print(train_data)
with versions: torch-1.12.1+cu113 torchvision-0.13.1+cu113
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_10081/229378695.py in <module>
11 from torchvision import datasets
12
---> 13 train_data = datasets.MNIST(root='data', train=True, download=True, transform=None)
14 print(train_data)
/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in __init__(self, root, train, transform, target_transform, download)
102 raise RuntimeError("Dataset not found. You can use download=True to download it")
103
--> 104 self.data, self.targets = self._load_data()
105
106 def _check_legacy_exist(self):
/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in _load_data(self)
121 def _load_data(self):
122 image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
--> 123 data = read_image_file(os.path.join(self.raw_folder, image_file))
124
125 label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in read_image_file(path)
542
543 def read_image_file(path: str) -> torch.Tensor:
--> 544 x = read_sn3_pascalvincent_tensor(path, strict=False)
545 if x.dtype != torch.uint8:
546 raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
/opt/conda/lib/python3.7/site-packages/torchvision/datasets/mnist.py in read_sn3_pascalvincent_tensor(path, strict)
529
530 assert parsed.shape[0] == np.prod(s) or not strict
--> 531 return parsed.view(*s)
532
533
RuntimeError: shape '[60000, 28, 28]' is invalid for input of size 9437168
____________________
and I'm getting this strange error when trying to load MNIST
This error is often caused by an issue with the MNIST dataset files that are downloaded onto your system. Try deleting the MNIST dataset files in the data
directory and then running the code again to download fresh copies of the dataset files. Follow this code:
import os
import shutil
mnist_folder = 'data/MNIST'
if os.path.exists(mnist_folder):
shutil.rmtree(mnist_folder)
train_data = datasets.MNIST(root='data', train=True, download=True, transform=None)
If this method doesn't work, visit this website and placing them in the data/MNIST
folder.