Search code examples
pythonnumpypytorchnumpy-slicing

pytorch split array by list of indices


I want to split a torch array by a list of indices.

For example say my input array is torch.arange(20)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])

and my list of indices is splits = [1,2,5,10]

Then my result would be:

(tensor([0]),
 tensor([1, 2]),
 tensor([3, 4, 5, 6, 7]),
 tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17]))

assume my input array is always long enough to bigger than the sum of my list of indices.


Solution

  • You could use tensor_split on the cumulated sum of the splits (e.g. with np.cumsum), excluding the last chunk:

    import torch
    import numpy as np
    
    t = torch.arange(20)
    splits = [1,2,5,10]
    
    t.tensor_split(np.cumsum(splits).tolist())[:-1]
    

    Output:

    (tensor([0]),
     tensor([1, 2]),
     tensor([3, 4, 5, 6, 7]),
     tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17]),
    )