Search code examples
tensorflowkeraslstmtensorflow2.0keras-layer

Gate weights order for LSTM layers in Tensorflow


I have a Keras model that includes some LSTM layers. I know that I can get the weights of the LSTM layer through the get_weights() method, and the result is a list made of three elements: kernel, recurrent kernel and bias.

As the documentation state, each element includes the weights for the 4 gates in the LSTM layer. However, it does not state which is the order in which they are stored. For instance, if LSTM layer has N units, the bias vector will be made of 4*N elements. Which of those elements correspond to the 1st/2nd/3rd/4th gate?


Solution

  • The order is i, f, c, o which stands for input gate, forget gate, cell gate and output gate respectively. You can get the information from LSTMCell implementation here.

    lstm = LSTM(100)
    lstm(np.zeros((64,10,5)))
    kernel = lstm.weights[0]
    w_i,w_f,w_c,w_o = tf.split(kernel,4,axis=1)
    
    print(*(w.shape for w in (w_i,w_f,w_c,w_o)))#all are (5, 100)