Search code examples
pythonandroidtensorflowtensorflow-lite

How to do the inference for a Tensorflow lite model with multiple inputs and outputs?


I have created a simple tensorflow classification model which I converted and exported as a .tflite file. For the integration of the model in my android app I've followed this tutorial, but they are covering only the single input/output model type for the inference part. After looking on documentation and some other sources, I've implemented the following solution:

        // acc and gyro X, Y, Z are my features
        float[] accX = new float[1];
        float[] accY = new float[1];
        float[] accZ = new float[1];

        float[] gyroX = new float[1];
        float[] gyroY = new float[1];
        float[] gyroZ = new float[1];


        Object[] inputs = new Object[]{accX, accY, accZ, gyroX, gyroY, gyroZ};
        
        // And I have 4 classes
        float[] output1 = new float[1];
        float[] output2 = new float[1];
        float[] output3 = new float[1];
        float[] output4 = new float[1];

        Map<Integer, Object> outputs = new HashMap<>();
        outputs.put(0, output1);
        outputs.put(1, output2);
        outputs.put(2, output3);
        outputs.put(3, output4);

        interpreter.runForMultipleInputsOutputs(inputs, outputs);

but this code throws an exception:

java.lang.IllegalArgumentException: Invalid input Tensor index: 1

At this step I'm not sure what's wrong.

Here is my model's architecture:

 model = tf.keras.Sequential([
        tf.keras.layers.Dense(units=hp_units, input_shape=(6,), activation='relu'),
        tf.keras.layers.Dense(240, activation='relu'),
        tf.keras.layers.Dense(4, activation='softmax')
    ])

Solution:

Based on @Karim Nosseir's answer, I used the signature method to access the inputs and outputs of my model. If you have a model built in python then you can find out the signature like in the answer and use it as shown below:

Python signature:

{'serving_default': {'inputs': ['dense_6_input'], 'outputs': ['dense_8']}}

Android java use:

        float[] input = new float[6];
        float[][] output = new float[1][4];
        
        // Run decoding signature.
        try (Interpreter interpreter = new Interpreter(loadModelFile())) {
            Map<String, Object> inputs = new HashMap<>();
            inputs.put("dense_6_input", input);

            Map<String, Object> outputs = new HashMap<>();
            outputs.put("dense_8", output);

            interpreter.runSignature(inputs, outputs, "serving_default");
        } catch (IOException e) {
            e.printStackTrace();
        }

Solution

  • The easiest is to use the signature API and use signature names for inputs/outputs

    You should find a signature defined if you used the v2 TFLite Converter.

    Example that prints which signatures defined is below

    model = tf.keras.Sequential([
            tf.keras.layers.Dense(4, input_shape=(6,), activation='relu'),
            tf.keras.layers.Dense(240, activation='relu'),
            tf.keras.layers.Dense(4, activation='softmax')
        ])
    
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    print(interpreter.get_signature_list())
    

    See the guide here on how to run for Java and other languages.