Search code examples
pythontensorflowkerasdeep-learning

This variational autoencoder fails with InvalidArgumentError


Here is the code implementing a variational autoencoder in keras. The execution terminates at the end of the first epoch when training the model as shown at the bottom of this report.

import keras
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshape
#from keras.layers import BatchNormalization
from keras.models import Model
from keras.datasets import mnist
from keras import backend as K
import numpy as np
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train = x_train / 255
x_test = x_test / 255
img_width  = x_train.shape[1]
img_height = x_train.shape[2]
num_channels = 1 #MNIST --> grey scale so 1 channel
x_train = x_train.reshape(x_train.shape[0], img_height, img_width, num_channels)
x_test = x_test.reshape(x_test.shape[0], img_height, img_width, num_channels)
input_shape = (img_height, img_width, num_channels)
# ========================
latent_dim = 2 # Number of latent dim parameters

input_img = Input(shape=input_shape, name='encoder_input')
x = Conv2D(32, 3, padding='same', activation='relu')(input_img)
x = Conv2D(64, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2D(64, 3, padding='same', activation='relu')(x)
x = Conv2D(64, 3, padding='same', activation='relu')(x)

conv_shape = K.int_shape(x) #Shape of conv to be provided to decoder
x = Flatten()(x)
x = Dense(32, activation='relu')(x)
z_mu = Dense(latent_dim, name='z_mu')(x)
z_sigma = Dense(latent_dim, name='z_sigma')(x)

def sample_z(args):
  z_mu, z_sigma = args
  eps = K.random_normal(shape=(K.shape(z_mu)[0], K.int_shape(z_mu)[1]))
  return z_mu + K.exp(z_sigma / 2) * eps

z = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([z_mu, z_sigma])

encoder = Model(input_img, [z_mu, z_sigma, z], name='encoder')
print(encoder.summary())


decoder_input = Input(shape=(latent_dim, ), name='decoder_input')

x = Dense(conv_shape[1]*conv_shape[2]*conv_shape[3], activation='relu')(decoder_input)
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
x = Conv2DTranspose(32, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(num_channels, 3, padding='same', activation='sigmoid', name='decoder_output')(x)

decoder = Model(decoder_input, x, name='decoder')
decoder.summary()

z_decoded = decoder(z)

def vae_loss(x, z_decoded):
    x = K.flatten(x)
    z_decoded = K.flatten(z_decoded)
    recon_loss = keras.metrics.binary_crossentropy(x, z_decoded)
    kl_loss = -5e-4 * K.mean(1 + z_sigma - K.square(z_mu) - K.exp(z_sigma), axis=-1)
    return K.mean(recon_loss + kl_loss)

y = decoder(z)


vae = Model(input_img, y, name='vae')

# Compile VAE
vae.compile(optimizer='adam', loss=vae_loss)
vae.summary()

vae.fit(x_train, None, epochs = 10, batch_size = 32, validation_split = 0.2)

It ends with the followíng exception:

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'decoder_target' with dtype float and shape [?,?,?,?] [[{{node decoder_target}}]]


Solution

  • The following modification seems to fix the issue reported earlier while the code training gives diminishing losses across 10 epochs:

    • change the vae_loss () function.

    Here are the modified python instructions:

    def vae_loss(_, z_decoded):
        x = K.flatten(input_img)
        z_decoded = K.flatten(z_decoded)
        recon_loss = keras.metrics.binary_crossentropy(x, z_decoded)
        kl_loss = -5e-4 * K.mean(1 + z_sigma - K.square(z_mu) - K.exp(z_sigma), axis=-1)
        return K.mean(recon_loss + kl_loss)