I have created a simple tensorflow classification model which I converted and exported as a .tflite file. For the integration of the model in my android app I've followed this tutorial, but they are covering only the single input/output model type for the inference part. After looking on documentation and some other sources, I've implemented the following solution:
// acc and gyro X, Y, Z are my features
float[] accX = new float[1];
float[] accY = new float[1];
float[] accZ = new float[1];
float[] gyroX = new float[1];
float[] gyroY = new float[1];
float[] gyroZ = new float[1];
Object[] inputs = new Object[]{accX, accY, accZ, gyroX, gyroY, gyroZ};
// And I have 4 classes
float[] output1 = new float[1];
float[] output2 = new float[1];
float[] output3 = new float[1];
float[] output4 = new float[1];
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, output1);
outputs.put(1, output2);
outputs.put(2, output3);
outputs.put(3, output4);
interpreter.runForMultipleInputsOutputs(inputs, outputs);
but this code throws an exception:
java.lang.IllegalArgumentException: Invalid input Tensor index: 1
At this step I'm not sure what's wrong.
Here is my model's architecture:
model = tf.keras.Sequential([
tf.keras.layers.Dense(units=hp_units, input_shape=(6,), activation='relu'),
tf.keras.layers.Dense(240, activation='relu'),
tf.keras.layers.Dense(4, activation='softmax')
])
Solution:
Based on @Karim Nosseir's answer, I used the signature method to access the inputs and outputs of my model. If you have a model built in python then you can find out the signature like in the answer and use it as shown below:
Python signature:
{'serving_default': {'inputs': ['dense_6_input'], 'outputs': ['dense_8']}}
Android java use:
float[] input = new float[6];
float[][] output = new float[1][4];
// Run decoding signature.
try (Interpreter interpreter = new Interpreter(loadModelFile())) {
Map<String, Object> inputs = new HashMap<>();
inputs.put("dense_6_input", input);
Map<String, Object> outputs = new HashMap<>();
outputs.put("dense_8", output);
interpreter.runSignature(inputs, outputs, "serving_default");
} catch (IOException e) {
e.printStackTrace();
}
The easiest is to use the signature API and use signature names for inputs/outputs
You should find a signature defined if you used the v2 TFLite Converter.
Example that prints which signatures defined is below
model = tf.keras.Sequential([
tf.keras.layers.Dense(4, input_shape=(6,), activation='relu'),
tf.keras.layers.Dense(240, activation='relu'),
tf.keras.layers.Dense(4, activation='softmax')
])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
print(interpreter.get_signature_list())
See the guide here on how to run for Java and other languages.