I am working on implementing a Generative Adversarial Network (GAN) in PyTorch 1.5.0.
For computing the loss of the generator, I compute both the negative probabilities that the discriminator mis-classifies an all-real minibatch and an all-(generator-generated-)fake minibatch. Then, I back-propagate both parts sequentially and finally apply the step function.
Calculating and back-propagating the part of the loss which is a function of the mis-classifications of the generated fake data seems straight forward, since during back-propagation of that loss term, the backward path leads through the generator who has produced the fake data in the first place.
However, classification of all-real-data minibatches does not involve passing data through the generator. Therefore, I was wondering whether the following code snipped would still calculate gradients for the generator or whether it would not calculate any gradients at all (since the backward path does not lead through the generator and the discriminator is in eval-mode while updating the generator)?
# Update generator #
net.generator.train()
net.discriminator.eval()
net.generator.zero_grad()
# All-real minibatch
x_real = get_all_real_minibatch()
y_true = torch.full((batch_size,), label_fake).long() # Pretend true targets were fake
y_pred = net.discriminator(x_real) # Produces softmax probability distribution over (0=label_fake,1=label_real)
loss_real = NLLLoss(torch.log(y_pred), y_true)
loss_real.backward()
optimizer_generator.step()
If this doesn’t work as intended, how could I make it work? Thanks in advance!
No gradients are propagated to the generator, as no calculation was performed with any of the generator's parameters. The discriminator being in eval mode would not prevent the gradients from propagating to the generator, albeit they would be slightly different if you are using layers that behave differently in eval mode compared to train mode, such as dropout.
The misclassification of real images is not part of training the generator, because it doesn't gain anything from this information. Conceptually, what should the generator learn from the fact that the discriminator failed to correctly classify a real image? The sole task of the generator is to create a fake image such that the discriminator thinks it's real, therefore the only relevant information for the generator is whether the discriminator was able to identify the fake image. If the discriminator was indeed able to identify the fake image, the generator needs to adjust itself to create a more convincing fake.
Of course it's not a binary case, but the generator always tries to improve the fake image such that the discriminator is even more convinced that it was a real image. The generator's goal is not to make the discriminator be doubtful (probability of 0.5 that it's real or fake), but that the discriminator is fully convinced that it's real, even though it's fake. That's why they are adversarial, not cooperative.