Search code examples
pythonmultidimensional-arraypytorchtensortensor-indexing

For a given condition, get indices of values in 2D tensor A, use those to index a 3D tensor B


For a given 2D tensor I want to retrieve all indices where the value is 1. I expected to be able to simply use torch.nonzero(a == 1).squeeze(), which would return tensor([1, 3, 2]). However, instead, torch.nonzero(a == 1) returns a 2D tensor (that's okay), with two values per row (that's not what I expected). The returned indices should then be used to index the second dimension (index 1) of a 3D tensor, again returning a 2D tensor.

import torch

a = torch.Tensor([[12, 1, 0, 0],
                  [4, 9, 21, 1],
                  [10, 2, 1, 0]])

b = torch.rand(3, 4, 8)

print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])

idxs = torch.nonzero(a == 1)
print('idxs_size', idxs.size())
# idxs_size torch.Size([3, 2])

print(b.gather(1, idxs))

Evidently, this does not work, leading to aRunTimeError:

RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:453

It seems that idxs is not what I expect it to be, nor can I use it the way I thought. idxs is

tensor([[0, 1],
        [1, 3],
        [2, 2]])

but reading through the documentation I don't understand why I also get back the row indices in the resulting tensor. Now, I know I can get the correct idxs by slicing idxs[:, 1] but then still, I cannot use those values as indices for the 3D tensor because the same error as before is raised. Is it possible to use the 1D tensor of indices to select items across a given dimension?


Solution

  • You could simply slice them and pass it as the indices as in:

    In [193]: idxs = torch.nonzero(a == 1)     
    In [194]: c = b[idxs[:, 0], idxs[:, 1]]  
    
    In [195]: c   
    Out[195]: 
    tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
            [0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
            [0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
    

    Alternatively, an even simpler & my preferred approach would be to just use torch.where() and then directly index into the tensor b as in:

    In [196]: b[torch.where(a == 1)]  
    Out[196]: 
    tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
            [0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
            [0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
    

    A bit more explanation about the above approach of using torch.where(): It works based on the concept of advanced indexing. That is, when we index into the tensor using a tuple of sequence objects such as tuple of tensors, tuple of lists, tuple of tuples etc.

    # some input tensor
    In [207]: a  
    Out[207]: 
    tensor([[12.,  1.,  0.,  0.],
            [ 4.,  9., 21.,  1.],
            [10.,  2.,  1.,  0.]])
    

    For basic slicing, we would need a tuple of integer indices:

       In [212]: a[(1, 2)] 
       Out[212]: tensor(21.)
    

    To achieve the same using advanced indexing, we would need a tuple of sequence objects:

    # adv. indexing using a tuple of lists
    In [213]: a[([1,], [2,])] 
    Out[213]: tensor([21.])
    
    # adv. indexing using a tuple of tuples
    In [215]: a[((1,), (2,))]  
    Out[215]: tensor([21.])
    
    # adv. indexing using a tuple of tensors
    In [214]: a[(torch.tensor([1,]), torch.tensor([2,]))] 
    Out[214]: tensor([21.])
    

    And the dimension of the returned tensor would always be one dimension less than the dimension of the input tensor.