Search code examples
pythonpytorch

Pytorch find the matching 2x2 tensor in a batch of 2x2 tensors


I have a 2x2 reference tensor and a batch of candidate 2x2 tensors. I would like to find the closest candidate tensor to the reference tensor by summed euclidean distance over the identically indexed (except for the batch index) elements.

For example:

ref = torch.as_tensor([[1, 2], [3, 4]])
candidates = torch.rand(100, 2, 2)

I would like to find the 2x2 tensor index in candidates that minimizes:

(ref[0][0] - candidates[index][0][0])**2 + 
(ref[0][1] - candidates[index][0][1])**2 + 
(ref[1][0] - candidates[index][1][0])**2 + 
(ref[1][1] - candidates[index][1][1])**2

Ideally, this solution would work for arbitrary dimension reference tensor of size (b, c, d, ...., z) and an arbitrary batch_size of candidate tensors with equal dimensions to the reference tensor (batch_size, b, c, d,..., z)


Solution

  • Elaborating on @ndrwnaguib's answer, it should be rather:

    dist = torch.cdist( ref.float().flatten().unsqueeze(0), candidates.flatten(start_dim=1))
    print(torch.square( dist ))
    torch.argmin( dist )
    

    tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
             22.7513, 16.8489]])
    
    tensor(9)
    

    other options, worth noting:

    torch.square(ref.float()- candidates).sum( dim=(1,2) )
    

    tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
             22.7513, 16.8489]])
    

    diff = ref.float()- candidates
    torch.einsum( "abc,abc->a" ,diff, diff)
    

    tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
             22.7513, 16.8489]])