Search code examples
pytorchpytorch-lightningattention-modelself-attentionvision-transformer

This code runs perfectly but I wonder what the parameter 'x' in my_forward function refers to


refering to the attention maps in VIT transformers example in: https://github.com/huggingface/pytorch-image-models/discussions/1232?sort=old

This code runs perfectly but I wonder what the parameter 'x' in my_forward function refers to. and How and where in the code the x value is passed to the function my_forward.

def my_forward(x):
        B, N, C = x.shape

        qkv = attn_obj.qkv(x).reshape(
            B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4)
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv.unbind(0) 

Solution

  • This requires a little code inspection but you can easily find the implementation if you look in the right places. Let us start with your snippet.

    • The my_forward_wrapper function is a function generator that defines my_forward and returns it. This implementation is overwriting the implementation of the last block attention layer blocks[-1].attn of the loaded model "deit_small_distilled_patch16_224".

      model.blocks[-1].attn.forward = my_forward_wrapper(model.blocks[-1].attn)
      
    • What the x corresponds to is the output of the previous block. To understand, you can dive into the source code of timm. The model loaded in the script is deit_small_distilled_patch16_224 which returns a VisionTransformerDistilled instance. The blocks are defined in the VisionTransformer class. There are n=depth blocks defined sequentially. The default block definition is given by Block in which attn is implemented by Attention, the details are given here:

      def forward(self, x: torch.Tensor) -> torch.Tensor:
          B, N, C = x.shape
          qkv = self.qkv(x) \
                    .reshape(B, N, 3, self.num_heads, self.head_dim) \
                    .permute(2, 0, 3, 1, 4)
          q, k, v = qkv.unbind(0)
          q, k = self.q_norm(q), self.k_norm(k)
      
          if self.fused_attn:
              x = F.scaled_dot_product_attention(
                  q, k, v,
                  dropout_p=self.attn_drop.p if self.training else 0.,
              )
          else:
              q = q * self.scale
              attn = q @ k.transpose(-2, -1)
              attn = attn.softmax(dim=-1)
              attn = self.attn_drop(attn)
              x = attn @ v
      
          x = x.transpose(1, 2).reshape(B, N, C)
          x = self.proj(x)
          x = self.proj_drop(x)
          return x
      

      While the implementation - that you provided - overwriting it is:

      def my_forward(x):
          B, N, C = x.shape
          qkv = attn_obj.qkv(x) \
                  .reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads) \
                  .permute(2, 0, 3, 1, 4)
          q, k, v = qkv.unbind(0)
      
          attn = (q @ k.transpose(-2, -1)) * attn_obj.scale
          attn = attn.softmax(dim=-1)
          attn = attn_obj.attn_drop(attn)
          attn_obj.attn_map = attn
          attn_obj.cls_attn_map = attn[:, :, 0, 2:]
      
          x = (attn @ v).transpose(1, 2).reshape(B, N, C)
          x = attn_obj.proj(x)
          x = attn_obj.proj_drop(x)
          return x
      

      The idea is that the attention map is being cached as an attribute to the attention layer with attn_obj.attn_map = attn, such that it can be inspected after inference.