Search code examples
pytorchhooktransformer-modelself-attention

Store intermediate values of pytorch module


I try to plot attention maps for ViT. I know that I can do something like
h_attn = model.blocks[-1].attn.register_forward_hook(get_activations('attention')) to register a hook that camputres output of some nn.module in my model.
The ViT's attention layer has the following forward structure:

def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x

Can I somehow attach the hook such that i get the attn value and not the return value of forward (e.g. by using some kind of dummy-module)?


Solution

  • In this case, what you want to do is capture intermediate outputs within the forward method of a module, specifically the attn tensor. Hooks are called when a module's forward or backward method is called, but they do not directly allow you to capture intermediate values.

    However, you can create a new subclass of the attention module where you modify the forward method to store the attn value as an attribute, which you can access later.

    Here's an example of how you can create such a subclass:

    class AttentionWithStoredAttn(Attention):
    
        def forward(self, x):
            B, N, C = x.shape
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
    
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
        
            self.stored_attn = attn.detach().cpu()  # Store the attn tensor
        
            x = (attn @ v).transpose(1, 2).reshape(B, N, C)
            x = self.proj(x)
            x = self.proj_drop(x)
        
            return x
    

    Then you can replace the original attention module in your model with this new one. After each forward pass, you can access the stored attn tensor via the stored_attn attribute of the module.

    This will increase your memory usage because you are storing the attn tensor. Also, remember to call model.eval() before inference to disable dropout and other training-specific operations. If you are working with multiple inputs in a batch, the stored attention maps will be overwritten at each forward pass, so retrieve or save them as needed.