Search code examples
pythonpytorchsum

Torch sum a tensor along an axis


How do I sum over the columns of a tensor?

torch.Size([10, 100])    --->    torch.Size([10])

Solution

  • The simplest and best solution is to use torch.sum().

    To sum all elements of a tensor:

    torch.sum(x) # gives back a scalar
    

    To sum over all rows (i.e. for each column):

    torch.sum(x, dim=0) # size = [ncol]
    

    To sum over all columns (i.e. for each row):

    torch.sum(x, dim=1) # size = [nrow]
    

    It should be noted that the dimension summed over is eliminated from the resulting tensor.