Search code examples
pythonkerasneural-networkdeep-learninggenerative-adversarial-network

Keras: Understanding the role of Embedding layer in a Conditional GAN


I am working to understand Erik Linder-Norén's implementation of the Categorical GAN model, and am confused by the generator in that model:

def build_generator(self):
    model = Sequential()
    # ...some lines removed...    
    model.add(Dense(np.prod(self.img_shape), activation='tanh'))
    model.add(Reshape(self.img_shape))
    model.summary()

    noise = Input(shape=(self.latent_dim,))
    label = Input(shape=(1,), dtype='int32')
    label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
    model_input = multiply([noise, label_embedding])
    img = model(model_input)

    return Model([noise, label], img)

My question is: How does the Embedding() layer work here?

I know that noise is a vector that has length 100, and label is an integer, but I don't understand what the label_embedding object contains or how it functions here.

I tried printing the shape of label_embedding to try and figure out what's going on in that Embedding() line but that returns (?,?).

If anyone could help me understand how the Embedding() lines here work, I'd be very grateful for their assistance!


Solution

  • From the documentation, https://keras.io/layers/embeddings/#embedding,

    Turns positive integers (indexes) into dense vectors of fixed size. eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]

    In the GAN model, the input integer(0-9) is converted to a vector of shape 100. With this short code snippet, we can feed some test input to check the output shape of the Embedding layer.

    from keras.layers import Input, Embedding
    from keras.models import Model
    import numpy as np
    latent_dim = 100
    num_classes = 10
    label = Input(shape=(1,), dtype='int32')
    label_embedding = Embedding(num_classes, latent_dim)(label)
    mod = Model(label, label_embedding)
    test_input = np.zeros((1))
    print(f'output shape is {mod.predict(test_input).shape}')
    mod.summary()
    

    output shape is (1, 1, 100)

    From model summary, output shape for embedding layer is (1,100) which is the same as output of predict.

    embedding_1 (Embedding) (None, 1, 100) 1000

    One additional point, in the output shape (1,1,100), the leftmost 1 is the batch size, the middle 1 is the input length. In this case, we provided an input of length 1.