Search code examples
pythonpytorch

Finding closest matches (by distance metric) in two batches of pytorch tensors


I am trying to find the closest matches between two batches of pytorch tensors. Assuming I have a batch of mxn tensors with batch size b1 and a batch of mxn tensors with batch size b2, I would like to find:

  • The distance between each mxn tensor in batch b1 and each mxn tensor in batch b2. This distance matrix would be of size b1xb2.
  • For each tensor in b1, I would like the batch index of the closest (by distance) tensor in b2.

I define distance as the sum of the elementwise squared Euclidean distance between corresponding elements in each tensor. For example, if the first tensor in b1 (i.e. batch index = 0) is [[a, b, c], [d, e, f], [g, h, i], [j, k, l]] and the first tensor in b2 (i.e. batch index = 0) is [[z, y, x], [w, v, u], [t, s, r]], then the distance between b1 and b2 is: (a-z)^2 + (b-y)^2 + (c-x)^2 + (d-w)^2 + (e-v)^2 +(f-u)^2 +...+(l-r)^2

Here's what I have tried:

a = torch.rand((3, 3, 4))
b = torch.rand((5, 3, 4))
flat_a = torch.flatten(a, start_dim = 1)
flat_b = torch.flatten(b, start_dim = 1)
torch.cdist(flat_a, flat_b)

Which gives me a 3x5 matrix that I hope is correct. And I would now like to return the batch indices of the 3x4 tensors in b that are the closest matches to the tensors in a.

Thanks


Solution

  • Let's call the distance tensor dist. All you have to do is:

    b_idx = torch.argmin(dist,dim = 1) # returns tensor of shape [3]
    

    which returns the indices into b along dimension 0.