Search code examples
pythondeep-learningpytorchattention-model

Inputs to the nn.MultiheadAttention?


I have n-vectors which need to be influenced by each other and output n vectors with same dimensionality d. I believe this is what torch.nn.MultiheadAttention does. But the forward function expects query, key and value as inputs. According to this blog, I need to initialize a random weight matrix of shape (d x d) for each of q, k and v and multiply each of my vectors with these weight matrices and get 3 (n x d) matrices. Now are the q, k and v expected by torch.nn.MultiheadAttention just these three matrices or do I have it mistaken?


Solution

  • When you want to use self attention, just pass your input vector into torch.nn.MultiheadAttention for the query, key and value.

    
    attention  = torch.nn.MultiheadAttention(<input-size>, <num-heads>)
    
    x, _ = attention(x, x, x)
    

    The pytorch class returns the output states (same shape as input) and the weights used in the attention process.