Hi Guys I am working with this code from machinecurve
The endecode part has this architecture the input are images with 28x28 size:
i = Input(shape=input_shape, name='encoder_input')
cx = Conv2D(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(i)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
x = Flatten()(cx)
x = Dense(20, activation='relu')(x)
x = BatchNormalization()(x)
mu = Dense(latent_dim, name='latent_mu')(x)
sigma = Dense(latent_dim, name='latent_sigma')(x)
The decode parts are as follows and it tries to reverse the layers defined in the code part:
d_i = Input(shape=(latent_dim, ), name='decoder_input')
x = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation='relu')(d_i)
x = BatchNormalization()(x)
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
cx = Conv2DTranspose(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(x)
cx = BatchNormalization()(cx)
cx = Conv2DTranspose(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2DTranspose(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2DTranspose(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
o = Conv2DTranspose(filters=num_channels, kernel_size=3, activation='sigmoid', padding='same', name='decoder_output')(cx)
As we see below the encoder_input must be the same as the decoder_output:
Model: "vae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
encoder_input (InputLayer) (None, 28, 28, 1) 0
_________________________________________________________________
encoder (Model) [(None, 2), (None, 2), (N 17298104
_________________________________________________________________
decoder (Model) (None, 32, 32, 1) 43457025
=================================================================
Total params: 60,755,129
Trainable params: 60,739,217
Non-trainable params: 15,912
And then when the model is trained we have this error:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-46-44d4cd644e8f> in <module>
3
4 # Train autoencoder
----> 5 vae.fit(input_train, input_train, epochs = no_epochs, batch_size = batch_size, validation_split = validation_split)
~\.conda\envs\keypoints\lib\site-packages\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
1237 steps_per_epoch=steps_per_epoch,
1238 validation_steps=validation_steps,
-> 1239 validation_freq=validation_freq)
1240
1241 def evaluate(self,
~\.conda\envs\keypoints\lib\site-packages\keras\engine\training_arrays.py in fit_loop(model, fit_function, fit_inputs, out_labels, batch_size, epochs, verbose, callbacks, val_function, val_inputs, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq)
194 ins_batch[i] = ins_batch[i].toarray()
195
--> 196 outs = fit_function(ins_batch)
197 outs = to_list(outs)
198 for l, o in zip(out_labels, outs):
~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\keras\backend.py in __call__(self, inputs)
3738 value = math_ops.cast(value, tensor.dtype)
3739 converted_inputs.append(value)
-> 3740 outputs = self._graph_fn(*converted_inputs)
3741
3742 # EagerTensor.numpy() will often make a copy to ensure memory safety.
~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\eager\function.py in __call__(self, *args, **kwargs)
1079 TypeError: For invalid positional/keyword argument combinations.
1080 """
-> 1081 return self._call_impl(args, kwargs)
1082
1083 def _call_impl(self, args, kwargs, cancellation_manager=None):
~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\eager\function.py in _call_impl(self, args, kwargs, cancellation_manager)
1119 raise TypeError("Keyword arguments {} unknown. Expected {}.".format(
1120 list(kwargs.keys()), list(self._arg_keywords)))
-> 1121 return self._call_flat(args, self.captured_inputs, cancellation_manager)
1122
1123 def _filtered_call(self, args, kwargs):
~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1222 if executing_eagerly:
1223 flat_outputs = forward_function.call(
-> 1224 ctx, args, cancellation_manager=cancellation_manager)
1225 else:
1226 gradient_name = self._delayed_rewrite_functions.register()
~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\eager\function.py in call(self, ctx, args, cancellation_manager)
509 inputs=args,
510 attrs=("executor_type", executor_type, "config_proto", config),
--> 511 ctx=ctx)
512 else:
513 outputs = execute.execute_with_cancellation(
~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
65 else:
66 message = e.message
---> 67 six.raise_from(core._status_to_exception(e.code, message), None)
68 except TypeError as e:
69 keras_symbolic_tensors = [
~\.conda\envs\keypoints\lib\site-packages\six.py in raise_from(value, from_value)
InvalidArgumentError: Incompatible shapes: [100352] vs. [131072]
[[node gradients/loss/decoder_loss/kl_reconstruction_loss/mul_1_grad/BroadcastGradientArgs (defined at C:\Users\XXXXX\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\framework\ops.py:1751) ]] [Op:__inference_keras_scratch_graph_22124]
Function call stack:
keras_scratch_graph
Do you have any idea in how to solve this issue please?
I also define:
# Define sampling with reparameterization trick
def sample_z(args):
mu, sigma = args
batch = K.shape(mu)[0]
dim = K.int_shape(mu)[1]
eps = K.random_normal(shape=(batch, dim))
return mu + K.exp(sigma / 2) * eps
# Use reparameterization trick to ensure correct gradient
z = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([mu, sigma])
And the encoder will be defined:
encoder = Model(i, [mu, sigma, z], name='encoder')
The architecture is:
Model: "encoder"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_input (InputLayer) (None, 28, 28, 1) 0
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 14, 14, 128) 3328 encoder_input[0][0]
__________________________________________________________________________________________________
batch_normalization_25 (BatchNo (None, 14, 14, 128) 512 conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 7, 7, 256) 819456 batch_normalization_25[0][0]
__________________________________________________________________________________________________
batch_normalization_26 (BatchNo (None, 7, 7, 256) 1024 conv2d_11[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 4, 4, 512) 3277312 batch_normalization_26[0][0]
__________________________________________________________________________________________________
batch_normalization_27 (BatchNo (None, 4, 4, 512) 2048 conv2d_12[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 2, 2, 1024) 13108224 batch_normalization_27[0][0]
__________________________________________________________________________________________________
batch_normalization_28 (BatchNo (None, 2, 2, 1024) 4096 conv2d_13[0][0]
__________________________________________________________________________________________________
flatten_4 (Flatten) (None, 4096) 0 batch_normalization_28[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 20) 81940 flatten_4[0][0]
__________________________________________________________________________________________________
batch_normalization_29 (BatchNo (None, 20) 80 dense_7[0][0]
__________________________________________________________________________________________________
latent_mu (Dense) (None, 2) 42 batch_normalization_29[0][0]
__________________________________________________________________________________________________
latent_sigma (Dense) (None, 2) 42 batch_normalization_29[0][0]
__________________________________________________________________________________________________
z (Lambda) (None, 2) 0 latent_mu[0][0]
latent_sigma[0][0]
==================================================================================================
Total params: 17,298,104
Similar the decoder part is defined:
decoder = Model(d_i, o, name='decoder')
The architecture of the decoder is:
Model: "decoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
decoder_input (InputLayer) (None, 2) 0
_________________________________________________________________
dense_8 (Dense) (None, 4096) 12288
_________________________________________________________________
batch_normalization_30 (Batc (None, 4096) 16384
_________________________________________________________________
reshape_4 (Reshape) (None, 2, 2, 1024) 0
_________________________________________________________________
conv2d_transpose_10 (Conv2DT (None, 4, 4, 1024) 26215424
_________________________________________________________________
batch_normalization_31 (Batc (None, 4, 4, 1024) 4096
_________________________________________________________________
conv2d_transpose_11 (Conv2DT (None, 8, 8, 512) 13107712
_________________________________________________________________
batch_normalization_32 (Batc (None, 8, 8, 512) 2048
_________________________________________________________________
conv2d_transpose_12 (Conv2DT (None, 16, 16, 256) 3277056
_________________________________________________________________
batch_normalization_33 (Batc (None, 16, 16, 256) 1024
_________________________________________________________________
conv2d_transpose_13 (Conv2DT (None, 32, 32, 128) 819328
_________________________________________________________________
batch_normalization_34 (Batc (None, 32, 32, 128) 512
_________________________________________________________________
decoder_output (Conv2DTransp (None, 32, 32, 1) 1153
=================================================================
Total params: 43,457,025
Trainable params: 43,444,993
Non-trainable params: 12,032
And finally we put it all together:
# =================
# VAE as a whole
# =================
# Instantiate VAE
vae_outputs = decoder(encoder(i)[2])
vae = Model(i, vae_outputs, name='vae')
This a problem due to the output shape of your decoder... you can simply solve it by changing the final layer of your decoder with:
Conv2D(filters=num_channels, kernel_size=5, activation='sigmoid', name='decoder_output')
here the full code:
num_channels = 1
latent_dim = 2
input_shape = (28,28,1)
i = Input(shape=input_shape, name='encoder_input')
cx = Conv2D(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(i)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2D(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
x = Flatten()(cx)
x = Dense(20, activation='relu')(x)
x = BatchNormalization()(x)
mu = Dense(latent_dim, name='latent_mu')(x)
sigma = Dense(latent_dim, name='latent_sigma')(x)
conv_shape = K.int_shape(cx)
d_i = Input(shape=(latent_dim, ), name='decoder_input')
x = Dense(np.prod(conv_shape[1:]), activation='relu')(d_i)
x = BatchNormalization()(x)
x = Reshape(conv_shape[1:])(x)
cx = Conv2DTranspose(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(x)
cx = BatchNormalization()(cx)
cx = Conv2DTranspose(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2DTranspose(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
cx = Conv2DTranspose(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx = BatchNormalization()(cx)
o = Conv2D(filters=num_channels, kernel_size=5, activation='sigmoid', name='decoder_output')(cx)
sampling layer:
def sample_z(args):
mu, sigma = args
batch = K.shape(mu)[0]
dim = K.int_shape(mu)[1]
eps = K.random_normal(shape=(batch, dim))
return mu + K.exp(sigma / 2) * eps
# Use reparameterization trick to ensure correct gradient
z = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([mu, sigma])
final VAE:
encoder = Model(i, [mu, sigma, z], name='encoder')
decoder = Model(d_i, o, name='decoder')
vae_outputs = decoder(encoder(i)[2])
vae = Model(i, vae_outputs, name='vae')
summary:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
encoder_input (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
encoder (Model) [(None, 2), (None, 2), (N 17298104
_________________________________________________________________
decoder (Model) (None, 28, 28, 1) 43459073
=================================================================
as you can see, input and output shapes now match