When I try to generate a list of permuted integer indices with randperm
using the C++ PyTorch API, the resulting tensor has the element type of CPUFloatType{10}
instead of an integer type:
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES);
cout << shuffled_indices << endl;
returns
9
3
8
6
2
5
4
7
1
0
[ CPUFloatType{10} ]
Which cannot be used used for indexing of tensors because the element type is float and not an integer type. When tryig to use my_tensor.index(shuffled_indices)
I get
terminate called after throwing an instance of 'c10::IndexError'
what(): tensors used as indices must be long, byte or bool tensors
Environment:
Why does this happen?
That's because the default type of any tensor that you create with torch is always float
. If you want otherwise, you have to specify it with the TensorOptions
parameter struct :
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES, torch::TensorOptions().dtype(at::kLong));
cout << shuffled_indices.dtype() << endl;
>>> long