I am playing around with the pytorch implementation of MultiHeadAttention.
In the docs it states that the query dimensions are [N,L,E]
(assuming batch_first=True
) where N
is the batch dimension, L
is the target sequence length and E
is the embedding dimension.
It then states that the key and value dimensions are [N,S,E]
, where S
is the source sequence length. Presumably this means S
and L
don't need to be equal, which makes sense.
However, if one runs the following:
import torch
import torch.nn as nn
input_size = 10
batch_size = 3
window_size = 2
attention = nn.MultiheadAttention(input_size, num_heads=1)
q = torch.empty(batch_size, 1, input_size)
k = v = torch.empty(batch_size, window_size, input_size)
y = attention(q, k, v, need_weights=False)
The following error is produced:
.../lib/python3.8/site-packages/torch/nn/functional.py:5044, in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
5042 q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
5043 if static_k is None:
-> 5044 k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
5045 else:
5046 # TODO finish disentangling control flow so we don't do in-projections when statics are passed
5047 assert static_k.size(0) == bsz * num_heads, \
5048 f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
RuntimeError: shape '[3, 1, 10]' is invalid for input of size 60
Have I missed something?
I am using torch v1.10.2.
My bad, I will post this in-case anyone else encounters the error.
The default value of batch_first
is False
setting it to True
fixes the issue.
attention = nn.MultiheadAttention(input_size, num_heads=1, batch_first=True)