Search code examples
pythonpytorchgenerative-adversarial-network

RuntimeError: running_mean should contain 1 elements not 200


I am implementing a conditional gan for image generation with text embedding from scratch and I am getting the above error exactly in the BatchNorm1d layer from the embedding_layers in the generator
generator class :

import torch.nn as nn
class Generator(nn.Module):
    def __init__(self, embedding_dim=300, latent_dim=100, image_size=64, num_channels=3):
        super(Generator, self).__init__()

        self.embedding_size = embedding_dim
        self.latent_dim = latent_dim
        self.image_size = image_size

        # Define embedding processing layers
        self.embedding_layers = nn.Sequential(
            nn.Linear(embedding_dim,latent_dim),
            nn.BatchNorm1d(latent_dim),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Define noise processing layers
        self.noise_layers = nn.Sequential(
            nn.Linear(latent_dim, image_size * image_size * 4),
            nn.BatchNorm1d(image_size * image_size * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Define image processing layers
        self.conv_layers = nn.Sequential(
            nn.ConvTranspose2d(latent_dim + 256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, num_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )
    def get_latent_dim(self):
        return self.latent_dim
    def forward(self, embeddings,noise):
        # Process embedding
        embedding_features = self.embedding_layers(embeddings)

        # Process noise
        
        noise_features = self.noise_layers(noise)

        # Combine features
        features = torch.cat((embedding_features, noise_features), dim=1)
        features = features.view(features.shape[0], -1, self.image_size // 16, self.image_size // 16)

        # Generate image
        image = self.conv_layers(features)

        return image

discriminator class:

import torch.nn as nn
class Discriminator(nn.Module):
    def __init__(self, embedding_dim=300, image_size=64, num_channels=3):
        super(Discriminator, self).__init__()

        # Define image processing layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

        # Define embedding processing layers
        self.embedding_layers = nn.Sequential(
            nn.Linear(embedding_dim, image_size * image_size),
            nn.BatchNorm1d(image_size * image_size),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, images, embeddings):
        # Process image
        image_features = self.conv_layers(images)

        # Process embedding
        embedding_features = self.embedding_layers(embeddings)
        embedding_features = embedding_features.view(embedding_features.shape[0], 1, 64, 64)

        # Combine features
        features = torch.cat((image_features, embedding_features), dim=1)

        # Classify
        classification = self.classification_layers(features).view(features.shape[0], -1)
        validity = self.validity_layers(features).view(features.shape[0], -1)

        return validity, classification

train function:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm

def train_gan(generator, discriminator, dataset, batch_size, num_epochs, device):
    """
    Trains a conditional GAN with a generator and a discriminator using a PyTorch dataset containing text embeddings and images.
    """
    # Set up loss functions and optimizers
    adversarial_loss = nn.BCELoss()
    generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    # Set up data loader
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    generator.to(device)
    discriminator.to(device)
    # Train the GAN
    for epoch in range(num_epochs):
        for i, data in enumerate(tqdm(data_loader)):
            # Load data and labels onto the device
            text_embeddings = data['text_embedding'].to(device)
            real_images = data['image'].to(device)
            
            # Generate fake images using the generator and the text embeddings
            noise = torch.randn(batch_size,generator.latent_dim).to(device)
            fake_images = generator(text_embeddings,noise)
            
            # Train the discriminator
            discriminator_optimizer.zero_grad()
            real_labels = torch.ones(real_images.size(0), 1).to(device)
            fake_labels = torch.zeros(fake_images.size(0), 1).to(device)
            real_predictions = discriminator(real_images, text_embeddings)
            real_loss = adversarial_loss(real_predictions, real_labels)
            fake_predictions = discriminator(fake_images.detach(), text_embeddings)
            fake_loss = adversarial_loss(fake_predictions, fake_labels)
            discriminator_loss = real_loss + fake_loss
            discriminator_loss.backward()
            discriminator_optimizer.step()
            
            # Train the generator
            generator_optimizer.zero_grad()
            fake_predictions = discriminator(fake_images, text_embeddings)
            generator_loss = adversarial_loss(fake_predictions, real_labels)
            generator_loss.backward()
            generator_optimizer.step()
            
            # Save generated images and model checkpoints every 500 batches
            if i % 500 == 0:
                with torch.no_grad():
                    fake_images = generator(text_embeddings[:16]).detach().cpu()
                save_image(fake_images, f"images\generated_images_epoch_{epoch}_batch_{i}.png", normalize=True, nrow=4)
                torch.save(generator.state_dict(), f"images\generator_checkpoint_epoch_{epoch}_batch_{i}.pt")
                torch.save(discriminator.state_dict(), f"images\discriminator_checkpoint_epoch_{epoch}_batch_{i}.pt")
                
        # Print loss at the end of each epoch
        print(f"Epoch [{epoch+1}/{num_epochs}] Discriminator Loss: {discriminator_loss.item()}, Generator Loss: {generator_loss.item()}")

main

# defining hyperparamter 
torch.cuda.empty_cache()
embedding_dim=768
img_size=512
latent_dim=200
batch_size=32
num_epochs=100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#define the main components
generator=Generator(embedding_dim=embedding_dim, latent_dim=latent_dim, image_size=img_size)
discriminator=Discriminator(embedding_dim=embedding_dim,image_size=img_size)

train_gan(generator=generator,
          discriminator=discriminator,
          dataset=dataset,
          batch_size=batch_size,
          num_epochs=num_epochs,
          device=device,)


as for my dataset, it consists of images and text embeddings with the following shape

torch.Size([3, 512, 512])
torch.Size([1, 768])


Solution

  • If you refer to https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html

    it says that input either (N, C) or (N, C,L). Your input on the other hand, is of shape (batch_size, 1, emb), which means that C is 1, and that is giving you the error, what you need to do is to remove the extra dim

    from torch import nn
    import torch
    embedding_dim = 768
    latent_dim = 200
    batch_size = 10
    inputs = torch.randn(batch_size, 1, 768).squeeze()
    model =nn.Sequential(
                nn.Linear(embedding_dim,latent_dim),
                nn.BatchNorm1d(latent_dim),
                nn.LeakyReLU(0.2, inplace=True)
            )
    print(model(inputs).shape)