Search code examples
pythonmachine-learningmathpytorch

Calculate formulas on PyTorch using matrix


I have equations:

$e_{ij} = \frac{X_i W^Q (X_j W^K + A^K_{ij}) }{\sqrt{D_z}}$
$\alpha_{ij} = softmax(e_{ij})$
$z_{i} = \sum_j \alpha_{ij} (X_j W^V + A^V_{ij})$

where sizes:

X: [B, S, H,D]
each W: [H,D,D]
each A: [S, S, H,D]

how i can calculate it via matrix operations?

i have a partial solution

import torch
import torch.nn.functional as F

B, S, H, D = X.shape
d_z = D  # Assuming d_z is equal to D for simplicity

W_Q = torch.randn(H, D, D)
W_K = torch.randn(H, D, D)
W_V = torch.randn(H, D, D)

a_K = torch.randn(S, S, H, D)
a_V = torch.randn(S, S, H, D)
}
XW_Q = torch.einsum('bshd,hde->bshe', X, W_Q)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
XW_K = torch.einsum('bshd,hde->bshe', X, W_K)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]

e_ij_numerator = XW_Q.unsqueeze(2) @ (XW_K.unsqueeze(1) + a_K).transpose(-1, -2)  # [B, S, 1, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]
e_ij = e_ij_numerator / torch.sqrt(torch.tensor(d_z, dtype=torch.float32))  # [B, S, S, H, D]

XW_V = torch.einsum('bshd,hde->bshe', X, W_V)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
alpha = F.softmax(e_ij, dim=2)  # [B, S, S, H, D]

z_i = torch.einsum('bshij,bshjd->bshid', alpha, XW_V.unsqueeze(1) + a_V)  # [B, S, S, H, D] @ [B, 1, S, H, D] -> [B, S, S, H, D]

but z should be [B, S, H,D]


Solution

  • So, If I understood your question correctly, you're implementing an attention mechanism between the ith and jth sequences in a batch. First you linearly project your data (X) to get the queries: XW_Q, then you linearly project your data to get the keys: XW_K. You then add bias a_K and finally you want compute the dot product (similarity) between XW_Q @ (XW_K + a_K).

    In this case, each D-dimensional embedding from the queries is multiplied (in the dot product sense) with every D-dimensional embedding from the keys. The output of a dot product of two vectors is a scalar, to the shape of e_ij should be [B, S, S, H], rather than [B, S, H, D].

    Then, after normalization, you apply softmax such that every ith row sums to 1 to get the scaling matrix alpha which is also [B, S, S, H]

    Now, you project your input the get the values: X@W_V. This should result in a [B, S, H, D] Tensor.

    Finally, you get the new ith sequence (z_i) by scaling every jth sequence column of XW_V by the jth scaling factor in alpha_i and sum, resulting in a [B, S, H, D] tensor as you expected. See the modified code below.

    Hopefully, this is a clear enough explanation. I hope I got your intention right and I that I didn't mix up any indices.

    import torch
    import torch.nn.functional as F
    X = torch.randn((10, 20, 30, 40))
    B, S, H, D = X.shape
    d_z = D  # Assuming d_z is equal to D for simplicity
    
    W_Q = torch.randn(H, D, D)
    W_K = torch.randn(H, D, D)
    W_V = torch.randn(H, D, D)
    
    a_K = torch.randn(S, S, H, D)
    a_V = torch.randn(S, S, H, D)
    
    XW_Q = torch.einsum('bshd,hde->bshe', X, W_Q)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
    XW_K = torch.einsum('bshd,hde->bshe', X, W_K)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
    
    e_ij_numerator = (XW_Q.unsqueeze(2) * (XW_K.unsqueeze(1) + a_K)).sum(dim=-1)  # [B, S, S, H]
    e_ij = e_ij_numerator / torch.sqrt(torch.tensor(d_z, dtype=torch.float32))  # [B, S, S, H]
    alpha = F.softmax(e_ij, dim=2)  # [B, S, S, H]
    XW_V = torch.einsum('bshd,hde->bshe', X, W_V)  # [B, S, H, D] @ [H, D, D] -> [B, S, H, D]
    
    
    z_i = torch.einsum('bijh,bijhd -> bihd', alpha, (XW_V.unsqueeze(1) + a_V))  # [B, S, S, H] * [B, S, S, H, D] -> [B, S, H, D]
    print(z_i.shape) # [B, S, H, D].