Search code examples
pytorchtorchscript

Torchscript trace "must be on the current device" error despite model and input both being on the same device


I am failing to run torch.jit.trace despite my best effort, encountering RuntimeError: Input, output and indices must be on the current device

I have a (fairly complex) model which I have already put on GPU, along with a set of inputs, also on GPU. I can verify that all input tensors and model parameters & buffers are on the same device:

(Pdb) {p.device for p in self.parameters()}
{device(type='cuda', index=0)}
(Pdb) {p.device for p in self.buffers()}
{device(type='cuda', index=0)}
(Pdb) in_ = (<several tensors here>)
(Pdb) {p.device for p in in_}
{device(type='cuda', index=0)}
(Pdb) torch.cuda.current_device()
0

I can certify the model runs and the output is on the correct device:

(Pdb) self(*in_).device
device(type='cuda', index=0)

Despite all this, tracing fails:

(Pdb) generator_script = torch.jit.trace(self, example_inputs=in_)
*** RuntimeError: Input, output and indices must be on the current device
  1. I understand about inputs and outputs, but what are these "indices" that must also be on the same device?
  2. What other elements that I am not accounting for could be causing trace to fail?

Solution

  • If you're not yet mapping the device during the loading process, doing so could be the solution.[1] That is, mapping the device should happen during jit.load, not as a simple call of .to(device) after jit.load has already finished. See this page for more info.

    As an example of what to do:

    model = jit.load("your_traced_model.pt", map_location=torch.device("cuda"))
    

    This is different from how it works for typical/non-JIT models, where you can simply do:

    model = some_model_creation_function()
    _ = model.to(torch.device("cuda"))
    

    1 = this does not currently work for the MPS device.