I have an onnx model, which has some (ideally) boolean inputs, that are only used for control flow within the model.
Some minimal code for what I try to do:
import onnx
import onnxruntime
import torch.onnx
class SumModule(torch.nn.Module):
def forward(self, x1, x2):
if x2 is not None:
x1 *= 1
return torch.sum(x1)
torch_model = SumModule()
torch_model.eval()
model_inputs = {'x1': torch.tensor([1, 2]), 'x2': torch.tensor([1, 2])}
torch_out = torch_model(**model_inputs)
torch.onnx.export(torch_model,
tuple(model_inputs.values()),
'model.onnx',
export_params=True,
opset_version=16,
do_constant_folding=True,
input_names=list(model_inputs.keys()),
output_names=['output'],
dynamic_axes={'x1': {0: 'batch_size'}, })
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession('model.onnx')
def to_numpy(tensor):
if isinstance(tensor, torch.Tensor):
return tensor.detach().cpu().numpy()
return tensor
model_inputs_np = {k: to_numpy(v) for k, v in model_inputs.items()}
ort_outs = ort_session.run(None, input_feed=model_inputs_np)
While the onnx export goes through, I cannot run the inference model without the error
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:x2
I think I misunderstand something fundamental with onnx here. The argument x2
obviously exists, why does onnx discard it somehow? The same happens if I don't use the argument x2
at all (but have it as an input argument), which I also find weird.
In my actual code, the control flow I want to do is the following:
I have 3 inputs that are optional, so ideally it would be Optional[torch.Tensor]
. However, onnx seems to be unable to deal with None
. So instead I wanted to have the 3 inputs + 3 boolean flags (torch.tensor(True)
or if need be torch.tensor([True]) or replace
Truewith
1or
1.0` --> Same issue with all of those).
Then within the code I do different things based on those flags.
Why does onnx not allow this? I found out that having those variables if fine if I somehow include them in some computations sometimes, but I can't figure out the rule behind all of this.
Your problem is related to how torch.onnx.export
works.
When generating the ONNX model, torch executes (traces) the module once with given inputs while keeping track of all performed computations, then maps them to the corresponding ONNX Operators, and finally simplifies the graph. In your case, the noteworthy detail is that all control flows are evaluated once and Python built-in types are evaluated as constants. So the code
if x2 is not None:
x1 *= 1
return torch.sum(x1)
is saved as
if True:
x1 *= 1
return torch.sum(x1)
and when torch.onnx.export
simplifies the graph, it removes all unused variables including x2
, hence your error.
If you want to preserve control flow in your exported model, you need torch to evaluate your model with torch.jit.script
instead of torch.jit.trace
. As you've already pointed out, ONNX expects a fixed amount of tensors as inputs, and does not accept "optional" arguments. Exporting the model with Scripting is done like this
scripted_model = torch.jit.script(torch_model)
torch.onnx.export(scripted model, ...)
However, with this your model will still not work. We notice that the if
statement in your forward pass is a Pythonic comparison, and doesn't operate on the tensor itself. So x2
will still be discarded during simplification. Changing SumModel
to
class SumModule(torch.nn.Module):
def forward(self, x1, x2):
if torch.any(x2):
x1 *= 1
return torch.sum(x1)
will yield the correct graph, since now x2
is actually operated on. With this, you could use x2
as a boolean flag for control flow.
Highly recommend looking into the torch documentation, as it explains a lot of common mistakes in regard to exporting.
EDIT
For completeness, I should add that the aforementioned approach should generally be avoided. Much of the hardware acceleration is not designed for conditionals, and trying to run ONNX models containing a lot of control flow with, for example CUDA, often leads to large parts of the graph falling back to CPU. When presented with a situation described in this question, I would recommend to consider
Rather than using the solution presented above
EDIT
Add an example of avoiding using if-else control flow for better hardware acceleration.
# if the shape of x2 is static when x2 exists, let's say the shape is [1,2]
# and assume your x2 will never be all zeros when x2 exists (you need to figure out a special case that x2 will never be)
# then you can try the following
# initialize/pass x2 as torch.zeros(1,2) when x2 is None
# this can guarantee that x2 are always passed into the function
# and it always has a static shape
# initialize x2 as all zeros, and update it if x2 exists
# this is outside the ONNX model
x2 = torch.zeros(1,2)
class SumModule(torch.nn.Module):
def forward(self, x1, x2):
# create the condition by pure pytorch, and convert it to float (0.0 or 1.0)
condition = torch.tensor(torch.equal(x2, torch.zeros_like(x2))).to(x1.device).float()
# avoid using control flow
x1 = condition*(x1*1)+(1-condition)*(x1)
return torch.sum(x1)