I'm using a pretrained model in which there are several self_attentions sequentially stacked each one after another and the number of them is 12. I need to extract the output of the fourth and 10th blocks of this sequential layers. In the following script, the BLock
represents each self-attention layer:
dpr = [x.item() for x in torch.linspace(0, 0.1, 12)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, attention_type=self.attention_type)
for i in range(12)])
The self-attention layers (stack of Block
) are as follows:
## Attention blocks
for blk in self.blocks:
x = blk(x, B, T, W)
How can I extract the fourth and the 10th layers' output?
To extract the output of a layer, you'll need to use hooks. A forward hook is a function that is called after the forward
method of the module was executed.
Here's an example of how to do it:
vit = model(...) # your model with 12 transformer blocks
features = {l:[] for l in range(len(vit.blocks))} # place holder for the extracted features
def make_hook_function(layer):
def hook(module, input, output):
features[layer].appned(output) # save the output of the layer to the place holder
return hook
# place the hooks on the layers that interest you
vit.blocks[4].register_forward_hook(make_hook_function(4))
vit.blocks[10].register_forward_hook(make_hook_function(10))
pred = vit(x) # run an image through the model
# now you can inspect features[4] and features[10]
A very comprehensive example of ViT feature extractor can be found here.