Search code examples
pythontorch

How to set value at Tensor index in batch


I have a tensor of batch size N.

t = [[...], [....], [....] .... ]

In second tensor indices, I have N indices of elements I want to change in each tensor

indices = [i0, i1, i2 .... ]

So I want to have t0 created from t via:

t0 = [[ set X at i0 ], [ set X at i1 ], [ set X at i2 ] .... ]

How can I do this at Torch?


Solution

  • It seems like you're looking for the following:

    t[torch.arange(N),indices]
    

    As an example:

    import torch
    a = torch.zeros((3,3))
    a[torch.arange(3),[0,2,1]] = 0.2
    print(a)
    

    Output:

    tensor([[0.2000, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.2000],
            [0.0000, 0.2000, 0.0000]])
    

    Note: This behavior is the same as NumPy's integer array indexing