I am trying to find the closest matches between two batches of pytorch tensors. Assuming I have a batch of mxn
tensors with batch size b1
and a batch of mxn
tensors with batch size b2
, I would like to find:
mxn
tensor in batch b1
and each mxn
tensor in batch b2
. This distance matrix would be of size b1xb2
.b1
, I would like the batch index of the closest (by distance) tensor in b2
.I define distance as the sum of the elementwise squared Euclidean distance between corresponding elements in each tensor. For example, if the first tensor in b1
(i.e. batch index = 0) is [[a, b, c], [d, e, f], [g, h, i], [j, k, l]]
and the first tensor in b2
(i.e. batch index = 0) is [[z, y, x], [w, v, u], [t, s, r]]
, then the distance between b1
and b2
is: (a-z)^2 + (b-y)^2 + (c-x)^2 + (d-w)^2 + (e-v)^2 +(f-u)^2 +...+(l-r)^2
Here's what I have tried:
a = torch.rand((3, 3, 4))
b = torch.rand((5, 3, 4))
flat_a = torch.flatten(a, start_dim = 1)
flat_b = torch.flatten(b, start_dim = 1)
torch.cdist(flat_a, flat_b)
Which gives me a 3x5
matrix that I hope is correct. And I would now like to return the batch indices of the 3x4
tensors in b
that are the closest matches to the tensors in a
.
Thanks
Let's call the distance tensor dist
. All you have to do is:
b_idx = torch.argmin(dist,dim = 1) # returns tensor of shape [3]
which returns the indices into b
along dimension 0.