I can't find anyone who explains to a layman how to load an onnx model into a python script, then use that model to make a prediction when fed an image. All I could find were these lines of code:
sess = rt.InferenceSession("onnx_model.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred = sess.run([label_name], {input_name: X.astype(np.float32)})[0]
But I don't know what any of that means. And everywhere I look, everybody already seems to know what they mean, so nobody's explaining it. That would be one thing if I could just run this code, but I can't. It gives me this error:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid rank for input: Input3 Got: 2 Expected: 4 Please fix either the inputs or the model.
So I need to actually know what those things mean so I can figure out how to fix the error. Will someone knowledgeable please explain?
Let's first start by going over the code you provided, to make everything clear.
sess = ort.InferenceSession("onnx_model.onnx")
This line loads the model into a session object. This means that the layers, functions and weights used in the model are made ready to perform inferences.
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
The two methods get_inputs
and get_outputs
each retrieve some meta information about the model, that being what inputs the model expects, and what outputs it can provide. Off of this meta information in these lines, only the first input & output is actually used, and off of these, only the name is being gotten, and saved into variables.
For the last line, let's tackle that part by part.
pred = sess.run(...)[0]
This performs a inference on the model, we'll go over the inputs to this method after this, but for now, the output is a list of different outputs. These outputs are each numpy arrays. In this case only the first output in this list is being used, and saved to the pred
variable
([label_name], {input_name: X.astype(np.float32)})
These are the inputs to sess.run
. The fist is a list of names of outputs that you want to be computed by the session. The second argument is a dict, where each input's name maps to numpy arrays. These arrays are are expected to be of the same dimension as the ones supplied during creation of the model. Similarly the types of these arrays should also match the types used during creation of the model.
The error you encountered seems to indicate that the supplied array doesn't have the expected dimensions. These intended amount of dimensions seems to be 4.
To gain clarity about what the exact shape and data type of the input array should be, there are visualization tools, like Netron