Search code examples
pythondeep-learningpytorchonnx

issue while exporting torch model to onnx format


I'm trying to export my PyTorch model to an ONNX format but I keep getting this error:

TypeError: forward() missing 1 required positional argument: 'text'

This is my code:

model = Model(opt)
dummy_input = torch.randn(1, 3, 224, 224)
file_path='/content/drive/MyDrive/VitSTR/vitstr_tiny_patch16_224_aug.pth'
torch.save(model.state_dict(), file_path)
model.load_state_dict(torch.load(file_path))
#model = torch.nn.DataParallel(model).to(device)
#print(model)
torch.onnx.export(model, dummy_input, "vitstr.onnx", verbose=True)

Solution

  • ViTSTR forward requires two positional arguments, input and text:

    def forward(self, input, text, is_train=True, seqlen=25):
        # ...
    

    Therefore, you need to pass an additional argument:

    # ...
    dummy_text = # create a dummy_text as well, with the appropriate shape
    torch.onnx.export(model, (dummy_input, dummy_text), "vitstr.onnx", verbose=True)