Let's say I have my original .pt
weights and I export them to ONNX
, OpenVINO
and TFLite
. Is there a way of loading these models without needing to write a custom class that checks its type of instance and loads it accordingly?
OpenVINO model loading example:
from openvino.runtime import Core
ie = Core()
classification_model_xml = "model/classification.xml"
model = ie.read_model(model=classification_model_xml)
compiled_model = ie.compile_model(model=model, device_name="CPU")
TFlite model loading example:
class TestModel(tf.Module):
def __init__(self):
super(TestModel, self).__init__()
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 10], dtype=tf.float32)])
def add(self, x):
Simple method that accepts single input 'x' and returns 'x' + 4.
# Name the output 'result' for convenience.
return {'result' : x + 4}
SAVED_MODEL_PATH = 'content/saved_models/test_variable'
TFLITE_FILE_PATH = 'content/test_variable.tflite'
# Save the model
module = TestModel()
# You can omit the signatures argument and a default signature name will be
# created with name 'serving_default'.
# Convert the model using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
tflite_model = converter.convert()
with open(TFLITE_FILE_PATH, 'wb') as f:
# Load the TFLite model in TFLite Interpreter
interpreter = tf.lite.Interpreter(TFLITE_FILE_PATH)
# There is only 1 signature defined in the model,
# so it will return it by default.
# If there are multiple signatures then we can pass the name.
my_signature = interpreter.get_signature_runner()
# my_signature is callable with input as arguments.
output = my_signature(x=tf.constant([1.0], shape=(1,10), dtype=tf.float32))
# 'output' is dictionary with all outputs from the inference.
# In this case we have single output 'result'.
pt model loading example:
model = TheModelClass(*args, **kwargs)
The closest solution I could find for this, atm, is Ivy . However, you have to write your model using their framework agnostic operations and the set of available frameworks are quite limited: jnp
, tf
, np
, mx
, torch
I found a quite good adapter that I will base my own on, here