I have two arrays, x
and y
, with the same shape (B 1 N
).
x
represents data and y
represents which class (from 1
to C
) each datapoint in x
belongs to.
I want to create a new tensor z
(with shape B C
) where
x
are partitioned into channels based on their classes in y
N
I can accomplish this if I use a one-hot encoding. However, for large tensors (especially with a large number of classes), PyTorch's one-hot encoding quickly uses up all memory on the GPU.
Is there a more memory-efficient way to do this broadcasting without explicitly allocating a B C N
tensor?
Here's an MWE of what I'm after:
import torch
B, C, N = 2, 10, 1000
x = torch.randn(B, 1, N)
y = torch.randint(low=0, high=C, size=(B, 1, N))
one_hot = torch.nn.functional.one_hot(y, C) # B 1 N C
one_hot = one_hot.squeeze().permute(0, -1, 1) # B C N
z = x * one_hot # B C N
z = z.sum(-1) # B C
If z
is the desired output tensor, then you will have to allocate BxCxN
in memory one way or another. An alternative solution is to expand x
and y
and scatter
values into a zero tensor:
>>> x, y = x.expand(-1,C,-1), y.expand(-1,C,-1)
>>> z = torch.zeros(B,C,N).scatter_(1,y,x).sum(-1)
You can check for yourself, but this approach seems to take less memory.
Edit: If you are looking to reduce N
afterward, then no need for C
. Since you were using one-hot-encodings, a standard scatter operation without reduction will suffice. Also, the extra singletons are not needed, so assuming x
and y
are both BxN
:
>>> z = torch.zeros(B,C).scatter_(1,y,x)