Search code examples
pytorchtorcharray-broadcastingnumpy-einsum

How to perform the MaxSim operator leveraging torch procedures?


Let T and L be two batches of matrices (MxN) and a function f(ti,lj) that calculates a score for matrices ti and lj. For instance, if

T, L= torch.rand(4,3,2), torch.rand(4,3,2)
# T = tensor([[[0.0017, 0.5781],
#          [0.8136, 0.5971],
#          [0.7697, 0.0795]],

#         [[0.2794, 0.7285],
#          [0.1528, 0.8503],
#          [0.9714, 0.1060]],

#         [[0.6907, 0.8831],
#          [0.4691, 0.4254],
#          [0.2539, 0.7538]],

#         [[0.3717, 0.2229],
#          [0.6134, 0.4810],
#          [0.7595, 0.9449]]])

and the score function is defined as shown in the following code snippet:

def score(ti, lj):
    """MaxSim score of matrix ti and lj
    """
    m = torch.matmul(ti, torch.transpose(lj, 0, 1))
    return torch.sum(torch.max(m, 1).values, dim=-1)

How to return a score matrix S, where S[i,j] represents the score between T[i] and L[j]?

#S = tensor([[2.3405, 2.2594, 2.0989, 1.6450],
#            [2.5939, 2.4186, 2.3946, 2.0648],
#            [2.9447, 2.3652, 2.3829, 2.1536],
#            [2.8195, 2.3105, 2.2563, 1.8388]])

NOTE: This operation must be differentiable.


Solution

  • I'd recommend using einsum for the pair wise matrix multiplication

    m = torch.einsum('b i j, c k j -> b c i k', T, L)
    

    which results in

    >>> m.shape
    torch.Size([4, 4, 3, 3])
    

    that is, a tensor that contains all 16 matrix products. Then the rest is simply

    out = torch.max(m, -1).values.sum(dim=-1)
    

    Alternatively you could use broadcasting for the matrix multiplications, but I think it is quite a bit more cumbersome than the einsum solution.