Test Code:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(32, 16)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
self.fc = nn.Linear(32, 2)
def forward(self, x):
x1, x2 = x
x1 = self.linear(x1)
x1 = self.relu1(x1)
x2 = self.linear(x2)
x2 = self.relu2(x2)
out = torch.cat((x1, x2), dim=-1)
out = self.fc(out)
return out
model = Model()
model.eval()
x1 = torch.randn((2, 10, 32))
x2 = torch.randn((2, 10, 32))
x = (x1, x2)
torch.onnx.export(model,
x,
'model.onnx',
input_names=["input"],
output_names=["output"],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}}
)
print("Done")
How to convert the above code to onnx? The input of the forward part of my model is a tuple, cannot be converted to onnx format? thanks! The input of the forward part of my model is a tuple, which cannot be converted to onnx format according to the existing methods. Can you tell me how to solve it
Looking at this issue and this other issue, the parameters are unpacked by default so you need to provide a tuple as argument to torch.onnx.export
:
torch.onnx.export(model,
args=(x,),
f='model.onnx',
input_names=["input"],
output_names=["output"],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}})