Search code examples
pythonpython-3.xpytorch

Selecting different ranges of values from each row in a 2D Tensor


Let's say there's a 2D Tensor:

data = tensor([[ 1.,  2.,  3.,  4.,  5.],
               [ 6.,  7.,  8.,  9., 10.],
               [11., 12., 13., 14., 15.]])

There are also two 1D Tensors for start and end indices of values in each row to be selected:

start = tensor([0., 3., 1.])
end = start + 2               # end is always at a +2 offset from start

Is there a way to select start[i]: end[i] for the ith row in data that does not involve iterating over data? For the above example, the expected output is:

tensor([[ 1.,  2.],
        [ 9., 10.],
        [12., 13.]])

Solution

  • This can be done provided the offset/chunksize is the same for all rows.

    def index_function(data, start_index, chunksize, dim):
        index_tensor = torch.stack([torch.arange(i, i+chunksize) for i in start_index])
        result = data.gather(dim, index_tensor)
        return result
    
    data = torch.tensor([[ 1.,  2.,  3.,  4.,  5.],
                         [ 6.,  7.,  8.,  9., 10.],
                         [11., 12., 13., 14., 15.]])
    
    start_idx = torch.tensor([0, 3, 1]) # start index tensor must be int, not float
    index_function(data, start_idx, 2, 1)
    
    >tensor([[ 1.,  2.],
             [ 9., 10.],
             [12., 13.]])
    
    start_idx = torch.tensor([0, 2, 1])
    index_function(data, start_idx, 3, 1)
    
    >tensor([[ 1.,  2.,  3.],
             [ 8.,  9., 10.],
             [12., 13., 14.]])