Search code examples
indexingpytorch

How to assign values to the tensor matrix according to a given index matrix?


Here I have a value tensor:

index = Tensor([[ 0,  1],
                [ 1,  2],
                [ 4,  1]])
a = Tensor([[ 0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.]])
b = Tensor([ 5.,  6.,  7.,  8.,  9.,  10.])

How to assign the value in b to a according to the index in a gentle way without loop? Like this:

a = Tensor([[ 5.,  6.,  0.,  0.,  0.,  0.],
            [ 0.,  6.,  7.,  0.,  0.,  0.],
            [ 0.,  6.,  0.,  0.,  9.,  0.]])

I have tried torch.index_put() and torch.scatter(), but unfortunately get a unexpected result. Much probably, I made a wrong operation.


Solution

  • a.scatter_(1, index, b[index])

    PS. ChatGPT yyds!