Search code examples
python-2.7keraskeras-layer

How to view the summary of neural networks with keras functional api


I have a very large neural network which I am making using keras functional api. I want to monitor the parameters and shape of layers added in a model which is not defined yet something like model.summary().

If I have a model like this

input_img = Input(shape=(256, 256, 3))

tower_1 = Conv2D(64, (1, 1), padding='same', activation='relu')(input_img)
tower_1 = Conv2D(64, (3, 3), padding='same', activation='relu')(tower_1)
#stage1

tower_2 = Conv2D(64, (1, 1), padding='same', activation='relu')(input_img)
tower_2 = Conv2D(64, (5, 5), padding='same', activation='relu')(tower_2)
#stage2

tower_3 = MaxPooling2D((3, 3), strides=(1, 1), padding='same')(input_img)
tower_3 = Conv2D(64, (1, 1), padding='same', activation='relu')(tower_3)
#stage3

output = keras.layers.concatenate([tower_1, tower_2, tower_3], axis=1)

I want the summary() of this progressing model at these various stages. I know we can do model.summary() by defining the model=Model(input,output) but can we do this as we are progressing through the layers?


Solution

  • You can easily get the compile-time shape of any Keras tensor with the _keras_shape member variable, like:

    input_img = Input(shape=(256, 256, 3))
    
    tower_1 = Conv2D(64, (1, 1), padding='same', activation='relu')(input_img)
    tower_1 = Conv2D(64, (3, 3), padding='same', activation='relu')(tower_1)
    #stage1
    
    tower_2 = Conv2D(64, (1, 1), padding='same', activation='relu')(input_img)
    tower_2 = Conv2D(64, (5, 5), padding='same', activation='relu')(tower_2)
    #stage2
    
    tower_3 = MaxPooling2D((3, 3), strides=(1, 1), padding='same')(input_img)
    tower_3 = Conv2D(64, (1, 1), padding='same', activation='relu')(tower_3)
    #stage3
    
    output = keras.layers.concatenate([tower_1, tower_2, tower_3], axis=1)
    
    print("Output shape is: {}".format(output._keras_shape))
    

    You can do this at any point in your computation as long as you have a TensorVariable (the output of a layer). Its not the same as the full summary but it helps a lot for debugging.