Search code examples
pytorchvectorizationtensor

projecting values of tensor A into tensor B at indices C (pytorch)


Given:

A = tensor([[ 0.4821, -0.3484,  0.0915, -0.1870],
            [ 1.3817,  0.3011,  1.0704,  2.1717]])

B = torch.zeros(2,6)

C =  torch.tensor([[1,2,2,3], [3,7,2,5]]) (same shape of A)

I want to replace values in B by A at indices C where < 6 (B.size(-1))

-> B =[[0, 0.4821, 0.0915, -0.1870, 0, 0],
       [0, 0, 1.0704, 1.3817, 0, 2.1717]]

Notice that: there are two 2 in the first row of C at the second and third position in A. Here I want to get the max (or sum if you think it's more possible to do)


Solution

  • What you could do is clip the indices such that they do not go above B.size(1) using a scattering operation, in this case, the last element will override the other (only the 2nd 2 will be kept). You also have specialized functions to accumulate via summation or reduce to the maximum value. Let's try this:

    torch.zeros(2,6).scatter_(1, C.clip(0,5), A)
    tensor([[ 0.0000,  0.4821,  0.0915, -0.1870,  0.0000,  0.0000],
            [ 0.0000,  0.0000,  1.0704,  1.3817,  0.0000,  2.1717]])
    

    But this won't work all the time because indices that are over the max length will be placed at the end. A solution might be to concatenate an additional buffer column to account for the undesired values and then trim the tensor at the end to discard those:

    torch.zeros(2,6+1).scatter_(1, C.clip(0,6), A)[:,:-1]
    tensor([[ 0.0000,  0.4821,  0.0915, -0.1870,  0.0000,  0.0000],
            [ 0.0000,  0.0000,  1.0704,  1.3817,  0.0000,  2.1717]])
    

    If you want to reduce your scattering differently, you can use torch.scatter_reduce_ and specify the reduce argument:

    reduce (str) – the reduction operation to apply for non-unique indices ("sum", "prod", "mean", "amax", "amin")

    For example to get the maximum:

    torch.zeros(2,7).scatter_reduce_(1, C.clip(0,6), A, reduce='amax')[:,:-1]