Search code examples
c++pytorchlibtorch

Shouldn't `randperm` in the PyTorch C++ API return a tensor with default type int?


When I try to generate a list of permuted integer indices with randperm using the C++ PyTorch API, the resulting tensor has the element type of CPUFloatType{10} instead of an integer type:

int N_SAMPLES = 10;               
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES); 
cout << shuffled_indices << endl;

returns

 9
 3
 8
 6
 2
 5
 4
 7
 1
 0
[ CPUFloatType{10} ]

Which cannot be used used for indexing of tensors because the element type is float and not an integer type. When tryig to use my_tensor.index(shuffled_indices) I get

terminate called after throwing an instance of 'c10::IndexError'
  what():  tensors used as indices must be long, byte or bool tensors

Environment:

  • python-pytorch, version 1.6.0-2 on Arch Linux
  • g++ (GCC) 10.1.0

Why does this happen?


Solution

  • That's because the default type of any tensor that you create with torch is always float. If you want otherwise, you have to specify it with the TensorOptions parameter struct :

    int N_SAMPLES = 10;               
    torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES, torch::TensorOptions().dtype(at::kLong));
    cout << shuffled_indices.dtype() << endl;
    >>> long