Search code examples
pythonpytorchvalueerrorgenerative-adversarial-network

Value Error: Torch target size and torch input size in GAN do not match


Hi I am working on a GAN with custom images. I got the following error, which doesn't add up for me:

ValueError: Using a target size (torch.Size([64, 1])) that is different to the input size (torch.Size([47, 1])) is deprecated. Please ensure they have the same size.

I do not see where either of these sizes come from. Could someone please help me out? The error is to be found at the loss_disrimenator in the course of the training (marked with an arrow) after epoch 0. Below you find the related code. I am using vs code windows.

Also is it normal that epoch 0 works and then the problem appears?

[Sceenshot of Terminal- Epoch 0 Loss Discriminatorand Generator][1]

import torch
from glob import glob
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from skimage import io
import matplotlib.pyplot as plt


path = 'Punks'
image_paths = glob(path + '/*.png')

img_size = 28
batch_size = 32

transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)


class ImageDataset(Dataset):
    def __init__(self, paths, transform):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        image_path = self.paths[index]
        image = io.imread(image_path)

        if self.transform:
            image_tensor = self.transform(image)

        return image_tensor


if __name__ == '__main__':

    dataset = ImageDataset(image_paths, transform)

    train_loader = DataLoader(
        dataset, batch_size=batch_size, num_workers=1, shuffle=True)

    # PLOTTING SAMPLES

    real_samples = next(iter(train_loader))
    for i in range(9):
        ax = plt.subplot(3, 3, 3 + 1)
        plt.imshow(real_samples[i].reshape(28, 28, 3))
        plt.xticks([])
        plt.yticks([])
        plt.show()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    class Discriminator(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = nn.Sequential(
                nn.Linear(784*3, 2048),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(2048, 1024),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(1024, 512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, 1),
                nn.Sigmoid(),
            )

        def forward(self, x):
            x = x.view(x.size(0), 784*3)  # change required for 3 channel image
            output = self.model(x)
            return output

    discriminator = Discriminator().to(device=device)

    class Generator(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = nn.Sequential(
                nn.Linear(100, 256),
                nn.ReLU(),
                nn.Linear(256, 512),
                nn.ReLU(),
                nn.Linear(512, 1024),
                nn.ReLU(),
                nn.Linear(1024, 2048),
                nn.ReLU(),
                nn.Linear(2048, 784*3),
                nn.Tanh(),
            )

        def forward(self, x):
            output = self.model(x)
            output = output.view(x.size(0), 3, 28, 28)
            return output

    generator = Generator().to(device=device)

    # TRAINING PARAMS

    lr = 0.0001
    num_epochs = 10
    loss_function = nn.BCELoss()

    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)
    for epoch in range(num_epochs):
        for n, real_samples in enumerate(train_loader):
            # Data for training the discriminator
            real_samples = real_samples.to(device=device)
            real_samples_labels = torch.ones((batch_size, 1)).to(
                device=device
            )
            latent_space_samples = torch.randn((batch_size, 100)).to(
                device=device
            )
            print(f'Latent space samples : {latent_space_samples.shape}')
            generated_samples = generator(latent_space_samples)
            generated_samples_labels = torch.zeros((batch_size, 1)).to(
                device=device
            )
            all_samples = torch.cat((real_samples, generated_samples))
            print(f'Real samples : {real_samples.shape}, generated samples : {generated_samples.shape}')
            all_samples_labels = torch.cat(
                (real_samples_labels, generated_samples_labels)
            )

            # Training the discriminator
            discriminator.zero_grad()
            output_discriminator = discriminator(all_samples)
            loss_discriminator = loss_function(
                output_discriminator, all_samples_labels
            )
 ------->    loss_discriminator.backward()
            optimizer_discriminator.step()

            # Data for training the generator
            latent_space_samples = torch.randn((batch_size, 100)).to(
                device=device
            )

            # Training the generator
            generator.zero_grad()
            generated_samples = generator(latent_space_samples)
            output_discriminator_generated = discriminator(generated_samples)
            loss_generator = loss_function(
                output_discriminator_generated, real_samples_labels
            )
            loss_generator.backward()
            optimizer_generator.step()

            # Show loss
            if n == batch_size - 1:
                print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
                print(f"Epoch: {epoch} Loss G.: {loss_generator}")

    latent_space_samples = torch.randn(batch_size, 100).to(device=device)
    generated_samples = generator(latent_space_samples)

    generated_samples = generated_samples.cpu().detach()
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(generated_samples[i].reshape(28, 28, 3))
        plt.xticks([])
        plt.yticks([])
        plt.show()´´´

Solution

  • Since this is happening at the end of the first epoch, what's essentially happening is that you have specified a batch size of 64 but the number of images in your dataset is some_integer_number * 64 + 47. This is because when you read the data in batches, the number of samples equal to your batch_size is read. However, when you reach the end of the epoch, there is a possibility that fewer than batch_size examples are left to load.

    In your code, the number of generated images in the last step of the 0th epoch is 47 whereas the number of fake images that you are generating is 64 since you use batch_size to sample batch_size number of fake images.

    A simple solution would be to use len(real_samples) in place of batch_size at all the places. You can do this by first setting batch_size=len(real_samples) as the first line in the for loop.

    for epoch in range(num_epochs):
       for n, real_samples in enumerate(train_loader):
          # Data for training the discriminator
          batch_size = len(real_samples)
          real_samples = real_samples.to(device=device)
          real_samples_labels = torch.ones((batch_size, 1)).to(device=device)
    
          # rest of the code continues
    

    I hope this solves your issue.