Search code examples
pythonpytorchtransformer-model

Extracting specific blocks from a module list


I'm using a pretrained model in which there are several self_attentions sequentially stacked each one after another and the number of them is 12. I need to extract the output of the fourth and 10th blocks of this sequential layers. In the following script, the BLock represents each self-attention layer:

dpr = [x.item() for x in torch.linspace(0, 0.1, 12)]  # stochastic depth decay rule
    
    self.blocks = nn.ModuleList([
        Block(
            dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, attention_type=self.attention_type)
        for i in range(12)])

The self-attention layers (stack of Block) are as follows:

## Attention blocks
    for blk in self.blocks:
        x = blk(x, B, T, W)

How can I extract the fourth and the 10th layers' output?


Solution

  • To extract the output of a layer, you'll need to use hooks. A forward hook is a function that is called after the forward method of the module was executed.

    Here's an example of how to do it:

    vit = model(...)  # your model with 12 transformer blocks
    
    features = {l:[] for l in range(len(vit.blocks))}  # place holder for the extracted features
    
    def make_hook_function(layer):
    
      def hook(module, input, output):
        features[layer].appned(output)  # save the output of the layer to the place holder
    
      return hook
    
    # place the hooks on the layers that interest you
    vit.blocks[4].register_forward_hook(make_hook_function(4))
    vit.blocks[10].register_forward_hook(make_hook_function(10))
    
    pred = vit(x)  # run an image through the model
    
    # now you can inspect features[4] and features[10]
    

    A very comprehensive example of ViT feature extractor can be found here.