Search code examples
pythonmobilepytorchconcatenation

Runtime Error when converting Pytorch model to PyTorch jit script


I am trying to make a simple PyTorch model and convert it to PyTorch jit script using below code. (Final goal is to convert it to PyTorch Mobile)

class Concat(nn.Module):
    def __init__(self):
        super(Concat, self).__init__()

    def forward(self, x):
        return torch.cat(x,1)

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1)
        self.conv2 = nn.Conv2d(16, 32, 3, 1)
        
    def forward(self, x):
        y = self.conv1(x)
        y = self.conv2(y)
        z = self.conv1(x)
        z = self.conv2(z)
        return (y, z)

net = nn.Sequential(
    Net(),
    Concat()
)
mobile_net = torch.quantization.convert(net)
scripted_net = torch.jit.script(mobile_net)

But the above code throws the following error.

RuntimeError                              Traceback (most recent call last)
Cell In [2], line 26
     21 net = nn.Sequential(
     22     Net(),
     23     Concat()
     24 )
     25 mobile_net = torch.quantization.convert(net)
---> 26 scripted_net = torch.jit.script(mobile_net)

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_script.py:1286, in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1284 if isinstance(obj, torch.nn.Module):
   1285     obj = call_prepare_scriptable_func(obj)
-> 1286     return torch.jit._recursive.create_script_module(
   1287         obj, torch.jit._recursive.infer_methods_to_compile
   1288     )
   1290 if isinstance(obj, dict):
   1291     return create_script_dict(obj)

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:476, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    474 if not is_tracing:
    475     AttributeTypeIsSupportedChecker().check(nn_module)
--> 476 return create_script_module_impl(nn_module, concrete_type, stubs_fn)

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:538, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    535     script_module._concrete_type = concrete_type
    537 # Actually create the ScriptModule, initializing it with the function we just defined
--> 538 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
    540 # Compile methods if necessary
    541 if concrete_type not in concrete_type_store.methods_compiled:

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_script.py:615, in RecursiveScriptModule._construct(cpp_module, init_fn)
    602 """
    603 Construct a RecursiveScriptModule that's ready for use. PyTorch
    604 code should use this to construct a RecursiveScriptModule instead
   (...)
    612     init_fn:  Lambda that initializes the RecursiveScriptModule passed to it.
    613 """
    614 script_module = RecursiveScriptModule(cpp_module)
--> 615 init_fn(script_module)
    617 # Finalize the ScriptModule: replace the nn.Module state with our
    618 # custom implementations and flip the _initializing bit.
    619 RecursiveScriptModule._finalize_scriptmodule(script_module)

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:516, in create_script_module_impl.<locals>.init_fn(script_module)
    513     scripted = orig_value
    514 else:
    515     # always reuse the provided stubs_fn to infer the methods to compile
--> 516     scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
    518 cpp_module.setattr(name, scripted)
    519 script_module._modules[name] = scripted

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:542, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    540 # Compile methods if necessary
    541 if concrete_type not in concrete_type_store.methods_compiled:
--> 542     create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    543     # Create hooks after methods to ensure no name collisions between hooks and methods.
    544     # If done before, hooks can overshadow methods that aren't exported.
    545     create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:393, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    390 property_defs = [p.def_ for p in property_stubs]
    391 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 393 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)

RuntimeError: 
Arguments for call are not valid.
The following variants are available:
  
  aten::cat(Tensor[] tensors, int dim=0) -> Tensor:
  Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::cat.names(Tensor[] tensors, str dim) -> Tensor:
  Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::cat.names_out(Tensor[] tensors, str dim, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.

The original call is:
  File "C:\Users\pawan\AppData\Local\Temp\ipykernel_16484\3929675973.py", line 6
    def forward(self, x):
        return torch.cat(x,1)
               ~~~~~~~~~ <--- HERE

I am new to PyTorch and am not familiar with the internal working of PyTorch please provide a solution to this. If torch.cat is combined in the forward method of Net class i.e instead of return (y, z) we do return torch.cat((y, z),1) then it works but I want to do it using a different class for concatenation.


Solution

  • Why the error happens

    While compiling Concat.forward, torch.jit assumes the parameter x is a Tensor. Later, torch.jit realizes the actual argument passed to Concat.forward is a tuple (y, z), so torch.jit concludes "Arguments for call are not valid" (because a tuple isn't a Tensor).

    How to fix it

    Explicitly specify the type of the parameter x in Concat.forward as Tuple[torch.Tensor, torch.Tensor], so that torch.jit knows what you want.

    from typing import Tuple
    
    class Concat(nn.Module):
        def __init__(self):
            super(Concat, self).__init__()
    
        def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
            #              ^^^ torch.jit.script needs this ^^^
            return torch.cat(x,1)
    
    class Net(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv1 = nn.Conv2d(3, 16, 3, 1)
            self.conv2 = nn.Conv2d(16, 32, 3, 1)
            
        def forward(self, x):
            y = self.conv1(x)
            y = self.conv2(y)
            z = self.conv1(x)
            z = self.conv2(z)
            return (y, z)
    
    net = nn.Sequential(
        Net(),
        Concat()
    )
    mobile_net = torch.quantization.convert(net)
    scripted_net = torch.jit.script(mobile_net)