I have equations:
$e_{ij} = \frac{X_i W^Q (X_j W^K + A^K_{ij}) }{\sqrt{D_z}}$
$\alpha_{ij} = softmax(e_{ij})$
$z_{i} = \sum_j \alpha_{ij} (X_j W^V + A^V_{ij})$
where sizes:
X: [B, S, H,D]
each W: [H,D,D]
each A: [S, S, H,D]
how i can calculate it via matrix operations?
i have a partial solution
import torch
import torch.nn.functional as F
B, S, H, D = X.shape
d_z = D # Assuming d_z is equal to D for simplicity
W_Q = torch.randn(H, D, D)
W_K = torch.randn(H, D, D)
W_V = torch.randn(H, D, D)
a_K = torch.randn(S, S, H, D)
a_V = torch.randn(S, S, H, D)
}
XW_Q = torch.einsum('bshd,hde->bshe', X, W_Q) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
XW_K = torch.einsum('bshd,hde->bshe', X, W_K) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
e_ij_numerator = XW_Q.unsqueeze(2) @ (XW_K.unsqueeze(1) + a_K).transpose(-1, -2) # [B, S, 1, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]
e_ij = e_ij_numerator / torch.sqrt(torch.tensor(d_z, dtype=torch.float32)) # [B, S, S, H, D]
XW_V = torch.einsum('bshd,hde->bshe', X, W_V) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
alpha = F.softmax(e_ij, dim=2) # [B, S, S, H, D]
z_i = torch.einsum('bshij,bshjd->bshid', alpha, XW_V.unsqueeze(1) + a_V) # [B, S, S, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]
but z should be [B, S, H,D]
So, If I understood your question correctly, you're implementing an attention mechanism between the i
th and j
th sequences in a batch. First you linearly project your data (X) to get the queries: XW_Q, then you linearly project your data to get the keys: XW_K. You then add bias a_K and finally you want compute the dot product (similarity) between XW_Q @ (XW_K + a_K).
In this case, each D-dimensional embedding from the queries is multiplied (in the dot product sense) with every D-dimensional embedding from the keys. The output of a dot product of two vectors is a scalar, to the shape of e_ij should be [B, S, S, H], rather than [B, S, H, D].
Then, after normalization, you apply softmax such that every ith row sums to 1 to get the scaling matrix alpha which is also [B, S, S, H]
Now, you project your input the get the values: X@W_V. This should result in a [B, S, H, D] Tensor.
Finally, you get the new ith sequence (z_i) by scaling every jth sequence column of XW_V by the jth scaling factor in alpha_i and sum, resulting in a [B, S, H, D] tensor as you expected. See the modified code below.
Hopefully, this is a clear enough explanation. I hope I got your intention right and I that I didn't mix up any indices.
import torch
import torch.nn.functional as F
X = torch.randn((10, 20, 30, 40))
B, S, H, D = X.shape
d_z = D # Assuming d_z is equal to D for simplicity
W_Q = torch.randn(H, D, D)
W_K = torch.randn(H, D, D)
W_V = torch.randn(H, D, D)
a_K = torch.randn(S, S, H, D)
a_V = torch.randn(S, S, H, D)
XW_Q = torch.einsum('bshd,hde->bshe', X, W_Q) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
XW_K = torch.einsum('bshd,hde->bshe', X, W_K) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
e_ij_numerator = (XW_Q.unsqueeze(2) * (XW_K.unsqueeze(1) + a_K)).sum(dim=-1) # [B, S, S, H]
e_ij = e_ij_numerator / torch.sqrt(torch.tensor(d_z, dtype=torch.float32)) # [B, S, S, H]
alpha = F.softmax(e_ij, dim=2) # [B, S, S, H]
XW_V = torch.einsum('bshd,hde->bshe', X, W_V) # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
z_i = torch.einsum('bijh,bijhd -> bihd', alpha, (XW_V.unsqueeze(1) + a_V)) # [B, S, S, H] * [B, S, S, H, D] -> [B, S, H, D]
print(z_i.shape) # [B, S, H, D].