Search code examples
deep-learningkeraskeras-layer

How to split a tensor column-wise in Keras to implement STFCN


I'd like to implement the Spatiotemporal Fully Convolutional Network (STFCN) in Keras. I need to feed each depth column of a 3D convolutional output, e.g. a tensor with shape (64, 16, 16), as input to a separate LSTM.

To make this clear, I have a (64 x 16 x 16) tensor of dimensions (channels, height, width). I need to split (explicitly or implicitly) the tensor into 16 * 16 = 256 tensors of shape (64 x 1 x 1).

Here is a diagram from the STFCN paper to illustrate the Spatio-Temporal module. What I described above is the arrow between 'Spatial Features' and 'Spatio-Temporal Module'.

The connection between FCn and Spatio-Temporal Module is the relevant part of the diagram.

How would this idea be best implemented in Keras?


Solution

  • You can use tf.split from Tensorflow using Keras Lambda layer

    Use Lambda to split a tensor of shape (64,16,16) into (64,1,1,256) and then subset any indexes you need.

    import numpy as np
    import tensorflow as tf
    import keras.backend as K
    from keras.models import  Model
    from keras.layers import Input, Lambda
    
    # input data
    data = np.ones((3,64,16,16))
    
    # define lambda function to split
    def lambda_fun(x) : 
        x = K.expand_dims(x, 4)
        split1 = tf.split(x, 16, 2)
        x = K.concatenate(split1, 4)
        split2 = tf.split(x, 16, 3)
        x = K.concatenate(split2, 4)
        return x
    
    ## check thet splitting works fine
    input = Input(shape= (64,16,16))
    ll = Lambda(lambda_fun)(input)
    model = Model(inputs=input, outputs=ll)
    res = model.predict(data)
    print(np.shape(res))    #(3, 64, 1, 1, 256)