Search code examples
tensorflowkeraskeras-layer

Create Keras / Tensorflow layer that computes 2D DCT


I would like to add a layer to my CNN a layer that computes de DCT. I have seen that tensorflow only has 1D DCT.

How can I create a layer that in the middle of the network performs a DCT on the batch of images.


Solution

  • The transformation would look like this, you could place it anywhere in your model and it would be converted into an ops layer automatically. Or if you prefer, you could wrap it in a Lambda layer

    def dct_2d(
            feature_map,
            norm=None # can also be 'ortho'
    ):
        X1 = tf.signal.dct(feature_map, type=2, norm=norm)
        X1_t = tf.transpose(X1, perm=[0, 1, 3, 2])
        X2 = tf.signal.dct(X1_t, type=2, norm=norm)
        X2_t = tf.transpose(X2, perm=[0, 1, 3, 2])
        return X2_t
    

    Bear in mind, in tensorflow, DCT is always applied to the -1th axis. So if you have a feature map of batch, H, W, channels, you want to transform to batch, channels, H, W. Only in that way to the above transposes become correct.