Autoencoder with 3D convolutions and convolutional LSTMs

I have implemented a variational autoencoder with CNN layers in the encoder and decoder. The code is shown below. My training data (train_X) consists of 40'000 images with size 64 x 80 x 1 and my validation data (valid_X) consists of 4500 images of size 64 x 80 x 1.

I would like to adapt my network in the following two ways:

  1. Instead of using 2D convolutions (Conv2D and Conv2DTranspose) I would like to use 3D convolutions to take time into account (as the third dimension). For that I would like to use slices of 10 images, i.e. I will have images of size 64 x 80 x 1 x 10. Can I just use Conv3D and Conv3DTranspose or are other changes necessary?

  2. I would like to try out convolutional LSTMs (ConvLSTM2D) in the encoder and decoder instead of plain 2D convolutions. Again, the input size of the images would be 64 x 80 x 1 x 10 (i.e. time series of 10 images). How can I adapt my network to work with ConvLSTM2D?

import keras
from keras import backend as K
from keras.layers import (Dense, Input, Flatten)
from keras.layers import Lambda, Conv2D
from keras.models import Model
from keras.layers import Reshape, Conv2DTranspose
from keras.losses import mse

def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

inner_dim = 16
latent_dim = 6

image_size = (64,78,1)
inputs = Input(shape=image_size, name='encoder_input')
x = inputs

x = Conv2D(32, 3, strides=2, activation='relu', padding='same')(x)
x = Conv2D(64, 3, strides=2, activation='relu', padding='same')(x)

# shape info needed to build decoder model
shape = K.int_shape(x)

# generate latent vector Q(z|X)
x = Flatten()(x)
x = Dense(inner_dim, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

# instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')

# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(inner_dim, activation='relu')(latent_inputs)
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(x)
x = Reshape((shape[1], shape[2], shape[3]))(x)

x = Conv2DTranspose(64, 3, strides=2, activation='relu', padding='same')(x)
x = Conv2DTranspose(32, 3, strides=2, activation='relu', padding='same')(x)

outputs = Conv2DTranspose(filters=1, kernel_size=3, activation='sigmoid', padding='same', name='decoder_output')(x)

# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')

# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae')

def vae_loss(x, x_decoded_mean):
    reconstruction_loss = mse(K.flatten(x), K.flatten(x_decoded_mean))
    reconstruction_loss *= image_size[0] * image_size[1]
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    return vae_loss

optimizer = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.000)
vae.compile(loss=vae_loss, optimizer=optimizer), train_X,
        validation_data=(valid_X, valid_X))

  • Have your input shape as (10, 64 , 80, 1) and just replace the layers.

    The boring part is to organize the input data, if you're going to use sliding windows or just reshape from (images, 64,80,1) to (images//10, 10, 64,80,1).

    Sliding windows (Overlapping) or not?

    1 - Ok.... if you want your model to understand individual segments of 10 images you may overlap or not. Your choice. Performance may be better with overlapping, but not necessarily.

    There isn't really an order in the images, as long as the 10 frames are in order.

    This is supported by Conv3D and by LSTM with stateful=False.

    2 - But if you want your model to understand the entire sequence, dividing the sequences only because of memory, only LSTM with stateful=True can support this.

    (A Conv3D with kernel size = (frames, w, h) will work, but limited to frames, never understanding sequences longer than frames. It may still be capable of detecting the existence of punctual events, though, but not long sequence relationships)

    In this case, for the LSTM you will need to:

    • set shuffle = False in training
    • use a fixed batch size of sequences
    • not overlap images
    • create a manual training loop where you do model.reset_states() every time you are giving "new sequences" for training AND predicting

    The loop structure would be:

    for epoch in range(epochs):
        for group_of_sequences in range(groups):
            sequences = getAGroupOfCompleteSequences() #shape (sequences, total_length, ....)            
            for batch in range(slide_divisions):
                batch = sequences[:,10*batch : 10*(batch+1)]
                model.train_on_batch(batch, ....)