Search code examples
pythonoptimizationpytorchtorch

Outer sum, etc. in pytorch


Numpy offers optimized outer operations for any RxR -> R function, like np.multiply.outer or np.subtract.outer, with the behaviour:

>>> np.subtract.outer([6, 5, 4], [3, 2, 1])
array([[3, 4, 5],
       [2, 3, 4],
       [1, 2, 3]])

Pytorch does not seem to offer such a feature (or I have missed it).
What is the best / usual / fastest / cleanest way to do so with torch tensors?


Solution

  • Per the documenation:

    Many PyTorch operations support NumPy Broadcasting Semantics.

    An outer subtraction is a broadcasted subtraction from a 2d array to a 1d array, so essentially you can reshape the first array to (3, 1) and then subtract the second array from it:

    x = torch.Tensor([6, 5, 4])
    y = torch.Tensor([3, 2, 1])
    
    x.reshape(-1, 1) - y
    #tensor([[3., 4., 5.],
    #        [2., 3., 4.],
    #        [1., 2., 3.]])