Search code examples

How to add an auxiliary head to an intermediary layer of a pretrained keras model?

This is my first question on stack overflow. I'm going to try to give as much context as I can. Thank you for taking the time to read my question !

I'm currently using efficentnet for a classification problem. I want to add an auxiliary head on an intermediary layer. By auxiliary head, I mean an other set of layers who is going to produce a second output (2 final outputs).

Currently I managed to add an additional head at the end of the model with the following code :

inputs = tf.keras.Input(shape=(img_size, img_size, 3), name='input')
x =  efn.EfficientNetB7(input_shape=(img_size, img_size, 3), include_top=False)(inputs)

classification_head = tf.keras.layers.GlobalAveragePooling2D()(x)
classification_head = tf.keras.layers.Dense(4, activation='softmax', name = 'classification')(classification_head)

aux_head = tf.keras.layers.Conv2D(128, kernel_size = 3, padding='same')(x)
aux_head = tf.keras.layers.BatchNormalization()(aux_head)
aux_head = tf.keras.layers.ReLU()(aux_head)
aux_head = tf.keras.layers.Conv2D(1, kernel_size=1, padding= 'valid', name = 'aux_head')(aux_head)
model = tf.keras.Model(inputs, [classification_head,aux_head])

I want to do a similar procedure but by adding the aux_head directly on an intermediary layer (here it is named block5a_expand_conv), what I've tried is:

inputs = tf.keras.Input(shape=(img_size, img_size, 3), name='input')
x = efn.EfficientNetB7(input_shape=(img_size, img_size, 3), include_top=False)(inputs)

classification_head = tf.keras.layers.GlobalAveragePooling2D()(x)
classification_head = tf.keras.layers.Dense(4, activation='softmax', name = 'classification')(classification_head)
intermediary_layer = x(
                input_shape=(img_sisze, img_sisze, 3),
                include_top=False).get_layer(name = 'block5a_expand_conv')

aux_head = tf.keras.layers.Conv2D(128, kernel_size = 3, padding='same')(intermediary_layer.output)
aux_head = tf.keras.layers.BatchNormalization()(aux_head)
aux_head = tf.keras.layers.ReLU()(aux_head)
aux_head = tf.keras.layers.Conv2D(1, kernel_size=1, padding= 'valid', name = 'aux_head')(aux_head)
model = tf.keras.Model(inputs, [classification_head,aux_head])

But this code produce an error named:

Graph disconnected

Does anyone have an idea on what could do the job here?


  • This error indicates, tensorflow could not make a graph (model) between your layers defined as the inputs and outputs. Somewhere in your model the path through layers disconnected. So, check the path between your inputs and outputs layers.

    In your case, you have defined one input layer as inputs. You feed this input through an efficientnet, and some your own defined layers then get the output as classification_head. Still we have no problem. As the next path, you wanted to get a hidden layer output named block5a_expand_conv and feed it through some other layers and get another output as aux_head.

    So, the question is, where is the input of this path? It finds no input for it because you defined intermediary as a layer, not as the output of the layer and it can not connect through intermediary layer to the input, because you didn't define it correctly. This is where the graph gets disconnected.

    Here is the modified code:

    en_model = efn.EfficientNetB7(input_shape=(img_size, img_size, 3), include_top=False)
    classification_head = tf.keras.layers.GlobalAveragePooling2D()(en_model.output)
    classification_head = tf.keras.layers.Dense(4, activation='softmax', name = 'classification')(classification_head)
    intermediary_layer = en_model.get_layer('block5a_expand_conv').output
    aux_head = tf.keras.layers.Conv2D(128, kernel_size = 3, padding='same')(intermediary_layer)
    aux_head = tf.keras.layers.BatchNormalization()(aux_head)
    aux_head = tf.keras.layers.ReLU()(aux_head)
    aux_head = tf.keras.layers.Conv2D(1, kernel_size=1, padding= 'valid', name = 'aux_head')(aux_head)
    model = tf.keras.Model(en_model.input, [classification_head,aux_head])