Search code examples
cpudevicetensorlibtorch

how does libtorch get tensor device type?


Is there any method for libtorch to get the tensor or model device type.

Just like xxx.device in pytorch?

import torch
tensor = torch.rand(3,4)
print(tensor.device)

Solution

  • libtorch was designed to provide almost exactly the same features in C++ as in python, so when in doubt you can try :

    #include <torch/torch.h>
    torch::Tensor tensor = torch::rand({3,4});
    std::cout << tensor.device() << std::endl;
    

    Plot twist : it works \o/