Search code examples
tensorflowgenerative-adversarial-networkadversarial-machines

How to update GAN Generator and Discriminator asynchronously in Tensorflow?


I want to develop a GAN with Tensorflow, with the Generator being an autoencoder and the Discriminator a Convolutional Neural Net with binary output. There is no problem to develop an autoencoder and the CNN, but my idea is to train 1 epoch for each one of the components (Discriminator and Generator) and repeat this cycle for 1000 epochs, keeping the results (weights) of the previous training epoch for the next one. How can I operationalize this ?


Solution

  • I solved the problem. In fact, I want the output of the autoencoder to be the input of the CNN, connecting the GAN and updating weights in the proportion 1:1. I noticed I had to have a special care differentiating the losses of the generator and the discriminator, otherwise in the start of the second loop the tensor loss of the Generator will be replaced by a float, the last loss generated by Discriminator.

    Here´s the code:

    with tf.Session() as sess:
    sess.run(init)
    for i in range(1, num_steps+1):
    

    here the Generator training

        batch_x, batch_y=next_batch(batch_size, x_train_noisy, x_train)        
        _, l = sess.run([optimizer, loss], feed_dict={X: batch_x.reshape(n,784),
                        Y:batch_y})
        if i % display_step == 0 or i == 1:
            print('Epoch %i: Denoising Loss: %f' % (i, l))
    

    here the output of the Generator will be used as an input for the Discriminator

        output=sess.run([decoder_op],feed_dict={X: x_train})
        x_train2=np.array(output).reshape(n,784).astype(np.float64)
    

    here the Discriminator training

        batch_x2, batch_y2 = next_batch(batch_size, x_train2, y_train)
        sess.run(train_op, feed_dict={X2: batch_x2.reshape(n,784), Y2: batch_y2, keep_prob: 0.8})
        if i % display_step == 0 or i == 1:
            loss3, acc = sess.run([loss_op2, accuracy], feed_dict={X2: batch_x2,
                                                                 Y2: batch_y2,
                                                                 keep_prob: 1.0})
            print("Epoch " + str(i) + ", CNN Loss= " + \
                  "{:.4f}".format(loss3) + ", Training Accuracy= " + "{:.3f}".format(acc))
    

    This way the asynchronous update can be operationalized in the proportion 1:1, 1:5, 5:1 (Discriminator : Generator) or any other way