Search code examples
machine-learningpytorchmnistpytorch-dataloader

What parameter do I need to change for it to match requirements?


I am trying to train a model based on a modified MNSIT dataset so it classifies random images with label 10. I am constantly getting a Typeerror.

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset1 = datasets.MNIST(root='./data', train=True, transform = transform)
dataset2 = datasets.MNIST(root='./data', train=False, transform=transform)
num_new_images = 7000
noisy_images = torch.randn(num_new_images, 1, 28, 28)
mean = 0.1307
std = 0.3081
random_images = (noisy_images-mean)/std
noisy_labels = torch.full((num_new_images,),10, dtype=torch.long)
new_dataset = torch.utils.data.TensorDataset(noisy_images, noisy_labels)
combined_dataset = torch.utils.data.ConcatDataset([dataset1, new_dataset])
len(combined_dataset)   
num_val_images = 1000
noisy_images = torch.randn(num_val_images, 1, 28, 28)
random_val_images = (noisy_images-mean)/std
noisy_val_labels = torch.full((num_val_images,),10, dtype=torch.long)
new_val_dataset = torch.utils.data.TensorDataset(random_val_images, noisy_val_labels)
combined_val_dataset = torch.utils.data.ConcatDataset([dataset2, new_val_dataset])
batch_size = 128
train_loader = torch.utils.data.DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(combined_val_dataset, batch_size=batch_size, shuffle=False)

Error:

TypeError                                 Traceback (most recent call last)
Cell In[12], line 74
     72 # Train the neural network
     73 for epoch in range(num_epochs):
---> 74     for images, labels in train_loader:
     75         outputs = model(images)
     76         loss = criterion(outputs, labels)

File ~\PycharmProjects\tensorflow_start\venv\Lib\site-packages\torch\utils\data\dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
    630 if self._sampler_iter is None:
    631     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    632     self._reset()  # type: ignore[call-arg]
--> 633 data = self._next_data()
    634 self._num_yielded += 1
    635 if self._dataset_kind == _DatasetKind.Iterable and \
    636         self._IterableDataset_len_called is not None and \
    637         self._num_yielded > self._IterableDataset_len_called:

I have already tried to change the datatypes of the labels, but it didn't work


Solution

  • It should work if you transform the MNIST dataset labels/targets into tensors as well, e.g.,

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    def target_transform(t):
        return torch.tensor(t)
    
    dataset1 = datasets.MNIST(
        root='./data',
        train=True,
        transform=transform,
        target_transform=target_transform,
    )
    
    ...