Search code examples
pythonpytorchonnx

Crash when trying to export PyTorch model to ONNX: forward() missing 1 required positional argument


I'm trying to convert pyTorch model to onnx like this:

torch.onnx.export(
  model=modnet.module,
  args=example_input, 
  f=ONNX_PATH, # where should it be saved
  verbose=False,
  export_params=True,
  do_constant_folding=False,
  input_names=['input'],
  output_names=['output']
)

modnet is a model from this repo: https://github.com/ZHKKKe/MODNet

example_input is a Tensor of shape [1, 3, 512, 512]

During converting I received that error:

TypeError: forward() missing 1 required positional argument: 'inference'

This is my clone of Colab notebook to reproduce exception: https://colab.research.google.com/drive/1AE1VAXIXkm26krIOoBaFfhoE53hhuEdf?usp=sharing

Save me please! :)


Solution

  • Modnet forward method requires a parameter called inference which is a boolean, indeed when the model is trained they pass it in this way:

    # forward the main model
    pred_semantic, pred_detail, pred_matte = modnet(image, False)
    

    So here what you have to do is modify your example_input like this:

    example_input = (example_input, True)