Search code examples
tensorflowlstmrecurrent-neural-network

In TensorFlow 2.0 how to pass the output of a LSTM model at the previous time-step as input to next time-step?


I want to build a LSTM model where the input to the (n+1)th timestep is a function of the output at the (n)th timestep. I don't see a way this can be done in the current framework. People have been mentioning using raw_rnn which I think is deprecated in TensorFlow 2.0. Can anyone help me with this issue? Currently this is what I have,

class RNN(tf.keras.Model):
    def __init__(self):
        super(RNN, self).__init__()
        rnn_units = 16
        self.bn_layer = tf.keras.layers.BatchNormalization(
            momentum=0.99,
            epsilon=1e-6,
            beta_initializer=tf.random_normal_initializer(0.0, stddev=0.1),
            gamma_initializer=tf.random_uniform_initializer(0.1, 0.5)
        )
        self.lstm1 = tf.keras.layers.LSTM(rnn_units,
                                          return_sequences=True,
                                          return_state=True,
                                          recurrent_initializer='glorot_uniform',
                                          input_shape=[None, 4])
        self.lstm2 = tf.keras.layers.LSTM(rnn_units,
                                          return_sequences=True,
                                          return_state=True,
                                          recurrent_initializer='glorot_uniform')
        self.dense = tf.keras.layers.Dense(4)

    def call(self, x, training):
        for i in range(sequence_length):
            if i == 0:
                init_state1 = None
                init_state2 = None

            x = self.bn_layer(x, training)
            lstm_output, new_h1, new_c1 = self.lstm1(x, initial_state=init_state1)
            lstm_output, new_h2, new_c2 = self.lstm2(lstm_output, initial_state=init_state2)
            output = self.dense(lstm_output)           

            x = process_output_to_input(output)
            init_state1 = [new_h1, new_c1]
            init_state2 = [new_h2, new_c2]

        return output

Solution

  • I found a solution that utilizes the stateful property of the LSTM layers, which you can refer to: https://adgefficiency.com/tf2-lstm-hidden/. My implementation is:

    class SingleStepLSTM(tf.keras.Model):
        def __init__(self, config):
            super(SingleStepLSTM, self).__init__()
            state_dim = config.state_dim
            rnn_units = config.rnn_units
            self.bn_layer = tf.keras.layers.BatchNormalization(
                momentum=0.99,
                epsilon=1e-6,
                beta_initializer=tf.random_normal_initializer(0.0, stddev=0.1),
                gamma_initializer=tf.random_uniform_initializer(0.1, 0.5)
            )
            self.lstm1 = tf.keras.layers.LSTM(rnn_units,
                                              return_sequences=True,
                                              recurrent_initializer='glorot_uniform',
                                              stateful=True,
                                              input_shape=[None, state_dim])
            self.lstm2 = tf.keras.layers.LSTM(rnn_units,
                                              return_sequences=True,
                                              stateful=True,
                                              recurrent_initializer='glorot_uniform')
            self.dense = tf.keras.layers.Dense(state_dim)
    
        def call(self, x, training=True):
            x = self.bn_layer(x, training)
            h = self.lstm1(x)
            h = self.lstm2(h)
            x = self.dense(h)
    
            return x
    

    This is a single step LSTM model, when training or testing we can loop through it,

    single_lstm_step = SingleStepLSTM(config)
    for i in range(num_seqs):
        output = single_lstm_step(input)
        input = process_to_input(output)
    

    And I think the implementation in the problem statement would also work, using Stateful RNN's is just in my opinion a more elegant solution.