Search code examples
pythonneural-networktensorflow2.0tf.kerasresnet

How to access and visualize the weights in a pre-trained TensorFlow 2 model?


So, I have re-trained a pre-trained ResNet50 V2 model in TensorFlow 2 using the tf.keras framework with two Dense layers added to the top. Now I want to visualize the weights in the layers within the base ResNet model. However, reloading the saved model with

model = tf.keras.models.load_model(path/to/model.hdf5)
model.summary()

results in

enter image description here

As you can see, the layers of ResNet model are not individually listed, meaning that calling

model.layers[0].get_weights()[1]

will only result in

[7 7 3 64]

Thus, how do I access the weights inside each of the layers in the base ResNet50 V2 model?


Solution

  • The right answer here was to write

    model.layers[0].summary()
    

    instead of

    model.summary()
    

    Which will let me then see what all the layers are within the pre-trained model. Thus, writing

    model.layers[0].layers[0].get_weights(name='input_1')
    

    will give me the weights of the input to the ResNet base model.