Search code examples
pythontensorflowkerastf.kerastransfer-learning

Remove top layer from pre-trained model, transfer learning, tensorflow (load_model)


I have pre-trained a model (my own saved model) with two classes, which I want to use for transfer learning to train a model with six classes. I have loaded the pre-trained model into the new training script:

base_model = tf.keras.models.load_model("base_model_path")

How can I remove the top/head layer (a conv1D layer) ?

I see that in keras one can use base_model.pop(), and for tf.keras.applications one can simply use include_top=false but is there something similar when using tf.keras and load_model?

(I have tried something like this:

for layer in base_model.layers[:-1]:
    layer.trainable = False`

and then add it to a new model (?) but I am not sure on how to continue)

Thanks for any help!


Solution

  • You could try something like this:

    The base model is made up of a simple Conv1D network with an output layer with two classes:

    import tensorflow as tf
    
    samples = 100
    timesteps = 5
    features = 2
    classes = 2
    dummy_x, dummy_y = tf.random.normal((100, 5, 2)), tf.random.uniform((100, 1), maxval=2, dtype=tf.int32)
    
    base_model = tf.keras.Sequential()
    base_model.add(tf.keras.layers.Conv1D(32, 3, activation='relu', input_shape=(5, 2)))
    base_model.add(tf.keras.layers.GlobalMaxPool1D())
    base_model.add(tf.keras.layers.Dense(32, activation='relu'))
    base_model.add( tf.keras.layers.Dense(classes, activation='softmax'))
    
    base_model.compile(optimizer='adam', loss = tf.keras.losses.SparseCategoricalCrossentropy())
    print(base_model.summary())
    base_model.fit(dummy_x, dummy_y, batch_size=16, epochs=1)
    base_model.save("base_model")
    base_model = tf.keras.models.load_model("base_model")
    
    Model: "sequential_8"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     conv1d_31 (Conv1D)          (None, 3, 32)             224       
                                                                     
     global_max_pooling1d_13 (Gl  (None, 32)               0         
     obalMaxPooling1D)                                               
                                                                     
     dense_17 (Dense)            (None, 32)                1056      
                                                                     
     dense_18 (Dense)            (None, 2)                 66        
                                                                     
    =================================================================
    Total params: 1,346
    Trainable params: 1,346
    Non-trainable params: 0
    _________________________________________________________________
    None
    7/7 [==============================] - 0s 3ms/step - loss: 0.6973
    INFO:tensorflow:Assets written to: base_model/assets
    

    The new model is also is made up of a simple Conv1D network, but with an output layer with six classes. It also contains all the layers of the base_model except the first Conv1D layer and the last output layer:

    classes = 6
    dummy_x, dummy_y = tf.random.normal((100, 5, 2)), tf.random.uniform((100, 1), maxval=6, dtype=tf.int32)
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv1D(64, 3, activation='relu', input_shape=(5, 2)))
    model.add(tf.keras.layers.Conv1D(32, 2, activation='relu'))
    for layer in base_model.layers[1:-1]: # Skip first and last layer
      model.add(layer)
    model.add(tf.keras.layers.Dense(classes, activation='softmax'))
    model.compile(optimizer='adam', loss = tf.keras.losses.SparseCategoricalCrossentropy())
    print(model.summary())
    model.fit(dummy_x, dummy_y, batch_size=16, epochs=1)
    
    Model: "sequential_9"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     conv1d_32 (Conv1D)          (None, 3, 64)             448       
                                                                     
     conv1d_33 (Conv1D)          (None, 2, 32)             4128      
                                                                     
     global_max_pooling1d_13 (Gl  (None, 32)               0         
     obalMaxPooling1D)                                               
                                                                     
     dense_17 (Dense)            (None, 32)                1056      
                                                                     
     dense_19 (Dense)            (None, 6)                 198       
                                                                     
    =================================================================
    Total params: 5,830
    Trainable params: 5,830
    Non-trainable params: 0
    _________________________________________________________________
    None
    7/7 [==============================] - 0s 3ms/step - loss: 1.8069
    <keras.callbacks.History at 0x7f90c87a3c50>