I am using keras
pretrained model with include_top=False
, but i also want to remove one resblock from top
and one from bottom
as well. For vgg
net, it is simple, because of straight forward links in layers, but in resnet, architecture is complicated because of skip connection, so direct approach doesn't fit well.
Can somebody recommend any resource or scripts to do it?
renet = tf.keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet')
If you do not understand correctly, you want to eliminate the first block and the last one.
My advice is to use resnet.summary () to be able to visualize all the names of the model. Or even better if you have a tensorboard to see the relationships clearly.
Although you can know the completion of a block in Residual Network is a sum and just followed an activation. Activation will be the layer you want to obtain.
The names of the blocks are similar to res2a ... The number 2 indicates the block and the letter the "subblock".
Based on the Resnet50 architecture:
If I am looking to remove the first residual block, I must look for the end of res2c. In this case I found this:
activation_57 (Activation) (None, 56, 56, 64) 0 bn2c_branch2a [0] [0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_57 [0] [0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2b [0] [0]
__________________________________________________________________________________________________
activation_58 (Activation) (None, 56, 56, 64) 0 bn2c_branch2b [0] [0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_58 [0] [0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2c_branch2c [0] [0]
__________________________________________________________________________________________________
add_19 (Add) (None, 56, 56, 256) 0 bn2c_branch2c [0] [0]
activation_56 [0] [0]
__________________________________________________________________________________________________
activation_59 (Activation) (None, 56, 56, 256) 0 add_19 [0] [0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_59 [0] [0]
The input layer is the res3a_branch2a. This form I jump the first block of residuals.
activation_87 (Activation) (None, 14, 14, 256) 0 bn4f_branch2a[0][0]
__________________________________________________________________________________________________
res4f_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_87[0][0]
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2b[0][0]
__________________________________________________________________________________________________
activation_88 (Activation) (None, 14, 14, 256) 0 bn4f_branch2b[0][0]
__________________________________________________________________________________________________
res4f_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_88[0][0]
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4f_branch2c[0][0]
__________________________________________________________________________________________________
add_29 (Add) (None, 14, 14, 1024) 0 bn4f_branch2c[0][0]
activation_86[0][0]
__________________________________________________________________________________________________
activation_89 (Activation) (None, 14, 14, 1024) 0 add_29[0][0]
If I am looking to remove the last block of residuals, I should look for the end of res4. Thaat is activation_89.
Making these cuts we would have this model:
resnet_cut = Model(inputs=resnet.get_layer('res3a_branch2a'), outputs=resnet.get_layer('activation_89'))