Search code examples
pytorchnumpy-einsumeinsummatmul

Einsum for shapes of different sizes or ranks


I have two PyTorch tensors. One is rank three and the other is rank four. Is there a way to get it so that it produce the rank and shape of the first tensor? For instance in this cross-attention bit:

q = torch.linspace(1, 192, steps=192)
q = q.reshape(2, 4, 3, 8)
k = torch.linspace(2, 193, steps=192)
k = k.reshape(2, 4, 3, 8)
v = torch.linspace(3, 194, steps=192)
v = v.reshape(2, 4, 24)

k = k.permute(0, 3, 2, 1)
attn = torch.einsum("nchw,nwhu->nchu", q, k)

# Below is what doesn't work. I would like to get it such that hidden_states is a tensor of rank 2, 4, 24
hidden_states = torch.einsum("chw,whu->chu", attn, v)

Is there a permutation/transpose I could apply to q, k, v, or attn that would allow me to multiply into (2, 4, 24)? I have yet to find one.

I currently receive this error: "RuntimeError: einsum(): the number of subscripts in the equation (3) does not match the number of dimensions (4) for operand 0 and no ellipsis was given" so I'm wondering how to use the ellipsis in this case, if that could be a solution.

Any explanation as to why this is or isn't possible would also be an excepted answer!


Solution

  • It seems like your q and k are 4D tensors of shape batch-channel-height-width (2x4x3x8). However, when considering attention mechanism, one disregard the spatial arrangement of the features and only treat them as a "bag of features". That is, instead of q and k of shape 2x4x3x8 you should have 2x4x24:

    q = torch.linspace(1, 192, steps=192)
    q = q.reshape(2, 4, 3 * 8)  # collapse the spatial dimensions into a single one
    k = torch.linspace(2, 193, steps=192)
    k = k.reshape(2, 4, 3 * 8)  # collapse the spatial dimensions into a single one
    v = torch.linspace(3, 194, steps=192)
    v = v.reshape(2, 4, 24)
    
    attn = torch.einsum("bcn,bcN->bnN", q, k)
    # it is customary to convert the raw attn into probabilities using softmax
    attn = torch.softmax(attn, dim=-1)
    hidden_states = torch.einsum("bnN,bcN->bcn", attn, v)