Search code examples
deep-learningneural-networkpytorchtorchgenerative-adversarial-network

Pytorch : different behaviours in GAN training with different, but conceptually equivalent, code


I'm trying to implement a simple GAN in Pytorch. The following training code works:

    for epoch in range(max_epochs):  # loop over the dataset multiple times
        print(f'epoch: {epoch}')
        running_loss = 0.0

        for batch_idx,(data,_) in enumerate(data_gen_fn):
   
            # data preparation
            real_data            = data
            input_shape          = real_data.shape
            inputs_generator     = torch.randn(*input_shape).detach() 

            # generator forward
            fake_data            = generator(inputs_generator).detach()
            # discriminator forward
            optimizer_generator.zero_grad()
            optimizer_discriminator.zero_grad()

            #################### ALERT CODE #######################
            predictions_on_real = discriminator(real_data)
            predictions_on_fake = discriminator(fake_data)

            predictions = torch.cat((predictions_on_real,
                                     predictions_on_fake), dim=0)
           #########################################################

            # loss discriminator
            labels_real_fake           = torch.tensor([1]*batch_size + [0]*batch_size)
            loss_discriminator_batch   = criterion_discriminator(predictions, 
                                                          labels_real_fake)
            # update discriminator
            loss_discriminator_batch.backward()
            optimizer_discriminator.step()


            # generator
            # zero the parameter gradients
            optimizer_discriminator.zero_grad()
            optimizer_generator.zero_grad()

            fake_data            = generator(inputs_generator) # make again fake data but without detaching
            predictions_on_fake  = discriminator(fake_data) # D(G(encoding))
            
            # loss generator           
            labels_fake          = torch.tensor([1]*batch_size)
            loss_generator_batch = criterion_generator(predictions_on_fake, 
                                                       labels_fake)
  
            loss_generator_batch.backward()  # dL(D(G(encoding)))/dW_{G,D}
            optimizer_generator.step()

If I plot the generated images for each iteration, I see that the generated images look like the real ones, so the training procedure seems to work well.

However, if I try to change the code in the ALERT CODE part , i.e., instead of:

   #################### ALERT CODE #######################
   predictions_on_real = discriminator(real_data)
   predictions_on_fake = discriminator(fake_data)

   predictions = torch.cat((predictions_on_real,
                            predictions_on_fake), dim=0)
   #########################################################

I use the following:

   #################### ALERT CODE #######################
   predictions = discriminator(torch.cat( (real_data, fake_data), dim=0))
   #######################################################

That is conceptually the same (in a nutshell, instead of doing two different forward on the discriminator, the former on the real, the latter on the fake data, and finally concatenate the results, with the new code I first concatenate real and fake data, and finally I make just one forward pass on the concatenated data.

However, this code version does not work, that is the generated images seems to be always random noise.

Any explanation to this behavior?


Solution

  • Why do we different results?

    Supplying inputs in either the same batch, or separate batches, can make a difference if the model includes dependencies between different elements of the batch. By far the most common source in current deep learning models is batch normalization. As you mentioned, the discriminator does include batchnorm, so this is likely the reason for different behaviors. Here is an example. Using single numbers and a batch size of 4:

    features = [1., 2., 5., 6.]
    print("mean {}, std {}".format(np.mean(features), np.std(features)))
    
    print("normalized features", (features - np.mean(features)) / np.std(features))
    
    >>>mean 3.5, std 2.0615528128088303
    >>>normalized features [-1.21267813 -0.72760688  0.72760688  1.21267813]
    

    Now we split the batch into two parts. First part:

    features = [1., 2.]
    print("mean {}, std {}".format(np.mean(features), np.std(features)))
    
    print("normalized features", (features - np.mean(features)) / np.std(features))
    
    >>>mean 1.5, std 0.5
    >>>normalized features [-1.  1.]
    

    Second part:

    features = [5., 6.]
    print("mean {}, std {}".format(np.mean(features), np.std(features)))
    
    print("normalized features", (features - np.mean(features)) / np.std(features))
    
    >>>mean 5.5, std 0.5
    >>>normalized features [-1.  1.]
    

    As we can see, in the split-batch version, the two batches are normalized to the exact same numbers, even though the inputs are very different. In the joint-batch version, on the other hand, the larger numbers are still larger than the smaller ones as they are normalized using the same statistics.

    Why does this matter?

    With deep learning, it's always hard to say, and especially with GANs and their complex training dynamics. A possible explanation is that, as we can see in the example above, the separate batches result in more similar features after normalization even if the original inputs are quite different. This may help early in training, as the generator tends to output "garbage" which has very different statistics from real data.

    With a joint batch, these differing statistics make it easy for the discriminator to tell the real and generated data apart, and we end up in a situation where the discriminator "overpowers" the generator.

    By using separate batches, however, the different normalizations result in the generated and real data to look more similar, which makes the task less trivial for the discriminator and allows the generator to learn.