Search code examples
pytorchonnx

The input of the forward part of my model is a tuple, cannot be converted to onnx format?


Test Code:

    #!/usr/bin/env python
    # -*- coding:utf-8 -*-
    import torch
    import torch.nn as nn


    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.linear = nn.Linear(32, 16)
            self.relu1 = nn.ReLU(inplace=True)
            self.relu2 = nn.ReLU(inplace=True)
            self.fc = nn.Linear(32, 2)

        def forward(self, x):
            x1, x2 = x
            x1 = self.linear(x1)
            x1 = self.relu1(x1)
            x2 = self.linear(x2)
            x2 = self.relu2(x2)
            out = torch.cat((x1, x2), dim=-1)
            out = self.fc(out)
            return out


    model = Model()
    model.eval()

    x1 = torch.randn((2, 10, 32))
    x2 = torch.randn((2, 10, 32))
    x = (x1, x2)

    torch.onnx.export(model,
                  x,
                  'model.onnx',
                  input_names=["input"],
                  output_names=["output"],
                  dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}}
                  )
    print("Done")

How to convert the above code to onnx? The input of the forward part of my model is a tuple, cannot be converted to onnx format? thanks! The input of the forward part of my model is a tuple, which cannot be converted to onnx format according to the existing methods. Can you tell me how to solve it


Solution

  • Looking at this issue and this other issue, the parameters are unpacked by default so you need to provide a tuple as argument to torch.onnx.export:

    torch.onnx.export(model,
       args=(x,),
       f='model.onnx',
       input_names=["input"],
       output_names=["output"],
       dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}})