Search code examples
pythonscikit-learnrandom-forestonnxonnxruntime

final_types for RandomForestClassifier skl2onnx


How is the best way to define final_types for RandomForestClassifier?

If I do the following:

initial_type = [('input', FloatTensorType([None, 13]))]
final_type = [('output', FloatTensorType([None, 1]))]

sklonnx = convert_sklearn(rfc, initial_types=initial_type, final_types=final_type)
with open("sklrfc.onnx", "wb") as f:
    f.write(sklonnx.SerializeToString())

I get the following error:

RuntimeError: Number of declared outputs is unexpected, declared 'output' found 'output_label, output_probability'.

So I change the the final_type to:

initial_type = [('input', FloatTensorType([None, 13]))]
final_type = [('label', Int64TensorType([None, 1])),
              ('output', FloatTensorType([None, 1]))]

sklonnx = convert_sklearn(rfc, initial_types=initial_type, final_types=final_type)
with open("sklrfc.onnx", "wb") as f:
    f.write(sklonnx.SerializeToString())

Which doesn't produce any errors, However when I go run InferenceSession:

import onnxruntime as rt
sess = rt.InferenceSession("sklrfc.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: X_test.astype(np.float32)})[0]

I get this error instead:

InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from sklrfc.onnx failed:This is an invalid model. Type Error: Type 'seq(map(int64,tensor(float)))' of input parameter (output_probability) of operator (Cast) in node (Cast2) is invalid.

Is they something I have to change in my model or in the process of converting the model to onnx?


Solution

  • I was given a fix:

    initial_type = [('input', FloatTensorType([None, 13]))]
    final_type = [('label', Int64TensorType([None, 1])),
                  ('output', FloatTensorType([None, 1]))]
    
    sklonnx = convert_sklearn(rfc, initial_types=initial_type, final_types=final_type, **options={'zipmap': False}**)
    with open("sklrfc.onnx", "wb") as f:
        f.write(sklonnx.SerializeToString())
    

    Which fixed the error when I run the InferenceSession