Search code examples
pythonmachine-learningkeraslstmkeras-layer

keras access layer parameter of pre-trained model to freeze


I saved an LSTM with multiple layers. Now, I want to load it and just fine-tune the last LSTM layer. How can I target this layer and change its parameters?

Example of a simple model trained and saved:

model = Sequential()
# first layer  #neurons 
model.add(LSTM(100, return_sequences=True, input_shape=(X.shape[1], 
X.shape[2])))
model.add(LSTM(50, return_sequences=True))
model.add(LSTM(25))
model.add(Dense(1))
model.compile(loss='mae', optimizer='adam')

I can load and retrain it but I can't find a way to target specific layer and freeze all the other layers.


Solution

  • If you have previously built and saved the model and now want to load it and fine-tune only the last LSTM layer, then you need to set the other layers' trainable property to False. First, find the name of the layer (or index of the layer by counting from zero starting from the top) by using model.summary() method. For example this is the output produced for one of my models:

    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_10 (InputLayer)        (None, 400, 16)           0         
    _________________________________________________________________
    conv1d_2 (Conv1D)            (None, 400, 32)           4128      
    _________________________________________________________________
    lstm_2 (LSTM)                (None, 32)                8320      
    _________________________________________________________________
    dense_2 (Dense)              (None, 1)                 33        
    =================================================================
    Total params: 12,481
    Trainable params: 12,481
    Non-trainable params: 0
    _________________________________________________________________
    

    Then set the trainable parameters of all the layers except the LSTM layer to False.

    Approach 1:

    for layer in model.layers:
        if layer.name != `lstm_2`
            layer.trainable = False
    

    Approach 2:

    for layer in model.layers:
        layer.trainable = False
    
    model.layers[2].trainable = True  # set lstm to be trainable
    
    # to make sure 2 is the index of the layer
    print(model.layers[2].name)    # prints 'lstm_2'
    

    Don't forget to compile the model again to apply these changes.