Search code examples
pythonpytorchtensor

Broadcast pytorch array across channels based on another array


I have two arrays, x and y, with the same shape (B 1 N).

x represents data and y represents which class (from 1 to C) each datapoint in x belongs to.

I want to create a new tensor z (with shape B C) where

  1. the data in x are partitioned into channels based on their classes in y
  2. and summed over N

I can accomplish this if I use a one-hot encoding. However, for large tensors (especially with a large number of classes), PyTorch's one-hot encoding quickly uses up all memory on the GPU.

Is there a more memory-efficient way to do this broadcasting without explicitly allocating a B C N tensor?

Here's an MWE of what I'm after:

import torch

B, C, N = 2, 10, 1000

x = torch.randn(B, 1, N)
y = torch.randint(low=0, high=C, size=(B, 1, N))

one_hot = torch.nn.functional.one_hot(y, C)  # B 1 N C
one_hot = one_hot.squeeze().permute(0, -1, 1)  # B C N

z = x * one_hot  # B C N
z = z.sum(-1)  # B C

Solution

  • If z is the desired output tensor, then you will have to allocate BxCxN in memory one way or another. An alternative solution is to expand x and y and scatter values into a zero tensor:

    >>> x, y = x.expand(-1,C,-1), y.expand(-1,C,-1)
    >>> z = torch.zeros(B,C,N).scatter_(1,y,x).sum(-1)
    

    You can check for yourself, but this approach seems to take less memory.


    Edit: If you are looking to reduce N afterward, then no need for C. Since you were using one-hot-encodings, a standard scatter operation without reduction will suffice. Also, the extra singletons are not needed, so assuming x and y are both BxN:

    >>> z = torch.zeros(B,C).scatter_(1,y,x)