I'm using PyTorch to train neural-net and output them into ONNX. I use these models in a Vespa index, which loads ONNXs through TensorRT. I need one-hot-encoding for some features but this is really hard to achieve within the Vespa framework.
Is it possible to embed a one-hot-encoding for some given features inside my ONNX net (e.g. before the network's representation) ? If so, how should I achieve this based on a PyTorch model ?
I already noticed two things:
EDIT 2021/03/11: Here is my workflow:
So, according to my testing, PyTorch does support one-hot encoding export to ONNX. With the following model:
#! /usr/bin/env python3
import torch
import torch.onnx
import torch.nn.functional as F
class MyModel(torch.nn.Module):
def __init__(self, classes=5):
super(MyModel, self).__init__()
self._classes = classes
self.linear = torch.nn.Linear(in_features=self._classes, out_features=1)
self.logistic = torch.nn.Sigmoid()
def forward(self, input):
one_hot = F.one_hot(input, num_classes=self._classes).float()
return self.logistic(self.linear(one_hot))
def main():
model = MyModel()
# training omitted
data = torch.tensor([0, 4, 2])
torch.onnx.export(model, data, "test.onnx",
input_names=["input"], output_names=["output"])
result = model.forward(data)
print(result)
if __name__ == "__main__":
main()
This model doesn't do any training, just takes a vector of indices in, one-hot encodes them using PyTorch's one_hot
and sends that to the simple NN layer. The weights are randomly initialised, and the output here for me was:
tensor([[0.5749],
[0.5081],
[0.5581]], grad_fn=<SigmoidBackward>)
This model is exported to ONNX to the "test.onnx" file. Testing this model using ONNX Runtime (which is what Vespa uses in the backend, not TensorRT):
In [1]: import onnxruntime as ort
In [2]: m = ort.InferenceSession("test.onnx")
In [3]: m.run(input_feed={"input":[0,4,2]}, output_names=["output"])
Out[3]:
[array([[0.57486993],
[0.5081395 ],
[0.5580716 ]], dtype=float32)]
Which is the same output as given from PyTorch with the same input. So PyTorch does export the OneHot
ONNX operator. This was for PyTorch 1.7.1.
If the input to the one-hot encoding is indexed in Vespa as integers, you can then just use these directly as inputs.