Search code examples
tensorflowkerasdeep-learningresnet

How to split resnet50 model from top as well as from bottom?


I am using keras pretrained model with include_top=False, but i also want to remove one resblock from top and one from bottom as well. For vgg net, it is simple, because of straight forward links in layers, but in resnet, architecture is complicated because of skip connection, so direct approach doesn't fit well.

Can somebody recommend any resource or scripts to do it?

renet = tf.keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet')

Solution

  • If you do not understand correctly, you want to eliminate the first block and the last one.

    My advice is to use resnet.summary () to be able to visualize all the names of the model. Or even better if you have a tensorboard to see the relationships clearly.

    Although you can know the completion of a block in Residual Network is a sum and just followed an activation. Activation will be the layer you want to obtain.

    The names of the blocks are similar to res2a ... The number 2 indicates the block and the letter the "subblock".

    Based on the Resnet50 architecture:

    enter image description here

    If I am looking to remove the first residual block, I must look for the end of res2c. In this case I found this:

    activation_57 (Activation) (None, 56, 56, 64) 0 bn2c_branch2a [0] [0]
    __________________________________________________________________________________________________
    res2c_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_57 [0] [0]
    __________________________________________________________________________________________________
    bn2c_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2b [0] [0]
    __________________________________________________________________________________________________
    activation_58 (Activation) (None, 56, 56, 64) 0 bn2c_branch2b [0] [0]
    __________________________________________________________________________________________________
    res2c_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_58 [0] [0]
    __________________________________________________________________________________________________
    bn2c_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2c_branch2c [0] [0]
    __________________________________________________________________________________________________
    add_19 (Add) (None, 56, 56, 256) 0 bn2c_branch2c [0] [0]
                                                                     activation_56 [0] [0]
    __________________________________________________________________________________________________
    activation_59 (Activation) (None, 56, 56, 256) 0 add_19 [0] [0]
    __________________________________________________________________________________________________
    res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_59 [0] [0]
    

    The input layer is the res3a_branch2a. This form I jump the first block of residuals.

    activation_87 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2a[0][0]              
    __________________________________________________________________________________________________
    res4f_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_87[0][0]              
    __________________________________________________________________________________________________
    bn4f_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2b[0][0]             
    __________________________________________________________________________________________________
    activation_88 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2b[0][0]              
    __________________________________________________________________________________________________
    res4f_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_88[0][0]              
    __________________________________________________________________________________________________
    bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4f_branch2c[0][0]             
    __________________________________________________________________________________________________
    add_29 (Add)                    (None, 14, 14, 1024) 0           bn4f_branch2c[0][0]              
                                                                     activation_86[0][0]              
    __________________________________________________________________________________________________
    activation_89 (Activation)      (None, 14, 14, 1024) 0           add_29[0][0]   
    

    If I am looking to remove the last block of residuals, I should look for the end of res4. Thaat is activation_89.

    Making these cuts we would have this model:

    enter image description here

    resnet_cut = Model(inputs=resnet.get_layer('res3a_branch2a'), outputs=resnet.get_layer('activation_89'))