I am trying to make a simple PyTorch model and convert it to PyTorch jit script using below code. (Final goal is to convert it to PyTorch Mobile)
class Concat(nn.Module):
def __init__(self):
super(Concat, self).__init__()
def forward(self, x):
return torch.cat(x,1)
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 3, 1)
self.conv2 = nn.Conv2d(16, 32, 3, 1)
def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
z = self.conv1(x)
z = self.conv2(z)
return (y, z)
net = nn.Sequential(
Net(),
Concat()
)
mobile_net = torch.quantization.convert(net)
scripted_net = torch.jit.script(mobile_net)
But the above code throws the following error.
RuntimeError Traceback (most recent call last)
Cell In [2], line 26
21 net = nn.Sequential(
22 Net(),
23 Concat()
24 )
25 mobile_net = torch.quantization.convert(net)
---> 26 scripted_net = torch.jit.script(mobile_net)
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_script.py:1286, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1284 if isinstance(obj, torch.nn.Module):
1285 obj = call_prepare_scriptable_func(obj)
-> 1286 return torch.jit._recursive.create_script_module(
1287 obj, torch.jit._recursive.infer_methods_to_compile
1288 )
1290 if isinstance(obj, dict):
1291 return create_script_dict(obj)
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:476, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
474 if not is_tracing:
475 AttributeTypeIsSupportedChecker().check(nn_module)
--> 476 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:538, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
535 script_module._concrete_type = concrete_type
537 # Actually create the ScriptModule, initializing it with the function we just defined
--> 538 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
540 # Compile methods if necessary
541 if concrete_type not in concrete_type_store.methods_compiled:
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_script.py:615, in RecursiveScriptModule._construct(cpp_module, init_fn)
602 """
603 Construct a RecursiveScriptModule that's ready for use. PyTorch
604 code should use this to construct a RecursiveScriptModule instead
(...)
612 init_fn: Lambda that initializes the RecursiveScriptModule passed to it.
613 """
614 script_module = RecursiveScriptModule(cpp_module)
--> 615 init_fn(script_module)
617 # Finalize the ScriptModule: replace the nn.Module state with our
618 # custom implementations and flip the _initializing bit.
619 RecursiveScriptModule._finalize_scriptmodule(script_module)
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:516, in create_script_module_impl.<locals>.init_fn(script_module)
513 scripted = orig_value
514 else:
515 # always reuse the provided stubs_fn to infer the methods to compile
--> 516 scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
518 cpp_module.setattr(name, scripted)
519 script_module._modules[name] = scripted
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:542, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
540 # Compile methods if necessary
541 if concrete_type not in concrete_type_store.methods_compiled:
--> 542 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
543 # Create hooks after methods to ensure no name collisions between hooks and methods.
544 # If done before, hooks can overshadow methods that aren't exported.
545 create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:393, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
390 property_defs = [p.def_ for p in property_stubs]
391 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 393 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError:
Arguments for call are not valid.
The following variants are available:
aten::cat(Tensor[] tensors, int dim=0) -> Tensor:
Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
aten::cat.names(Tensor[] tensors, str dim) -> Tensor:
Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
aten::cat.names_out(Tensor[] tensors, str dim, *, Tensor(a!) out) -> Tensor(a!):
Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!):
Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
The original call is:
File "C:\Users\pawan\AppData\Local\Temp\ipykernel_16484\3929675973.py", line 6
def forward(self, x):
return torch.cat(x,1)
~~~~~~~~~ <--- HERE
I am new to PyTorch and am not familiar with the internal working of PyTorch please provide a solution to this. If torch.cat is combined in the forward method of Net class i.e instead of return (y, z) we do return torch.cat((y, z),1) then it works but I want to do it using a different class for concatenation.
Why the error happens
While compiling Concat.forward
, torch.jit
assumes the parameter x
is a Tensor
. Later, torch.jit
realizes the actual argument passed to Concat.forward
is a tuple (y, z)
, so torch.jit
concludes "Arguments for call are not valid" (because a tuple isn't a Tensor
).
How to fix it
Explicitly specify the type of the parameter x
in Concat.forward
as Tuple[torch.Tensor, torch.Tensor]
, so that torch.jit
knows what you want.
from typing import Tuple
class Concat(nn.Module):
def __init__(self):
super(Concat, self).__init__()
def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
# ^^^ torch.jit.script needs this ^^^
return torch.cat(x,1)
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 3, 1)
self.conv2 = nn.Conv2d(16, 32, 3, 1)
def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
z = self.conv1(x)
z = self.conv2(z)
return (y, z)
net = nn.Sequential(
Net(),
Concat()
)
mobile_net = torch.quantization.convert(net)
scripted_net = torch.jit.script(mobile_net)