I saved an LSTM with multiple layers. Now, I want to load it and just fine-tune the last LSTM layer. How can I target this layer and change its parameters?
Example of a simple model trained and saved:
model = Sequential()
# first layer #neurons
model.add(LSTM(100, return_sequences=True, input_shape=(X.shape[1],
X.shape[2])))
model.add(LSTM(50, return_sequences=True))
model.add(LSTM(25))
model.add(Dense(1))
model.compile(loss='mae', optimizer='adam')
I can load and retrain it but I can't find a way to target specific layer and freeze all the other layers.
If you have previously built and saved the model and now want to load it and fine-tune only the last LSTM layer, then you need to set the other layers' trainable
property to False
. First, find the name of the layer (or index of the layer by counting from zero starting from the top) by using model.summary()
method. For example this is the output produced for one of my models:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_10 (InputLayer) (None, 400, 16) 0
_________________________________________________________________
conv1d_2 (Conv1D) (None, 400, 32) 4128
_________________________________________________________________
lstm_2 (LSTM) (None, 32) 8320
_________________________________________________________________
dense_2 (Dense) (None, 1) 33
=================================================================
Total params: 12,481
Trainable params: 12,481
Non-trainable params: 0
_________________________________________________________________
Then set the trainable parameters of all the layers except the LSTM layer to False
.
Approach 1:
for layer in model.layers:
if layer.name != `lstm_2`
layer.trainable = False
Approach 2:
for layer in model.layers:
layer.trainable = False
model.layers[2].trainable = True # set lstm to be trainable
# to make sure 2 is the index of the layer
print(model.layers[2].name) # prints 'lstm_2'
Don't forget to compile the model again to apply these changes.