Search code examples
luatorch

Torch, how to check a variable is CUDA or not?


I am finding for a function like type() to identify which variable is CudaTensor or Normal.

require('cutorch')

x = torch.Tensor(3,3)
x = x:cuda()

if type(x) == 'CudaTensor' then -- What function should be used?
    print('x is CUDA tensor')
else
    print('x is normal tensor')
end

Solution

  • Use :type() tensor's method:

    cutorch = require('cutorch')
    
    x = torch.Tensor(3,3)
    x = x:cuda()
    
    if x:type() == 'torch.CudaTensor' then
        print('x is CUDA tensor')
    else
        print('x is normal tensor')
    end