Search code examples
pythonpytorchattention-model

Pytorch MultiHeadAttention error with query sequence dimension different from key/value dimension


I am playing around with the pytorch implementation of MultiHeadAttention.

In the docs it states that the query dimensions are [N,L,E] (assuming batch_first=True) where N is the batch dimension, L is the target sequence length and E is the embedding dimension.

It then states that the key and value dimensions are [N,S,E], where S is the source sequence length. Presumably this means S and L don't need to be equal, which makes sense.

However, if one runs the following:

import torch
import torch.nn as nn

input_size = 10
batch_size = 3
window_size = 2

attention = nn.MultiheadAttention(input_size, num_heads=1)

q = torch.empty(batch_size, 1, input_size)
k = v = torch.empty(batch_size, window_size, input_size)

y = attention(q, k, v, need_weights=False)

The following error is produced:

.../lib/python3.8/site-packages/torch/nn/functional.py:5044, in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
   5042 q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
   5043 if static_k is None:
-> 5044     k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
   5045 else:
   5046     # TODO finish disentangling control flow so we don't do in-projections when statics are passed
   5047     assert static_k.size(0) == bsz * num_heads, \
   5048         f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"

RuntimeError: shape '[3, 1, 10]' is invalid for input of size 60

Have I missed something?

I am using torch v1.10.2.


Solution

  • My bad, I will post this in-case anyone else encounters the error.

    The default value of batch_first is False setting it to True fixes the issue.

    attention = nn.MultiheadAttention(input_size, num_heads=1, batch_first=True)