Search code examples
pythontensorflowkerasneural-networkrecurrent-neural-network

Setting the initial state of an RNN represented as a Keras sequential model


How do I set the initial state of the recurrent neural network rnn constructed below?

from tensorflow.keras.layers import Dense, SimpleRNN
from tensorflow.keras.models import Sequential

rnn = Sequential([SimpleRNN(3), Dense(1)])

I'd like to specify the initial state of the first layer before fitting the model with model.fit.


Solution

  • According to the tf.keras.layers.RNN documentation, you can specify the initial states symbolically using the argument initial_state or numerically by calling the function reset_states.

    Symbolic specification means you need to add the initial states as a input to your model. Here is an example I adapted from the Keras tests:

    from tensorflow.keras.layers import Dense, SimpleRNN, Input
    from tensorflow.keras.models import Model
    import numpy as np
    import tensorflow as tf
    
    timesteps = 3
    embedding_dim = 4
    units = 3
    
    inputs = Input((timesteps, embedding_dim))
    # initial state as Keras Input
    initial_state = Input((units,))
    rnn = SimpleRNN(units)
    hidden = rnn(inputs, initial_state=initial_state)
    output = Dense(1)(hidden)
    
    model = Model([inputs] + [initial_state], output)
    model.compile(loss='categorical_crossentropy', 
                  optimizer=tf.compat.v1.train.AdamOptimizer())
    

    And once your model is defined, you can perform training as follows:

    num_samples = 2
    
    inputs = np.random.random((num_samples, timesteps, embedding_dim))
    # random initial state as additional input
    some_initial_state = np.random.random((num_samples, units))
    targets = np.random.random((num_samples, units))
    model.train_on_batch([inputs] + [some_initial_state], targets)
    

    Note that this approach requires you to use the Functional API. For Sequential models, you will need to use a stateful RNN, specify a batch_input_shape, and call the reset_states method:

    input_shape = (num_samples, timesteps, embedding_dim)
    model = Sequential([
        SimpleRNN(3, stateful=True, batch_input_shape=input_shape), 
        Dense(1)])
    
    some_initial_state = np.random.random((num_samples, units))
    rnn = model.layers[0]
    rnn.reset_states(states=some_initial_state)