Search code examples
pytorchtensor

Subtract the elements of every possible pair of a torch Tensor efficiently


I have a huge torch Tensor and I'm looking for an efficient approach to subtract the elements of every pair of that Tensor. Of course I could use two nested for but it wouldn't be efficient.

For example giving

[1, 2, 3, 4]

The output I want is

[1-2, 1-3, 1-4, 2-3, 2-4, 3-4]

Solution

  • You can do this easily:

    >>> x = torch.tensor([1, 2, 3, 4])
    >>> x[:, None] - x[None, :]
    tensor([[ 0, -1, -2, -3],
            [ 1,  0, -1, -2],
            [ 2,  1,  0, -1],
            [ 3,  2,  1,  0]])
    

    see more details here.