Search code examples
machine-learningkerasautoencoder

Variational Autoencoder cross-entropy loss (xent_loss) with 3D convolutional layers


I am adapting this implementation of VAE https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder.py that I found here https://blog.keras.io/building-autoencoders-in-keras.html

This implementation does not use convolutional layers so everything happens in 1D so to speak. My goal is to implement 3D convolutional layers within this model.

However I run into a shape mismatch at the loss function when running the batches (which are of 128 samples):

def vae_loss(self, x, x_decoded_mean):
    xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
    #xent_loss.shape >> [128, 40, 20, 40, 1]
    kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    #kl_loss.shape >> [128]
    return K.mean(xent_loss + kl_loss) # >> error shape mismatch

Almost the same question is answered here already Keras - Variational Autoencoder Incompatible shape for a model with 1D convolutional layers, but I can't really understand how to extrapolate the answer to my case wjich has a more complex Input shape.

I have tried this solution:

xent_loss = original_dim * metrics.binary_crossentropy(K.flatten(x), K.flatten(x_decoded_mean))

But I don't know whether it is a valid solution or not from a mathematical point of view, although now the model is running.


Solution

  • Your approach is right but it's highly dependent on K.binary_crossentropy implementation. tensorflow and theano ones should work for you (as far as I know). To make it more clean and not implementation dependent I suggest you the following way:

    xent_loss_vec = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
    xent_loss = K.mean(xent_loss_vec, axis=[1, 2, 3, 4])
    # xent_loss.shape = (128,)
    

    Now you are taking a mean out of losses for each voxel and thanks to that every valid implementation of binary_crossentropy should work fine for you.