Search code examples
pythonindexingpytorch

Collecting varying element indices from a tensor across multiple dimensions


Assume I got the following tensor:

arr = torch.randint(0, 9, (100, 50, 3))

What I want to achieve is collecting, for example, 2 elements of that tensor, let's start with collecting the 6th and 56th one:

indices = torch.tensor([5, 55])
partial_arr = arr[indices]

This gives me an array of shape

torch.Size([2, 50, 3])

Now, let's assume that from the first element, I want to collect the elements 5 through 10

first_result = partial_arr[0, 5:10]

and from the second element, the elements from 10 to 15:

second_result = partial_arr[1, 10:15]

Since I want everything in one tensor, I can do:

final_result = torch.cat([first_result, second_result])

How can I achieve the final result only with one operation on the first tensor: arr = torch.randint(0, 9, (100, 50, 3)) ?


Solution

  • Assuming the number of sliced elements remains constant across rows, you can create an arrangement tensor and shift it by the per-row starting index:

    >>> idx = torch.tensor([5,10])
    >>> idx_ = torch.arange(5,)[None]+idx[:,None]
    tensor([[ 5,  6,  7,  8,  9],
            [10, 11, 12, 13, 14]])
    

    Then expand idx_ such that it has the same last dimension size as partial_arr:

    >>> idx_ = idx_[...,None].expand(-1,-1,partial_arr.size(-1)) 
    # shaped torch.Size([2, 5, 3])
    

    Finally, gather the values using torch.gather:

    >>> partial_arr.gather(1,idx_).shape
    tensor([[[8, 3, 1],
             [2, 4, 6],
             [4, 4, 5],
             [2, 8, 6],
             [3, 7, 0]],
    
            [[3, 6, 7],
             [5, 7, 4],
             [1, 5, 4],
             [4, 5, 3],
             [7, 1, 2]]])