I have n
-vectors which need to be influenced by each other and output n
vectors with same dimensionality d
. I believe this is what torch.nn.MultiheadAttention
does. But the forward function expects query, key and value as inputs. According to this blog, I need to initialize a random weight matrix of shape (d x d)
for each of q
, k
and v
and multiply each of my vectors with these weight matrices and get 3 (n x d)
matrices. Now are the q
, k
and v
expected by torch.nn.MultiheadAttention
just these three matrices or do I have it mistaken?
When you want to use self attention, just pass your input vector into torch.nn.MultiheadAttention
for the query, key and value.
attention = torch.nn.MultiheadAttention(<input-size>, <num-heads>)
x, _ = attention(x, x, x)
The pytorch class returns the output states (same shape as input) and the weights used in the attention process.