Search code examples
pytorchattention-modelautoregressive-modelsmultihead-attentioncausal-inference

Masked self-attention not working as expected when each token is masking also itself


I was developing a self-attentive module using Pytorch's nn.MultiheadAttention (MHA). My goal was to implement a causal mask that enforces each token to attend only to the tokens before itself, excluding itself, unlike the standard autoregressive causal masks where tokens can attend to themselves.

Here's the function to generate my custom causal mask:


    def generate_causal_mask(seq_length):
        # Diagonal = 0, so each element attends only to elements before it, excluding itself
        mask = torch.triu(torch.full((seq_length, seq_length), 1, dtype=torch.float32), diagonal=0).bool()
        
        # Allow the first element to attend to itself to avoid NaN results
        mask[0, 0] = False
        return mask

The resulting mask looks like this:

    tensor([[False,  True,  True,  True,  True,  True,  True,  True],
            [False,  True,  True,  True,  True,  True,  True,  True],
            [False, False,  True,  True,  True,  True,  True,  True],
            [False, False, False,  True,  True,  True,  True,  True],
            [False, False, False, False,  True,  True,  True,  True],
            [False, False, False, False, False,  True,  True,  True],
            [False, False, False, False, False, False,  True,  True],
            [False, False, False, False, False, False, False,  True]])

Here, True means "cannot attend." The first element attends to itself (False at position [0, 0]) to avoid NaN results.

The code to reproduce the issue:


    if __name__ == "__main__":
        embed_dim = 16
        batch_size = 1
        seq_len = 8
    
        mha = nn.MultiheadAttention(embed_dim, num_heads=1, batch_first=True)
    
        x = torch.randn(batch_size, seq_len, embed_dim).requires_grad_(True)
    
        causal_mask = generate_causal_mask(seq_len)
    
        print(causal_mask)
    
        output, _ = mha(x, x, x, attn_mask=causal_mask)
    
        # Gradient of the output with respect to the token at position t
        t = 5
        loss = output[:, t].sum().backward()
        print("Gradient of the token:")
        print(x.grad)

Observed Behavior

When printing the gradient of the input (x.grad) for token t = 5, I noticed that the output at time step t = 5 depends on its own value. This is unexpected because, according to the causal mask, tokens should only attend to elements before themselves.

tensor([[[ 1.7815e-02, 6.0239e-02, 4.4045e-02, -1.7005e-02, -1.2529e-01, -9.8527e-02, -2.5346e-02, 4.4857e-02, -9.7425e-02, 1.0793e-01, 1.4662e-01, 1.0073e-01, -9.0143e-02, -2.5913e-02, 1.3379e-03, -9.0163e-02],

[ 2.6240e-01, 1.4095e-01, 2.9541e-01, 6.0876e-02, -1.5522e-01, -1.5531e-01, 4.4279e-02, 6.3482e-02, -2.1853e-01, 2.4059e-02, 2.2273e-01, 1.1566e-01, 6.6013e-02, -1.2247e-01, -1.1333e-01, -1.5512e-01],

[ 5.3024e-02, 4.4725e-02, 6.7385e-02, 5.5258e-03, -6.8150e-02, -5.9587e-02, -1.4061e-04, 2.5825e-02, -7.0633e-02, 3.8935e-02, 8.7158e-02, 5.3142e-02, -1.6992e-02, -3.0389e-02, -2.0005e-02, -5.6871e-02],

[ 2.9774e-01, 1.1942e-01, 3.1602e-01, 8.5978e-02, -8.4358e-02, -1.0587e-01, 7.2915e-02, 3.9608e-02, -1.8192e-01, -5.7704e-02, 1.4758e-01, 5.6968e-02, 1.5057e-01, -1.2490e-01, -1.3581e-01, -1.1233e-01],

[ 1.1037e-01, 7.4862e-02, 1.3163e-01, 1.9109e-02, -1.0056e-01, -9.2370e-02, 9.9104e-03, 3.9165e-02, -1.1730e-01, 4.2791e-02, 1.3410e-01, 7.7194e-02, -1.3165e-03, -5.6924e-02, -4.4891e-02, -8.9721e-02],

[-6.6541e-02, -1.0303e-02, -3.5482e-02, 2.1983e-02, -5.1578e-02, 2.0161e-01, 7.2047e-02, -4.0216e-02, -1.7608e-02, -1.2176e-02, -5.2893e-02, -1.1424e-01, 4.6907e-03, -1.0784e-01, 5.8249e-02, 9.0503e-03],

[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],

[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]])

Expected Behavior

Given the custom causal mask, the output at token t should depend only on tokens at earlier time steps (0 to t-1). It should not depend on itself, as the diagonal is masked.

Request for Clarification

Is this behavior a bug in the MultiheadAttention implementation, or am I misunderstanding how attn_mask works? If this is intended behavior, could you please clarify how to correctly achieve the desired masking effect?

Versions of relevant libraries:

[pip3] numpy==2.1.3 [pip3] torch==2.5.1 [pip3] torchaudio==2.5.1 [pip3] torchvision==0.20.1


Solution

  • Think about where the attention weights come from. Attention computes:

    Attention(Q, K, V) = softmax(QK^T)V
    

    Your attention scores are computed via attn_weights = softmax(QK^T), and the attention mask masks the values in QK^T prior to the softmax.

    You mask the weights such that a vector t doesn't attend to itself, but all the other attention weights in question still depend on vector t, so you still have a nonzero gradient.

    Take your example of t=5. When vector t=5 attends to the first vector t=0, this interaction is weighted by attn_weights[0, 5]. The value of attn_weights[0, 5] depends on both vector t=0 and vector t=5 (plus vectors 0-4 from the softmax denominator). Even though vector t=5 doesn't attend to itself, the attention weights for the other vectors it does attend to depend on the value of vector t=5, so you still have a nonzero gradient.