Search code examples
machine-learninggenerative-adversarial-network

Why is the Generator in my GAN not working?


I‘m currently coding a Generative Adversarial Network (GAN) from scratch with my own neural network library to generate MNIST handwritten digits. The discriminator seems to work fine, but the generator doesn‘t really learn anything over time. Maybe my training approach is wrong.

So my question is, if I can actually train my generator this way.

So first I train my discriminator with real Examples and the output 1 and then with fake examples generated by the generator and the output 0. This works fine. Next I train the generator by running the discriminator with fake examples, but with the output 1 (the generator wants the discriminator to classify his generated images as real),and I backpropagate the error all the way back to the input layer of the discriminator, but without updating his weights. This error of the input layer I then backpropagate through the generator and update him based on this. Can I actually do that and backpropagate the error of the discriminator through the generator? The generator is essentially the input to the discriminator right? Or this there a better way to do it? Any help is appreciated.


Solution

  • From your question, I assume you are proposing an approach like this: While training discriminator, you want to backpropagate till generator (to the point where we provide noise) instead of detaching it at the beginning of discriminator( the first layer of discriminator) ?

    If this is the case, then you are updating generator parameters with respect to discriminator's loss. The job of Discriminator is to update it's parameters so that it can classify between real and fake. If you don't stop the backpropagation and let it go inside generator, the parameters of the generator will get updated wrt disc loss which makes the generator produce an image that can be easily distingushed by discriminator. This create a mess as you're training gen to fool disc and at the same time, your gen is getting fooled by disc

    The approach is simply Generate a image from Generator Pass the real image to disc and equate it with 1 Pass the fake image to disc and equate it with 0 (or vice-versa) Perform back prop and make sure to detach fake image (fake.detach() in pytorch). So that halts backprop there itself and doesn't update generator parameters

    Then, perform generator training by passing the fake image through disc with 1 ( or 0 if you have taken the vice-versa case above)

    GANs do take a lot of time to train. To perform best training, https://github.com/soumith/ganhacks Follow these hacks