Search code examples
luatorch

How to slice tensors with a predefined order in torch?


I have a dataset of length 10 train = torch.range(1,10). I want to slice it in a random order defined by p = torch.randperm(10).

To get slice by ranges one can do a = train[{{1,3}}] to get elements th first three elements. But lets say I want the the 2nd, 3rd and 9th elements. Can I get this without operating a for loop like this

for i = 1,3 do
  print(a[{ p[i] }])
end

where

p[1] = 2, p[2] = 3, p[3] = 9. 

a = train[{{ p[{{1,3}}] }}] doesn't work.


Solution

  • Well, for one there's index, it however requires longTensors:

    train = torch.range(1,10)
    p = torch.randperm(10):long()
    print(train:index(p))