Given:
A = tensor([[ 0.4821, -0.3484, 0.0915, -0.1870],
[ 1.3817, 0.3011, 1.0704, 2.1717]])
B = torch.zeros(2,6)
C = torch.tensor([[1,2,2,3], [3,7,2,5]]) (same shape of A)
I want to replace values in B by A at indices C where < 6 (B.size(-1))
-> B =[[0, 0.4821, 0.0915, -0.1870, 0, 0],
[0, 0, 1.0704, 1.3817, 0, 2.1717]]
Notice that: there are two 2 in the first row of C at the second and third position in A. Here I want to get the max (or sum if you think it's more possible to do)
What you could do is clip the indices such that they do not go above B.size(1)
using a scattering operation, in this case, the last element will override the other (only the 2nd 2
will be kept). You also have specialized functions to accumulate via summation or reduce to the maximum value. Let's try this:
torch.zeros(2,6).scatter_(1, C.clip(0,5), A)
tensor([[ 0.0000, 0.4821, 0.0915, -0.1870, 0.0000, 0.0000],
[ 0.0000, 0.0000, 1.0704, 1.3817, 0.0000, 2.1717]])
But this won't work all the time because indices that are over the max length will be placed at the end. A solution might be to concatenate an additional buffer column to account for the undesired values and then trim the tensor at the end to discard those:
torch.zeros(2,6+1).scatter_(1, C.clip(0,6), A)[:,:-1]
tensor([[ 0.0000, 0.4821, 0.0915, -0.1870, 0.0000, 0.0000],
[ 0.0000, 0.0000, 1.0704, 1.3817, 0.0000, 2.1717]])
If you want to reduce your scattering differently, you can use torch.scatter_reduce_
and specify the reduce
argument:
reduce
(str
) – the reduction operation to apply for non-unique indices ("sum"
,"prod"
,"mean"
,"amax"
,"amin"
)
For example to get the maximum:
torch.zeros(2,7).scatter_reduce_(1, C.clip(0,6), A, reduce='amax')[:,:-1]