Search code examples
pythonkerasconv-neural-networkloss-functiongenerative-adversarial-network

Add class information to keras network


I am trying to figure out how I will use the label information of my dataset with Generative Adversarial Networks. I am trying to use the following implementation of conditional GANs that can be found here. My dataset contains two different image domains (real objects and sketches) with common class information (chair, tree, orange etc). I opted for this implementation which only considers the two different domains as different "classes" for the correspondence (train samples X correspond to the real images while target samples y correspond to the sketch images).

Is there a way to modify my code and take into account the class information (chair, tree, etc.) in my whole architecture? I want actually my discriminator to predict whether or not my generated images from the generator belong to a specific class and not only whether they are real or not. As it is, with the current architecture, the system learns to create similar sketches in all cases.

Update: The discriminator returns a tensor of size 1x7x7 then both y_true and y_pred are passed through a flatten layer before calculating the loss:

def discriminator_loss(y_true, y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.concatenate([K.ones_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])),K.zeros_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])) ]) ), axis=-1)

and the loss function of the discriminator over the generator:

def discriminator_on_generator_loss(y_true,y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.ones_like(K.flatten(y_pred))), axis=-1)

Furthremore, my modification of the discriminator model for output 1 layer:

model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
#model.add(Activation('sigmoid'))

Now the discriminator outputs 1 layer. How can I modify the above-mentioned loss functions correspondingly? Should I have 7 instead of 1, for the n_classes = 6 + one class for predicting real and fake pairs?


Solution

  • Suggested Solution

    Reusing the code from the repository you shared, here are some suggested modifications to train a classifier along your generator and discriminator (their architectures and other losses are left untouched):

    from keras import backend as K
    from keras.models import Sequential
    from keras.layers.core import Dense, Dropout, Activation, Flatten
    from keras.layers.convolutional import Convolution2D, MaxPooling2D
    
    def lenet_classifier_model(nb_classes):
        # Snipped by Fabien Tanc - https://www.kaggle.com/ftence/keras-cnn-inspired-by-lenet-5
        # Replace with your favorite classifier...
        model = Sequential()
        model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Flatten())
        model.add(Dense(180, activation='relu', init='he_normal'))
        model.add(Dropout(0.5))
        model.add(Dense(100, activation='relu', init='he_normal'))
        model.add(Dropout(0.5))
        model.add(Dense(nb_classes, activation='softmax', init='he_normal'))
    
    def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
        inputs = Input((IN_CH, img_cols, img_rows))
        x_generator = generator(inputs)
    
        merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
        discriminator.trainable = False
        x_discriminator = discriminator(merged)
    
        classifier.trainable = False
        x_classifier = classifier(x_generator)
    
        model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])
    
        return model
    
    
    def train(BATCH_SIZE):
        (X_train, Y_train, LABEL_train) = get_data('train')  # replace with your data here
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5
        discriminator = discriminator_model()
        generator = generator_model()
        classifier = lenet_classifier_model(6)
        generator.summary()
        discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
            generator, discriminator, classifier)
        d_optim = Adagrad(lr=0.005)
        g_optim = Adagrad(lr=0.005)
        generator.compile(loss='mse', optimizer="rmsprop")
        discriminator_and_classifier_on_generator.compile(
            loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
            optimizer="rmsprop")
        discriminator.trainable = True
        discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")
        classifier.trainable = True
        classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop")
    
        for epoch in range(100):
            print("Epoch is", epoch)
            print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))
            for index in range(int(X_train.shape[0] / BATCH_SIZE)):
                image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
                label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]  # replace with your data here
    
                generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
                if index % 20 == 0:
                    image = combine_images(generated_images)
                    image = image * 127.5 + 127.5
                    image = np.swapaxes(image, 0, 2)
                    cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image)
                    # Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")
    
                # Training D:
                real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),
                                            axis=1)
                fake_pairs = np.concatenate(
                    (X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
                X = np.concatenate((real_pairs, fake_pairs))
                y = np.zeros((20, 1, 64, 64))  # [1] * BATCH_SIZE + [0] * BATCH_SIZE
                d_loss = discriminator.train_on_batch(X, y)
                print("batch %d d_loss : %f" % (index, d_loss))
                discriminator.trainable = False
    
                # Training C:
                c_loss = classifier.train_on_batch(image_batch, label_batch)
                print("batch %d c_loss : %f" % (index, c_loss))
                classifier.trainable = False
    
                # Train G:
                g_loss = discriminator_and_classifier_on_generator.train_on_batch(
                    X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], 
                    [image_batch, np.ones((10, 1, 64, 64)), label_batch])
                discriminator.trainable = True
                classifier.trainable = True
                print("batch %d g_loss : %f" % (index, g_loss[1]))
                if index % 20 == 0:
                    generator.save_weights('generator', True)
                    discriminator.save_weights('discriminator', True)
    

    Theoretical Details

    I believe there are some misunderstandings regarding how conditional GANs work and what is the discriminators role in such schemes.

    Role of the Discriminator

    In the min-max game which is GAN training [4], the discriminator D is playing against the generator G (the network you actually care about) so that under D's scrutiny, G becomes better at outputting realistic results.

    For that, D is trained to tell apart real samples from samples from G ; while G is trained to fool D by generating realistic results / results following the target distribution.

    Note: in the case of conditional GANs, i.e. GANs mapping an input sample from one domain A (e.g. real picture) to another domain B (e.g. sketch), D is usually fed with the pairs of samples stacked together and has to discriminate "real" pairs (input sample from A + corresponding target sample from B) and "fake" pairs (input sample from A + corresponding output from G) [1, 2]

    Training a conditional generator against D (as opposed to simply training G alone, with a L1/L2 loss only e.g. DAE) improves the sampling capability of G, forcing it to output crisp, realistic results instead of trying to average the distribution.

    Even though discriminators can have multiple sub-networks to cover other tasks (see next paragraphs), D should keep at least one sub-network/output to cover its main task: telling real samples from generated ones apart. Asking D to regress further semantic information (e.g. classes) alongside may interfere with this main purpose.

    Note: D output is often not a simple scalar / boolean. It is common to have a discriminator (e.g. PatchGAN [1, 2]) returning a matrix of probabilities, evaluating how realistic patches made from its input are.


    Conditional GANs

    Traditional GANs are trained in an unsupervised manner to generate realistic data (e.g. images) from a random noise vector as input. [4]

    As previously mentioned, conditional GANs have further input conditions. Along/instead of the noise vector, they take for input a sample from a domain A and return a corresponding sample from a domain B. A can be a completely different modality, e.g. B = sketch image while A = discrete label ; B = volumetric data while A = RGB image, etc. [3]

    Such GANs can also be conditioned by multiples inputs, e.g. A = real image + discrete label while B = sketch image. A famous work introducing such methods is InfoGAN [5]. It presents how to condition GANs on multiple continuous or discrete inputs (e.g. A = digit class + writing type, B = handwritten digit image), using a more advanced discriminator which has for 2nd task to force G to maximize the mutual-information between its conditioning inputs and its corresponding outputs.


    Maximizing the Mutual Information for cGANs

    InfoGAN discriminator has 2 heads/sub-networks to cover its 2 tasks [5]:

    • One head D1 does the traditional real/generated discrimination -- G has to minimize this result, i.e. it has to fool D1 so that it can't tell apart real form generated data;
    • Another head D2 (also named Q network) tries to regress the input A information -- G has to maximize this result, i.e. it has to output data which "show" the requested semantic information (c.f. mutual-information maximization between G conditional inputs and its outputs).

    You can find a Keras implementation here for instance: https://github.com/eriklindernoren/Keras-GAN/tree/master/infogan.

    Several works are using similar schemes to improve control over what a GAN is generating, by using provided labels and maximizing the mutual information between these inputs and G outputs [6, 7]. The basic idea is always the same though:

    • Train G to generate elements of domain B, given some inputs of domain(s) A;
    • Train D to discriminate "real"/"fake" results -- G has to minimize this;
    • Train Q (e.g. a classifier ; can share layers with D) to estimate the original A inputs from B samples -- G has to maximize this).

    Wrapping Up

    In your case, it seems you have the following training data:

    • real images Ia
    • corresponding sketch images Ib
    • corresponding class labels c

    And you want to train a generator G so that given an image Ia and its class label c, it outputs a proper sketch image Ib'.

    All in all, that's a lot of information you have, and you can supervise your training both on the conditioned images and the conditioned labels... Inspired from the aforementioned methods [1, 2, 5, 6, 7], here is a possible way of using all this information to train your conditional G:

    Network G:
    • Inputs: Ia + c
    • Output: Ib'
    • Architecture: up-to-you (e.g. U-Net, ResNet, ...)
    • Losses: L1/L2 loss between Ib' & Ib, -D loss, Q loss
    Network D:
    • Inputs: Ia + Ib (real pair), Ia + Ib' (fake pair)
    • Output: "fakeness" scalar/matrix
    • Architecture: up-to-you (e.g. PatchGAN)
    • Loss: cross-entropy on the "fakeness" estimation
    Network Q:
    • Inputs: Ib (real sample, for training Q), Ib' (fake sample, when back-propagating through G)
    • Output: c' (estimated class)
    • Architecture: up-to-you (e.g. LeNet, ResNet, VGG, ...)
    • Loss: cross-entropy between c and c'
    Training Phase:
    1. Train D on a batch of real pairs Ia + Ib then on a batch of fake pairs Ia + Ib';
    2. Train Q on a batch of real samples Ib;
    3. Fix D and Q weights;
    4. Train G, passing its generated outputs Ib' to D and Q to back-propagate through them.

    Note: this is a really rough architecture description. I'd recommend going through the literature ([1, 5, 6, 7] as a good start) to get more details and maybe a more elaborate solution.


    References

    1. Isola, Phillip, et al. "Image-to-image translation with conditional adversarial networks." arXiv preprint (2017). http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf
    2. Zhu, Jun-Yan, et al. "Unpaired image-to-image translation using cycle-consistent adversarial networks." arXiv preprint arXiv:1703.10593 (2017). http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf
    3. Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014). https://arxiv.org/pdf/1411.1784
    4. Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems. 2014. http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
    5. Chen, Xi, et al. "Infogan: Interpretable representation learning by information maximizing generative adversarial nets." Advances in Neural Information Processing Systems. 2016. http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf
    6. Lee, Minhyeok, and Junhee Seok. "Controllable Generative Adversarial Network." arXiv preprint arXiv:1708.00598 (2017). https://arxiv.org/pdf/1708.00598.pdf
    7. Odena, Augustus, Christopher Olah, and Jonathon Shlens. "Conditional image synthesis with auxiliary classifier gans." arXiv preprint arXiv:1610.09585 (2016). http://proceedings.mlr.press/v70/odena17a/odena17a.pdf