Search code examples
torch

Column sum in torch


How do i sum along a column in torch? I have a 128*1024 tensor, and I want to get a 1*1024 tensor by summing all the rows.

For example: a:

1 2 3 4 5 6

I want b

5 7 9


Solution

  • To do this, you can use the sum method.

    torch.sum(a,1)

    In general, you can specify any axis you want to sum over.

    torch.sum(a,axis)

    (To sum over rows, you can use axis=2)