Search code examples
pythontensorflowkerasdeep-learningtransfer-learning

How to stack Transfer Learning models in a Sequential


To make a nice architecture, I wanted to stack Transfer Learning models one over the other.

The three models I wanted to stack were :

  • VGG16
  • InceptionV3
  • Resnet50

So, I defined the three models as follows :

model_vgg = tf.keras.applications.VGG16(
    weights='imagenet', 
    include_top=False,
    input_shape=(SIZE, SIZE, 3)
)

model_inc = tf.keras.applications.inception_v3.InceptionV3(
    weights='imagenet', 
    include_top=False,
    input_shape=(SIZE, SIZE, 3)
)

model_res = tf.keras.applications.ResNet50(
    weights='imagenet', 
    include_top=False,
    input_shape=(SIZE, SIZE, 3)
)

Size was set as 100

After this, I set trainable=False for each one of them

Now, how would I stack these models in a sequential i.e what changes will I have to make so that the output shape for each model matches the input shape for the next ?

model = tf.keras.Sequential([
    
    model_vgg,
    model_inc,
    model_res,
    tf.keras.layers.Flatten()
    
])

Solution

  • Since each model has a different output shape, you will have to reshape each one before feeding it to the next model and this will probably impact the performance:

    import tensorflow as tf
    
    SIZE = 100
    model_vgg = tf.keras.applications.VGG16(
        weights='imagenet', 
        include_top=False,
        input_shape=(SIZE, SIZE, 3)
    )
    
    model_inc = tf.keras.applications.inception_v3.InceptionV3(
        weights='imagenet', 
        include_top=False,
        input_shape=(SIZE, SIZE, 3)
    )
    
    model_res = tf.keras.applications.ResNet50(
        weights='imagenet', 
        include_top=False,
        input_shape=(SIZE, SIZE, 3)
    )
    
    model_vgg.trainable = False
    model_inc.trainable = False
    model_res.trainable = False
    
    model = tf.keras.Sequential([
        
        model_vgg,
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(SIZE*SIZE*3),
        tf.keras.layers.Reshape((SIZE, SIZE, 3)),
        model_inc,
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(SIZE*SIZE*3),
        tf.keras.layers.Reshape((SIZE, SIZE, 3)),
        model_res,
        tf.keras.layers.Flatten()
        
    ])
    print(model(tf.random.normal((1, 100, 100, 3))).shape)
    

    You will also have to decide if you want to use a nonlinear activation function on each Dense layer. Oh, and you could also use the preprocessing methods of each model like this:

    model = tf.keras.Sequential([
        
        tf.keras.layers.Lambda(lambda x: tf.keras.applications.vgg16.preprocess_input(x)),
        model_vgg,
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(SIZE*SIZE*3),
        tf.keras.layers.Reshape((SIZE, SIZE, 3)),
        tf.keras.layers.Lambda(lambda x: tf.keras.applications.inception_v3.preprocess_input(x)),
        model_inc,
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(SIZE*SIZE*3),
        tf.keras.layers.Reshape((SIZE, SIZE, 3)),
        tf.keras.layers.Lambda(lambda x: tf.keras.applications.resnet50.preprocess_input(x)),
        model_res,
        tf.keras.layers.Flatten()
        
    ])
    

    My personal suggestion would be to feed the inputs into the individual models and then concatenate the outputs and run other downstream operations:

    inputs = tf.keras.layers.Input((SIZE, SIZE, 3))
        
    vgg = tf.keras.layers.Lambda(lambda x: tf.keras.applications.vgg16.preprocess_input(x))(inputs)
    vgg = tf.keras.layers.GlobalAvgPool2D()(model_vgg(vgg))
    
    inc = tf.keras.layers.Lambda(lambda x: tf.keras.applications.inception_v3.preprocess_input(x))(inputs)
    inc = tf.keras.layers.GlobalAvgPool2D()(model_inc(inc))
    
    res = tf.keras.layers.Lambda(lambda x: tf.keras.applications.resnet50.preprocess_input(x))(inputs)
    res = tf.keras.layers.GlobalAvgPool2D()(model_res(res))
    
    outputs = tf.keras.layers.Concatenate(axis=-1)([vgg, inc, res])
    model = tf.keras.Model(inputs, outputs)