Search code examples
kerashdf5transfer-learning

Keras - Proper way to extract weights from a nested model


I have a nested model which has an input layer, and has some final dense layers before the output. Here is the code for it:

image_input = Input(shape, name='image_input')
x = DenseNet121(input_shape=shape, include_top=False, weights=None,backend=keras.backend,
layers=keras.layers,
models=keras.models,
utils=keras.utils)(image_input)
x = GlobalAveragePooling2D(name='avg_pool')(x)
x = Dense(1024, activation='relu', name='dense_layer1_image')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)        
x = Dense(512, activation='relu', name='dense_layer2_image')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
output = Dense(num_class, activation='softmax', name='image_output')(x)
classificationModel = Model(inputs=[image_input], outputs=[output])

Now If say I wanted to extract the densenets weights from this model and perform transfer learning to another larger model which also has the same densenet model nested but also has an some other layers after the dense net such as:

image_input = Input(shape, name='image_input')
x = DenseNet121(input_shape=shape, include_top=False, weights=None,backend=keras.backend,
layers=keras.layers,
models=keras.models,
utils=keras.utils)(image_input)
x = GlobalAveragePooling2D(name='avg_pool')(x)
x = Dense(1024, activation='relu', name='dense_layer1_image')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)        
x = Dense(512, activation='relu', name='dense_layer2_image')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
x = Dense(256, activation='relu', name='dense_layer3_image')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
output = Dense(num_class, activation='sigmoid', name='image_output')(x)
classificationModel = Model(inputs=[image_input], outputs=[output])

Would I need to just do: modelB.load_weights(<weights.hdf5>, by_name=True)? Also should I name the internal densenet? and if so how?


Solution

  • You can, before using the nested model, have it into a variable. It gets a lot easier to do everything:

    densenet = DenseNet121(input_shape=shape, include_top=False, 
                           weights=None,backend=keras.backend,
                           layers=keras.layers,
                           models=keras.models,
                           utils=keras.utils)
    
    image_input = Input(shape, name='image_input')
    x = densenet(image_input)
    x = GlobalAveragePooling2D(name='avg_pool')(x)
    ......
    

    Now it's super simple to:

    weights = densenet.get_weights()
    another_densenet.set_weights(weights)
    

    The loaded file

    You can also print a model.summary() of your loaded model. The dense net will be the first or second layer (you must check this).

    You can then get it like densenet = loaded_model.layers[i].

    You can then transfer these weights to the new dense net, both with the method in the previous answer and with the new_model.layers[i].set_weights(densenet.get_weights())