Search code examples
tensorflowstatecudnn

Initializing the state of CUDNN LSTMs


I think we may use the following code segment to create a stack of LSTMs and initializes the states of it to be zero.

 lstm_cell = tf.contrib.rnn.BasicLSTMCell(
            hidden_size, forget_bias=0.0, state_is_tuple=True)
 cell = tf.contrib.rnn.MultiRNNCell([lstm_cell] * num_layers, state_is_tuple=True)
 cell.zero_state(batch_size, tf_float32)

Instead of using BasicLSTMCell, I would like to use the CUDNN

cudnn_cell = tf.contrib.cudnn_rnn.CudnnLSTM(
          num_layers, hidden_size, dropout=config.keep_prob)

In this case, how can I do the same thing as cell.zero_state(batch_size, tf_float32) on cudnn_cell?


Solution

  • The definition can be found in : tensorflow cudnn_rnn's code

    Regard to initial_states:

    with tf.Graph().as_default():
        lstm = CudnnLSTM(num_layers, num_units, direction, ...)
        outputs, output_states = lstm(inputs, initial_states, training=True)
    

    So you only need to add the initial states besides the embedding inputs. In the encoder-decoder system, it would look like:

    encoder_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers, hidden_size)
    encoder_output, encoder_state = encoder_cell(encoder_embedding_input)
    decoder_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers, hidden_size)
    decoder_output, decoder_state = encoder_cell(decoder_embedding_input,
                                                 initial_states=encoder_state)
    

    Here, the encoder_state is a tuple as (final_c_state, final_h_state). And the shape of both states are (1, batch, hidden_size)

    If your encoder is bidirection RNN, it would be a like bit tricky, since the output states are now become (2, batch, hidden_size).

    Hence, I use a roundabout way to solve it.

    encoder_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers, hidden_size, direction="bidirectional")
    encoder_output, (encoder_c_state, encoder_h_state) = encoder_cell(encoder_embedding_input)
    fw_c, bw_c = tf.split(encoder_c_state, [1, 1], axis=0)
    fw_h, bw_h = tf.split(encoder_h_state, [1, 1], axis=0)
    reshape_encoder_c_state = tf.concat((fw_c, bw_c), axis=2)
    reshape_encoder_h_state = tf.concat((fw_h, bw_h), axis=2)
    encoder_state = (reshape_encoder_c_state, reshape_encoder_h_state)
    

    Although I haven't tried many-layers RNN, I think it can also be solved in a similar way.