Search code examples
python-3.xpytorch

PyTorch L2-norm between 2 tensors of different shapes


I have 2 tensors in PyTorch:

a.shape, b.shape
# (torch.Size([1600, 2]), torch.Size([128, 2]))

I want to compute L2-norm distance between each of the 128 values in 'b' having 2-dim values from all 1600 values in 'a'. Currently, I have an inefficient for loop to do it for each values in b as follows:

# Need to compute l2-norm squared dist b/w each b from a-
l2_dist_squared = list()

for bmu in bmu_locs:
    l2_dist_squared.append(torch.norm(input = a.to(torch.float32) - b, p = 2, dim = 1))

l2_dist_squared = torch.stack(l2_dist_squared)

# l2_dist_squared.shape
# torch.Size([128, 1600])

Is there a better way to do as a one liner?


Solution

  • You can compute the batched p-norm with torch.cdist, it operates between x1 of shape B×P×M and x2 of shape B×R×M, returning a tensor shaped B×P×R. Which means the common dimension M is the one reduced. First, unsqueeze one singleton dimension on both inputs to turn them 3D, then apply the function:

    >>> torch.cdist(a[None], b[None]).shape
    >>> torch.Size([1, 1600, 128])