refering to the attention maps in VIT transformers example in: https://github.com/huggingface/pytorch-image-models/discussions/1232?sort=old
This code runs perfectly but I wonder what the parameter 'x' in my_forward function refers to. and How and where in the code the x value is passed to the function my_forward.
def my_forward(x):
B, N, C = x.shape
qkv = attn_obj.qkv(x).reshape(
B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv.unbind(0)
This requires a little code inspection but you can easily find the implementation if you look in the right places. Let us start with your snippet.
The my_forward_wrapper
function is a function generator that defines my_forward
and returns it. This implementation is overwriting the implementation of the last block attention layer blocks[-1].attn
of the loaded model "deit_small_distilled_patch16_224"
.
model.blocks[-1].attn.forward = my_forward_wrapper(model.blocks[-1].attn)
What the x
corresponds to is the output of the previous block. To understand, you can dive into the source code of timm. The model loaded in the script is deit_small_distilled_patch16_224
which returns a VisionTransformerDistilled
instance. The blocks are defined in the VisionTransformer
class. There are n=depth
blocks defined sequentially. The default block definition is given by Block
in which attn is implemented by Attention
, the details are given here:
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x) \
.reshape(B, N, 3, self.num_heads, self.head_dim) \
.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
While the implementation - that you provided - overwriting it is:
def my_forward(x):
B, N, C = x.shape
qkv = attn_obj.qkv(x) \
.reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads) \
.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * attn_obj.scale
attn = attn.softmax(dim=-1)
attn = attn_obj.attn_drop(attn)
attn_obj.attn_map = attn
attn_obj.cls_attn_map = attn[:, :, 0, 2:]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = attn_obj.proj(x)
x = attn_obj.proj_drop(x)
return x
The idea is that the attention map is being cached as an attribute to the attention layer with attn_obj.attn_map = attn
, such that it can be inspected after inference.