Search code examples
tensorflowkerasrecurrent-neural-network

How to create a recurrent connection between 2 layers in Tensorflow/Keras?


Essentially what I would like to do is take the following very simple feedforward graph:

Simple feedforward graph

And then add a recurrent layer that feeds the outputs of the second Dense layer as Input to the first Dense layer, like demonstrated below. Both models are obviously simplifications of my actual use case, though I suppose the general principle for which I am asking holds true for both.

Simple recurrent connection graph

I wonder if there may be an efficient way in Tensorflow or even keras to accomplish this, especially regarding GPU processing efficiency. While I am fairly confident that I could hack together a custom model in Tensorflow that would accomplish this function-wise am I pessimistic about the GPU processing efficiency of such a custom model. I therefore would very much appreciate if someone knows about an efficient way to accomplish these recurrent connections between 2 layers. Thank you for your time! =)


For completeness sake, here is the code to create the first simple feedforward graph. The recurrent graph I created through image editing.

inputs = tf.keras.Input(shape=(128,))

h_1 = tf.keras.layers.Dense(64)(inputs)
h_2 = tf.keras.layers.Dense(32)(h_1)
out = tf.keras.layers.Dense(16)(h_2)

model = tf.keras.Model(inputs, out)

Solution

  • Since my question hasn't received any answers would I like to share the solution I came up with in case someone finds this question via search.

    Please let me know if you find or come up with a better solution - thanks!

    class SimpleModel(tf.keras.Model):
        def __init__(self, input_shape, *args, **kwargs):
            super(SimpleModel, self).__init__(*args, **kwargs)
            # Create node layers
            self.node_1 = tf.keras.layers.InputLayer(input_shape=input_shape)
            self.node_2 = tf.keras.layers.Dense(64, activation='sigmoid')
            self.node_3 = tf.keras.layers.Dense(32, activation='sigmoid')
            self.node_4 = tf.keras.layers.Dense(16, activation='sigmoid')
            self.conn_3_2_recurrent_state = None
    
            # Create recurrent connection states
            node_1_output_shape = self.node_1.compute_output_shape(input_shape)
            node_2_output_shape = self.node_2.compute_output_shape(node_1_output_shape)
            node_3_output_shape = self.node_3.compute_output_shape(node_2_output_shape)
    
            self.conn_3_2_recurrent_state = tf.Variable(initial_value=self.node_3(tf.ones(shape=node_2_output_shape)),
                                                        trainable=False,
                                                        validate_shape=False,
                                                        dtype=tf.float32)
            # OR
            # self.conn_3_2_recurrent_state = tf.random.uniform(shape=node_3_output_shape, minval=0.123, maxval=4.56)
            # OR
            # self.conn_3_2_recurrent_state = tf.ones(shape=node_3_output_shape)
            # OR
            # self.conn_3_2_recurrent_state = tf.zeros(shape=node_3_output_shape)
    
        def call(self, inputs):
            x = self.node_1(inputs)
    
            #tf.print(self.conn_3_2_recurrent_state)
            #tf.print(self.conn_3_2_recurrent_state.shape)
    
            x = tf.keras.layers.Concatenate(axis=-1)([x, self.conn_3_2_recurrent_state])
            x = self.node_2(x)
            x = self.node_3(x)
    
            self.conn_3_2_recurrent_state.assign(x)
            #tf.print(self.conn_3_2_recurrent_state)
            #tf.print(self.conn_3_2_recurrent_state.shape)
    
            x = self.node_4(x)
            return x
    
    
    # Demonstrate statefulness of model (uncomment tf prints in model.call())
    model = SimpleModel(input_shape=(10, 128))
    x = tf.ones(shape=(10, 128))
    model(x)
    model(x)
    
    
    # Demonstrate trainability of the recurrent connection TF model
    x = tf.random.uniform(shape=(10, 128))
    y = tf.ones(shape=(10, 16))
    
    model = SimpleModel(input_shape=(10, 128))
    model.compile(optimizer='adam', loss='binary_crossentropy')
    model.fit(x=x, y=y, epochs=100)