Search code examples
pythonpytorchcomputer-visiongenerative-adversarial-networktorchvision

How can I save my GAN generated images after every epoch?


I tried to build my very own GAN in PyTorch. I wanted to see how my model learns to generate images over time, I tried to save images it created after each epoch but it saved the same image after every epoch. I guess I do save the same image everytime. You can see first 3 epochs' images In addition, as you can see, it combines all images to save, can I choose only 1?

class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)
        
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 512),
            nn.LeakyReLU(0.1),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.1),
            nn.Linear(1024, img_dim),
            nn.Tanh() 
        )

    def forward(self, x):
        return self.gen(x)
lr = 3e-4
z_dim = 64
image_dim = 256 * 256 * 3
batch_size = 32
num_epochs = 16

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

# dataset = load_dataset(data_path="mountain_dataset", transform=transform)
loader = DataLoader(dataset, batch_size=32, num_workers=0, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

second part:

from torchvision.utils import save_image

step = 0

for epoch in range(num_epochs):
    for batch_idx, real in enumerate(dataset):
        real = real.view(-1, image_dim).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(real)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2

        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator maximize log(D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        
        if batch_idx == 0:
            print(
                f"Epoch: [{epoch+1}/{num_epochs}]"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 3, 256, 256)
                data = real.reshape(-1, 3, 256, 256)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                # Convert the NumPy array to a PyTorch tensor
                img_grid_fake_tensor = img_grid_fake
                
                # Save the PyTorch tensor as an image
                save_image(img_grid_fake_tensor, f"generated_images/epoch{epoch}.png", normalize=True)
                
                step += 1

Solution

  • Question 1: Images are the same

    First, in the line for batch_idx, real in enumerate(dataset): you iterate over the dataset. I.e., real represents one image, not one batch. If you add the line print(real.shape) as the first line after the loop, this will print torch.Size([3, 256, 256]) which is one image rather than one batch. Accordingly, your batch_size will always be three and actually be the number of channels.

    So you have to change this line to for batch_idx, real in enumerate(loader):. Then, the print will give you torch.Size([32, 3, 256, 256]), which is what you actually want.

    After that modification, the saved images are different after each epoch for the first few epochs. I tested this with Cifar10 upscaled:

    dataset = torchvision.datasets.CIFAR10(root="dataset/", transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))
        ]), download=True)
    dataset = torch.utils.data.Subset(dataset, range(0, 1000))
    loader = DataLoader(dataset, batch_size=32, shuffle=True, pin_memory=True)
    

    However, after 3 epochs, the generated images stay indeed the same. This is because your model stops learning.

    Epoch: [1/16] Loss D: 0.6715, Loss G: 3.8073
    Epoch: [2/16] Loss D: 50.3202, Loss G: 0.0000
    Epoch: [3/16] Loss D: 50.0000, Loss G: 0.0000
    Epoch: [4/16] Loss D: 50.0030, Loss G: 0.0000
    Epoch: [5/16] Loss D: 50.0000, Loss G: 0.0000
    

    You will have to modify your architecture in order to keep training for more epochs. For instance, your discriminator is much smaller than the generator. As a starting point, you might want to have both models to have a similar number of parameters. Your can see the parameter count like this:

    print(
        f"Discriminator Parameters:\t{sum(p.numel() for p in disc.parameters())}\n"
        f"Generator Parameters:\t\t{sum(p.numel() for p in gen.parameters())}"
    )
    

    For instance, increasing the discriminator like this gives me an additional three epochs of learning:

    class Discriminator(nn.Module):
        def __init__(self, img_dim):
            super().__init__()
            self.disc = nn.Sequential(
                nn.Linear(img_dim, 512),
                nn.LeakyReLU(0.1),
                nn.Linear(512, 1024),
                nn.LeakyReLU(0.1),
                nn.Linear(1024, 1),
                nn.Sigmoid()
            )
    

    Question 2: How to save only one image

    With the line img_grid_fake = torchvision.utils.make_grid(fake, normalize=True), you create a grid of all 32 (batch_size) images which is then saved to a file.

    If you only want to randomly choose four of these for saving, you could do this like that:

    fake = gen(fixed_noise).reshape(-1, 3, 256, 256)
    # Select 4 random images to print
    fake = random.choices(fake, k=4)
    img_grid_fake = torchvision.utils.make_grid(fake, normalize=True, nrow=2)
    

    If you only want to save a single image instead of a grid:

    # Print only image number 6
    index = 5 # For random: random.randint(0, batch_size-1)
    fake = gen(fixed_noise[index]).reshape(3, 256, 256)
    save_image(fake, f"generated_images/epoch{epoch}.png", normalize=True)
    

    Note

    I deleted the lines

    data = real.reshape(-1, 3, 256, 256)
    img_grid_real = torchvision.utils.make_grid(data, normalize=True)
    

    are unnecessary because unless you actually want to print some real samples, too.