Search code examples
pythonmachine-learningkeraslstmflatten

What is this flatten layer doing in my LSTM?


I am creating an LSTM for sentiment analysis with (a subset of) the IMDB database, using Keras. My training, validation and testing accuracy dramatically improves if I add a flatten layer before the final dense layer:

def lstm_model_flatten():
    embedding_dim = 128
    model = Sequential()
    model.add(layers.Embedding(vocab_size, embedding_dim, input_length=maxlen))
    model.add(layers.LSTM(128, return_sequences = True,  dropout=0.2)) 
    # Flatten layer
    model.add(layers.Flatten())
    model.add(layers.Dense(1,activation='sigmoid'))
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    model.summary()
    return model

This overfits quickly, but the validation accuracy gets up to around 76%:

Model: "sequential_43"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_42 (Embedding)     (None, 500, 128)          4768256   
_________________________________________________________________
lstm_63 (LSTM)               (None, 500, 128)          131584    
_________________________________________________________________
flatten_10 (Flatten)         (None, 64000)             0         
_________________________________________________________________
dense_40 (Dense)             (None, 1)                 64001     
=================================================================
Total params: 4,963,841
Trainable params: 4,963,841
Non-trainable params: 0
_________________________________________________________________
Epoch 1/7
14/14 [==============================] - 26s 2s/step - loss: 0.6911 - accuracy: 0.5290 - val_loss: 0.6802 - val_accuracy: 0.5650
Epoch 2/7
14/14 [==============================] - 23s 2s/step - loss: 0.6451 - accuracy: 0.6783 - val_loss: 0.6074 - val_accuracy: 0.6950
Epoch 3/7
14/14 [==============================] - 23s 2s/step - loss: 0.4594 - accuracy: 0.7910 - val_loss: 0.5237 - val_accuracy: 0.7300
Epoch 4/7
14/14 [==============================] - 23s 2s/step - loss: 0.2566 - accuracy: 0.9149 - val_loss: 0.4753 - val_accuracy: 0.7650
Epoch 5/7
14/14 [==============================] - 23s 2s/step - loss: 0.1397 - accuracy: 0.9566 - val_loss: 0.6011 - val_accuracy: 0.8050
Epoch 6/7
14/14 [==============================] - 23s 2s/step - loss: 0.0348 - accuracy: 0.9898 - val_loss: 0.7648 - val_accuracy: 0.8100
Epoch 7/7
14/14 [==============================] - 23s 2s/step - loss: 0.0136 - accuracy: 0.9955 - val_loss: 0.8829 - val_accuracy: 0.8150

Using the same architecture without the flatten layer (and using return_sequences = False on the LSTM layer) only produces a validation accuracy of around 50%.

The comments on this post recommend that return_sequences = False is used before the dense layer, rather than a flatten layer.

But why is that the case? Is it ok to use a flatten layer if it improves my model? What exactly is the flatten layer doing here, and why does it improve the accuracy?


Solution

  • An LSTM layer consists of different LSTM cells that are processed sequentially. As seen in the figure below, the first cell takes an input/embedding calculates a hidden state and the next cell uses its input and the hidden state at previous time step to compute its own hidden state. Basically the arrows between the cells also pass the hidden states. If you do return_sequences=False, the lstm layer only outputs the very last hidden state! (h_4 in the figure). So, all those information from all inputs and cells are embedded in a single fixed size information and it can not contain lots of information. This is why, your accuracy is not good when you only use the last hidden state.

    When you do return_sequences=True, lstm layer outputs every hidden state, so the next layers have access to all hidden states and they contain naturally more information. However, the LSTM layer returns a matrix. You can also see this in your model summary. It returns a matrix of size (None, 500, 128). None is basically number of samples in your batch, you can forget about it. 500 is your input size, and 128 is your hidden state size. The dense layer can not process a matrix, it has to be a vector. That why you need to apply flatten and what it does is basically just to open up the 2D matrix and represent it as 1D vector. Therefore, the size of your Flatten layer is 64000 because 500*128 = 64000. And Of course with more hidden states, the accuracy is better as they contain more information. An example of LSTM networks