Search code examples

Calculating pairwise distances between entries in a `torch.tensor`

I'm trying to implement a manifold alignment type of loss illustrated here.

Given a tensor representing a batch of embeddings of shape (L,N) for example with L=256:

tensor([[ 0.0178,  0.0004, -0.0217,  ..., -0.0724,  0.0698, -0.0180],
        [ 0.0160,  0.0002, -0.0217,  ..., -0.0725,  0.0655, -0.0207],
        [ 0.0155, -0.0010, -0.0153,  ..., -0.0750,  0.0688, -0.0253],
        [ 0.0130, -0.0113, -0.0078,  ..., -0.0805,  0.0634, -0.0241],
        [ 0.0120, -0.0047, -0.0135,  ..., -0.0846,  0.0722, -0.0230],
        [ 0.0120, -0.0048, -0.0142,  ..., -0.0843,  0.0734, -0.0246]],

I want to compute all the pairwise distances between the row entries. Resulting in a (L, L) shaped output.

I've tried with torch.nn.PairwiseDistance but it is not clear to me if it is useful for what I'm looking for.


  • Thought it was strange that there was none. There is and it is called torch.cdist but it is "hidden" in the top level.

    >>> a = torch.rand((5,3))
    >>> a
    tensor([[0.0215, 0.0843, 0.3414],
            [0.9878, 0.5835, 0.3052],
            [0.0903, 0.7347, 0.0711],
            [0.9774, 0.8202, 0.7721],
            [0.7877, 0.9891, 0.4619]])
    >>> torch.cdist(a,a)
    tensor([[0.0000, 1.0883, 0.7077, 1.2809, 1.1918],
            [1.0883, 0.0000, 0.9398, 0.5236, 0.4787],
            [0.7077, 0.9398, 0.0000, 1.1339, 0.8390],
            [1.2809, 0.5236, 1.1339, 0.0000, 0.4010],
            [1.1918, 0.4787, 0.8390, 0.4010, 0.0000]])
    >>> torch.nn.functional.pairwise_distance(a[0], a[2])