Search code examples

Obtain the output of intermediate layer (Functional API) and use it in SubClassed API

In the keras doc, it says that if we want to pick the intermediate layer's output of the model (sequential and functional), all we need to do as follows:

model = ...  # create the original model

layer_name = 'my_layer'
intermediate_layer_model = keras.Model(inputs=model.input,
intermediate_output = intermediate_layer_model(data)

So, here we get two models, the intermediate_layer_model is the sub-model of its parent model. And they're independent as well. Likewise, if we get the intermediate layer's output feature maps of the parent model (or base model), and do some operation with it and get some output feature maps from this operation, then we can also impute this output feature maps back to the parent model.

input = tf.keras.Input(shape=(size,size,3))
model = tf.keras.applications.DenseNet121(input_tensor = input)

layer_name = "conv1_block1" # for example 
output_feat_maps = SomeOperationLayer()(model.get_layer(layer_name).output)  

# assume, they're able to add up
base = Add()([model.output, output_feat_maps])

# bind all 
imputed_model = tf.keras.Model(inputs=[model.input], outputs=base)

So, in this way we have one modified model. It's quite easy with functional API. All the keras imagenet models are written with functional API (mostly). In model subclassing API, we can use these models. My concern here is, what to do if we need the intermediate output feature maps of these functional API models' inside call function.

class Subclass(tf.keras.Model): 
    def __init__(self, dim):
         super(Subclass, self).__init__()
         self.dim = dim
         self.base = DenseNet121(input_shape=self.dim)

         # building new model with the desired output layer of base model 
         self.mid_layer_model = tf.keras.Model(self.base.inputs, 

    def call(self, inputs):
         # forward with base model 
         x = self.base(inputs)

         # forward with mid_layer_model 
         mid_feat = self.mid_layer_model(inputs)

         # do some op with it 
         mid_x = SomeOperationLayer()(mid_feat)
         # assume, they're able to add up
         out = tf.keras.layers.add([x, mid_x])

         return out 

The issue is, here we've technically two models in a joint fashion. But unlike building a model like this, here we simply want the intermediate output feature maps (from some inputs) of the base model forward manner and use it somewhere else and get some output. Like this

mid_x = SomeOperationLayer()(self.base.get_layer(layer_name).output)

But it gives ValueError: Graph disconnected. So, currently, we have to build a new model from the base model based on our desired intermediate layer. In the init method we define or create new self.mid_layer_model model that gives our desired output feature maps like this: mid_feat = self.mid_layer_model(inputs). Next, we take the mid_faet and do some operation and get some output and lastly add them with tf.keras.layers.add([x, mid_x]). So by creating a new model with desired intermediate out works but by the same time, we repeat the same operation twice i.e the base model and its subset model. Maybe I'm missing something obvious, please add up something. Is it how it is! or there some strategies we can adopt. I've asked in the forum here, no response yet.


Here is a working example. Let's say we have a custom layer like this

import tensorflow as tf
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.layers import Add
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten

class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, kernel_num=32, kernel_size=(3,3), strides=(1,1), padding='same'):
        super(ConvBlock, self).__init__()
        # conv layer
        self.conv = tf.keras.layers.Conv2D(kernel_num, 
                        strides=strides, padding=padding)
        # batch norm layer = tf.keras.layers.BatchNormalization()

    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        x =, training=training)
        return tf.nn.relu(x)

And we want to impute this layer into an ImageNet model and construct a model like this

input = tf.keras.Input(shape=(32, 32, 3))
base = DenseNet121(weights=None, input_tensor = input)

# get output feature maps of at certain layer, ie. conv2_block1_0_relu
cb = ConvBlock()(base.get_layer("conv2_block1_0_relu").output)
flat = Flatten()(cb)
dense = Dense(1000)(flat)

# adding up
adding = Add()([base.output, dense])
model = tf.keras.Model(inputs=[base.input], outputs=adding)

from tensorflow.keras.utils import plot_model 
           show_shapes=True, show_dtype=True, 

enter image description here

Here the computation from input to layer conv2_block1_0_relu is computed one time. Next, if we want to translate this functional API to subclassing API, we had to build a model from the base model's input to layer conv2_block1_0_relu first. Like

class ModelWithMidLayer(tf.keras.Model):
    def __init__(self, dim=(32, 32, 3)):
        self.dim = dim
        self.base = DenseNet121(input_shape=self.dim, weights=None)
        # building sub-model from self.base which gives 
        # desired output feature maps: ie. conv2_block1_0_relu
        self.mid_layer = tf.keras.Model(self.base.inputs,
        self.flat = Flatten()
        self.dense = Dense(1000)
        self.add = Add()
        self.cb = ConvBlock()
    def call(self, x):
        # forward with base model
        bx = self.base(x)

        # forward with mid layer
        mx = self.mid_layer(x)

        # make same shape or do whatever
        mx = self.dense(self.flat(mx))
        # combine
        out = self.add([bx, mx])
        return out
    def build_graph(self):
        x = tf.keras.layers.Input(shape=(self.dim))
        return tf.keras.Model(inputs=[x],

mwml = ModelWithMidLayer()
           show_shapes=True, show_dtype=True, 

enter image description here

Here model_1 is actually a sub-model from DenseNet, which probably leads the whole model (ModelWithMidLayer) to compute the same operation twice. If this observation is correct, then this gives us concern.


  • I thought it might be much complex but it's actually rather very simple. We just need to build a model with desired output layers at the __init__ method and use it normally in the call method.

    import tensorflow as tf
    from tensorflow.keras.applications import DenseNet121
    from tensorflow.keras.layers import Add
    from tensorflow.keras.layers import Dense
    from tensorflow.keras.layers import Flatten
    class ConvBlock(tf.keras.layers.Layer):
        def __init__(self, kernel_num=32, kernel_size=(3,3), strides=(1,1), padding='same'):
            super(ConvBlock, self).__init__()
            # conv layer
            self.conv = tf.keras.layers.Conv2D(kernel_num, 
                            strides=strides, padding=padding)
            # batch norm layer
   = tf.keras.layers.BatchNormalization()
        def call(self, input_tensor, training=False):
            x = self.conv(input_tensor)
            x =, training=training)
            return tf.nn.relu(x)
    class ModelWithMidLayer(tf.keras.Model):
        def __init__(self, dim=(32, 32, 3)):
            self.dim = dim
            self.base = DenseNet121(input_shape=self.dim, weights=None)
            # building sub-model from self.base which gives 
            # desired output feature maps: ie. conv2_block1_0_relu
            self.mid_layer = tf.keras.Model(
            self.flat = Flatten()
            self.dense = Dense(1000)
            self.add = Add()
            self.cb = ConvBlock()
        def call(self, x):
            # forward with base model
            mx, bx = self.mid_layer(x)
            # make same shape or do whatever
            mx = self.dense(self.flat(mx))
            # combine
            out = self.add([bx, mx])
            return out
        def build_graph(self):
            x = tf.keras.layers.Input(shape=(self.dim))
            return tf.keras.Model(inputs=[x],
    mwml = ModelWithMidLayer()
                              show_shapes=True, show_dtype=True, 

    enter image description here