Search code examples
conv-neural-networktensorflow2.0recurrent-neural-network

How to add a SimpleRNN layer within Convolutional layers in tensorflow without changing ndim?


I'm trying to add a RNN layer within a Convolutional layer. Unfortunately due to difference of ndim it's failing to create a model.

Model:

model = keras.Sequential(
[
    # layers.Rescaling(1.0/255),
    keras.Input(shape=(256, 256, 3)),
    layers.Conv2D(32, (3,3), padding="valid", activation='swish'),
    layers.MaxPooling2D(pool_size=(2,2)),
    layers.Conv2D(64, 3, activation="swish"),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, activation="swish"),
    layers.Flatten(),
    layers.Dense(64, activation='sigmoid'),
    layers.Dense(10),
    layers.Dense(2),
    ]
)

It'll be really helpful, if someone can help to figure this out :)

EDIT

Code that gave the error

model = keras.Sequential(
[
    # layers.Rescaling(1.0/255),
    keras.Input(shape=(256, 256, 3)),
    layers.Conv2D(32, (3,3), padding="valid", activation='swish'),
    layers.SimpleRNN(512, activation='relu')
    layers.MaxPooling2D(pool_size=(2,2)),
    layers.Conv2D(64, 3, activation="swish"),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, activation="swish"),
    layers.Flatten(),
    layers.Dense(64, activation='sigmoid'),
    layers.Dense(10),
    layers.Dense(2),
]
)

Error Message

Input 0 of layer "simple_rnn" is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: (None, 254, 254, 32)

Solution

  • If you add a layers.Reshape layer before and after the RNN layer, the shape issue is resolved.

    model = keras.Sequential(
        [
            keras.Input(shape=(256, 256, 3)),
            layers.Conv2D(32, (3,3), padding="valid", activation='swish'),
            layers.Reshape((-1, 32)),  # flatten only the two spatial dimensions
            layers.SimpleRNN(512),
            layers.Reshape((16, 16, 2)),  # whatever shape you like
            layers.MaxPooling2D(pool_size=(2,2)),
            layers.Conv2D(64, 3, activation="swish"),
            layers.MaxPooling2D(),
            layers.Conv2D(32, 3, activation="swish"),
            layers.Flatten(),
            layers.Dense(64, activation='sigmoid'),
            layers.Dense(10),
            layers.Dense(2),
        ]
    )
    

    I don't know if it makes sense to use 2d convolution after the RNN layer. I think it might destroy the spatial semantics of the input. Also the RNN layer will have a huge number of weights. That may speak for adding the RNN layer at the end. Reshaping will work there as well, but you will have to add a dimension, instead of flattening one.