Search code examples
pytorcharray-broadcasting

torch matrix equaity sum operation


I want to do an operation similar to matrix multiplication, except instead of multiplying I want to check equality. The effect that I want to achieve is similar to the following:

a = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.uint8)
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).to(torch.uint8)
result = [[sum(a[i] == b [j]) for j in range(len(b))] for i in range(len(a))]

Is there a way that I can use einsum, or any other function in pytorch to achieve the above efficiently?


Solution

  • You can make use of the broadcasting to do the same, for instance with

    result = (a[:, None, :] == b[None, :, :]).sum(dim=2)
    

    Here None just introduces a dummy dimensions - alternatively you can use the less visual .unsqueeze() instead.