I was developing a self-attentive module using Pytorch's nn.MultiheadAttention
(MHA). My goal was to implement a causal mask that enforces each token to attend only to the tokens before itself, excluding itself, unlike the standard autoregressive causal masks where tokens can attend to themselves.
Here's the function to generate my custom causal mask:
def generate_causal_mask(seq_length):
# Diagonal = 0, so each element attends only to elements before it, excluding itself
mask = torch.triu(torch.full((seq_length, seq_length), 1, dtype=torch.float32), diagonal=0).bool()
# Allow the first element to attend to itself to avoid NaN results
mask[0, 0] = False
return mask
The resulting mask looks like this:
tensor([[False, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, True],
[False, False, True, True, True, True, True, True],
[False, False, False, True, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, False, True, True, True],
[False, False, False, False, False, False, True, True],
[False, False, False, False, False, False, False, True]])
Here, True
means "cannot attend." The first element attends to itself (False
at position [0, 0]) to avoid NaN results.
The code to reproduce the issue:
if __name__ == "__main__":
embed_dim = 16
batch_size = 1
seq_len = 8
mha = nn.MultiheadAttention(embed_dim, num_heads=1, batch_first=True)
x = torch.randn(batch_size, seq_len, embed_dim).requires_grad_(True)
causal_mask = generate_causal_mask(seq_len)
print(causal_mask)
output, _ = mha(x, x, x, attn_mask=causal_mask)
# Gradient of the output with respect to the token at position t
t = 5
loss = output[:, t].sum().backward()
print("Gradient of the token:")
print(x.grad)
When printing the gradient of the input (x.grad) for token t = 5
, I noticed that the output at time step t = 5
depends on its own value. This is unexpected because, according to the causal mask, tokens should only attend to elements before themselves.
tensor([[[ 1.7815e-02, 6.0239e-02, 4.4045e-02, -1.7005e-02, -1.2529e-01, -9.8527e-02, -2.5346e-02, 4.4857e-02, -9.7425e-02, 1.0793e-01, 1.4662e-01, 1.0073e-01, -9.0143e-02, -2.5913e-02, 1.3379e-03, -9.0163e-02],
[ 2.6240e-01, 1.4095e-01, 2.9541e-01, 6.0876e-02, -1.5522e-01, -1.5531e-01, 4.4279e-02, 6.3482e-02, -2.1853e-01, 2.4059e-02, 2.2273e-01, 1.1566e-01, 6.6013e-02, -1.2247e-01, -1.1333e-01, -1.5512e-01],
[ 5.3024e-02, 4.4725e-02, 6.7385e-02, 5.5258e-03, -6.8150e-02, -5.9587e-02, -1.4061e-04, 2.5825e-02, -7.0633e-02, 3.8935e-02, 8.7158e-02, 5.3142e-02, -1.6992e-02, -3.0389e-02, -2.0005e-02, -5.6871e-02],
[ 2.9774e-01, 1.1942e-01, 3.1602e-01, 8.5978e-02, -8.4358e-02, -1.0587e-01, 7.2915e-02, 3.9608e-02, -1.8192e-01, -5.7704e-02, 1.4758e-01, 5.6968e-02, 1.5057e-01, -1.2490e-01, -1.3581e-01, -1.1233e-01],
[ 1.1037e-01, 7.4862e-02, 1.3163e-01, 1.9109e-02, -1.0056e-01, -9.2370e-02, 9.9104e-03, 3.9165e-02, -1.1730e-01, 4.2791e-02, 1.3410e-01, 7.7194e-02, -1.3165e-03, -5.6924e-02, -4.4891e-02, -8.9721e-02],
[-6.6541e-02, -1.0303e-02, -3.5482e-02, 2.1983e-02, -5.1578e-02, 2.0161e-01, 7.2047e-02, -4.0216e-02, -1.7608e-02, -1.2176e-02, -5.2893e-02, -1.1424e-01, 4.6907e-03, -1.0784e-01, 5.8249e-02, 9.0503e-03],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]])
Given the custom causal mask, the output at token t
should depend only on tokens at earlier time steps (0 to t-1
). It should not depend on itself, as the diagonal is masked.
Is this behavior a bug in the MultiheadAttention implementation, or am I misunderstanding how attn_mask
works? If this is intended behavior, could you please clarify how to correctly achieve the desired masking effect?
Versions of relevant libraries:
[pip3] numpy==2.1.3 [pip3] torch==2.5.1 [pip3] torchaudio==2.5.1 [pip3] torchvision==0.20.1
Think about where the attention weights come from. Attention computes:
Attention(Q, K, V) = softmax(QK^T)V
Your attention scores are computed via attn_weights = softmax(QK^T)
, and the attention mask masks the values in QK^T
prior to the softmax.
You mask the weights such that a vector t
doesn't attend to itself, but all the other attention weights in question still depend on vector t
, so you still have a nonzero gradient.
Take your example of t=5
. When vector t=5
attends to the first vector t=0
, this interaction is weighted by attn_weights[0, 5]
. The value of attn_weights[0, 5]
depends on both vector t=0
and vector t=5
(plus vectors 0-4 from the softmax denominator). Even though vector t=5
doesn't attend to itself, the attention weights for the other vectors it does attend to depend on the value of vector t=5
, so you still have a nonzero gradient.