Search code examples
luadeep-learningtorch

torch extract sub tensor using indices


I would like to extract a subtensor from the original tensor which contains elements at indices defined by another tensor.

Say

th> ls = torch.linspace(1, 10, 10)
                                                                      [0.0001s]
th> ls
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
[torch.DoubleTensor of size 10]
                                                                      [0.0002s]
th> i = torch.FloatTensor(3)
                                                                      [0.0001s]
th> i[1] = 2
                                                                      [0.0000s]
th> i[2] = 7
                                                                      [0.0000s]
th> i[3] = 9
                                                                      [0.0000s]
th> ls.eq(i)
[string "_RESULT={ls.eq(i)}"]:1: invalid arguments: FloatTensor
expected arguments: [*ByteTensor*] DoubleTensor double | *DoubleTensor* DoubleTensor double | [*ByteTensor*] DoubleTensor DoubleTensor | *DoubleTensor* DoubleTensor DoubleTensor
stack traceback:
        [C]: in function 'eq'
        [string "_RESULT={ls.eq(i)}"]:1: in main chunk
        [C]: in function 'xpcall'
        /home/ubuntu/torch/install/share/lua/5.1/trepl/init.lua:651: in function 'repl'
        ...untu/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:199: in main chunk
        [C]: at 0x00406670
                                                                      [0.0001s]
th>

How can I query ls using the indices in i?


Solution

  • sub = ls:index(1, torch.LongTensor{2, 7, 9})
    

    See [Tensor] index(dim, index).