This model works in PyTorch however, after exporting it with PyTorch to Onnx format, the onnx runtime crashes with a 'Trilu NOT_IMPLEMENTED error' when loading it in. (I do not have this issue for my other models that use torch.tril() )
How do I make this model run in the Onnxruntime?
This is a visualisation of the Onnx graph of the Model.
The Model in PyTorch
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, item_seq):
attention_mask = item_seq < 100
tril_mask = torch.tril(attention_mask)
query_layer = torch.rand((1, 2, 2, 32))
key_layer = torch.rand((1, 2, 32, 2))
attention_scores = torch.matmul(query_layer, key_layer)
return attention_scores + tril_mask
model = MyModel()
model.eval()
x_train = torch.ones([1, 2], dtype=torch.long)
# demonstrate that eager works
print(model.forward(x_train))
bigmodel_onnx_filename = 'mymodel.onnx'
torch.onnx.export(
model,
x_train,
bigmodel_onnx_filename,
input_names=['x'],
output_names=['output'],
)
onnx.load(bigmodel_onnx_filename)
# Onnxruntime crashes when loading in the model
ort_sess = ort.InferenceSession(bigmodel_onnx_filename, providers=['CPUExecutionProvider'])
key = {'x': x_train.numpy()}
print(ort_sess.run(None, key))
This results in the following error for ort.InferenceSession():
NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node with name '/net/Trilu'
How can I make this model run in the Onnxruntime?
[github: code to reproduce the error and the model.onnx file] (https://github.com/bkersbergen/pytorch_onnx_runtime_error/blob/main/main.py)
I'm using python 3.9, these are the project requirements
torch==1.13.1
jupyter==1.0.0
onnxruntime==1.13.1
onnx==1.13.0
Torch nightly version 2.0.0.dev20230205 gave the same error
I then decided to implement my own tril function.
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, item_seq):
attention_mask = item_seq < 100
tril_mask = self.my_tril(attention_mask)
query_layer = torch.rand((1, 2, 2, 32))
key_layer = torch.rand((1, 2, 32, 2))
attention_scores = torch.matmul(query_layer, key_layer)
return attention_scores + tril_mask
def my_tril(self, x):
l = x.size(-1)
arange = torch.arange(l)
mask = arange.expand(l, l)
arange = arange.unsqueeze(-1)
mask = torch.le(mask, arange)
return x.masked_fill(mask == 0, 0)
but then I get a Where(9) node with name '/Where_1' NOT_IMPLEMENTED error. (?!)
The boolean output of torch.lt() as input for torch.tril() works with PyTorch's Eager and LIT mode. However it breaks the Onnx runtime with the "TRILU not implemented error".
I was able to work around it by casting the torch.tril() input to float():
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, item_seq):
attention_mask = torch.lt(item_seq, 100).float()
tril_mask = torch.tril(attention_mask)
query_layer = torch.rand((1, 2, 2, 32))
key_layer = torch.rand((1, 2, 32, 2))
attention_scores = torch.matmul(query_layer, key_layer)
return attention_scores + tril_mask
Based on this experience, my hypothesis is that the TRILU NOT_IMPLEMENTED error is only applicable when having BOOLEAN Tensors as input. The Onnxruntime then throws this generic TRILU NOT_IMPLEMENTED error making me believe that Onnx has no TRILU support at all, which is clearly not the case.