Search code examples
pytorchonnx

Onnx RuntimeError NOT_IMPLEMENTED Trilu


This model works in PyTorch however, after exporting it with PyTorch to Onnx format, the onnx runtime crashes with a 'Trilu NOT_IMPLEMENTED error' when loading it in. (I do not have this issue for my other models that use torch.tril() )

How do I make this model run in the Onnxruntime?

This is a visualisation of the Onnx graph of the Model. Onnx graph of MyModel using Netron

The Model in PyTorch

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

    def forward(self, item_seq):
        attention_mask = item_seq < 100
        tril_mask = torch.tril(attention_mask)
        query_layer = torch.rand((1, 2, 2, 32))
        key_layer = torch.rand((1, 2, 32, 2))
        attention_scores = torch.matmul(query_layer, key_layer)
        return attention_scores + tril_mask


model = MyModel()
model.eval()

x_train = torch.ones([1, 2], dtype=torch.long)
# demonstrate that eager works
print(model.forward(x_train))

bigmodel_onnx_filename = 'mymodel.onnx'
torch.onnx.export(
    model,
    x_train,
    bigmodel_onnx_filename,
    input_names=['x'],
    output_names=['output'],
)

onnx.load(bigmodel_onnx_filename)

# Onnxruntime crashes when loading in the model
ort_sess = ort.InferenceSession(bigmodel_onnx_filename, providers=['CPUExecutionProvider'])
key = {'x': x_train.numpy()}
print(ort_sess.run(None, key))

This results in the following error for ort.InferenceSession():

NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node with name '/net/Trilu'

How can I make this model run in the Onnxruntime?

[github: code to reproduce the error and the model.onnx file] (https://github.com/bkersbergen/pytorch_onnx_runtime_error/blob/main/main.py)

I'm using python 3.9, these are the project requirements

torch==1.13.1

jupyter==1.0.0

onnxruntime==1.13.1

onnx==1.13.0

Torch nightly version 2.0.0.dev20230205 gave the same error

I then decided to implement my own tril function.


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

    def forward(self, item_seq):
        attention_mask = item_seq < 100
        tril_mask = self.my_tril(attention_mask)
        query_layer = torch.rand((1, 2, 2, 32))
        key_layer = torch.rand((1, 2, 32, 2))
        attention_scores = torch.matmul(query_layer, key_layer)
        return attention_scores + tril_mask

    def my_tril(self, x):
        l = x.size(-1)
        arange = torch.arange(l)
        mask = arange.expand(l, l)
        arange = arange.unsqueeze(-1)
        mask = torch.le(mask, arange)
        return x.masked_fill(mask == 0, 0)

but then I get a Where(9) node with name '/Where_1' NOT_IMPLEMENTED error. (?!)


Solution

  • The boolean output of torch.lt() as input for torch.tril() works with PyTorch's Eager and LIT mode. However it breaks the Onnx runtime with the "TRILU not implemented error".

    I was able to work around it by casting the torch.tril() input to float():

    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
    
        def forward(self, item_seq):
            attention_mask = torch.lt(item_seq, 100).float()
            tril_mask = torch.tril(attention_mask)
            query_layer = torch.rand((1, 2, 2, 32))
            key_layer = torch.rand((1, 2, 32, 2))
            attention_scores = torch.matmul(query_layer, key_layer)
            return attention_scores + tril_mask
    

    Based on this experience, my hypothesis is that the TRILU NOT_IMPLEMENTED error is only applicable when having BOOLEAN Tensors as input. The Onnxruntime then throws this generic TRILU NOT_IMPLEMENTED error making me believe that Onnx has no TRILU support at all, which is clearly not the case.