Search code examples
pythonjavaandroidtensorflowtensorflow-lite

TfLite Model is giving different output on Android app and in python . For most inputs tflite model gives same output on android . Why? Please fix


So I'm building a very simple model using tensorflow that gives x+1 as output (prediction). I'll deploy this model on android application so I convert it to tflite format. Building model

Python

import tensorflow as tf
# Create a simple Keras model.      
x = [1,2,3,4,5,6,7,8,9,10]
y = [2,3,4,5,6,7,8,9,10,11]

model = tf.keras.models.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x, y, epochs=50)

path_file = 'saved_model/hello_world_tensorflow'
tf.saved_model.save(model, path_file)

import pathlib
# Convert the model.
converter = tf.lite.TFLiteConverter.from_saved_model(path_file)
tflite_model = converter.convert()
tflite_model_file = pathlib.Path('model1.tflite')
tflite_model_file.write_bytes(tflite_model)

Using model in Python code for getting output

import numpy as np
import tensorflow as tf

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="model1.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on random input data.
input_shape = input_details[0]['shape']
print(input_shape)
input_data = np.array([[3]], dtype=np.float32) # 3 is the input here
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data,input_data)

Using model in Java Code (MainActivity.java File ) android

package ar.labs.androidml;

import androidx.appcompat.app.AppCompatActivity;

import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.EditText;
import android.widget.TextView;
import android.widget.Toast;

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

import java.nio.ByteBuffer;

import ar.labs.androidml.ml.Model1;

public class MainActivity extends AppCompatActivity {

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        Button btn= findViewById(R.id.button);
        btn.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                try{
                    EditText inputEditText;

                    inputEditText = findViewById(R.id.editTextNumberDecimal);
                    Float data= Float.valueOf(inputEditText.getText().toString());
                    ByteBuffer byteBuffer= ByteBuffer.allocateDirect(1*4);
                    byteBuffer.putFloat(data);

                    Model1 model = Model1.newInstance(getApplicationContext());

                    // Creates inputs for reference.
                    TensorBuffer inputFeature0 = TensorBuffer.createFixedSize(new int[]{1, 1}, DataType.FLOAT32);
                    inputFeature0.loadBuffer(byteBuffer);

                    // Runs model inference and gets result.
                    Model1.Outputs outputs = model.process(inputFeature0);
                    TensorBuffer outputFeature0 = outputs.getOutputFeature0AsTensorBuffer();

                    // Releases model resources if no longer used.
                    TextView tv= findViewById(R.id.textView);
                    float[] data1=outputFeature0.getFloatArray();

                    tv.setText(outputFeature0.getDataType().toString());
                    tv.setText(String.valueOf(data1[0]));


                    model.close();

                }
                catch (Exception e)
                {
                    Toast.makeText(getApplicationContext(),"Issue...",Toast.LENGTH_LONG).show();
                }
            }
        });
    }
}

Python code:

  • Input -> Output
  • 2 -> 2.5395..
  • 3 -> 3.6323..

Java Code

  • Input -> Output
  • 1 -> 0.3540...
  • 2 -> 0.3540..
  • 2.1 -> 2.967..E23
  • 2.11 -> 0.39083
  • 41 -> 0.3540

Why the outputs are behaving this way in java file?


Solution

  • Solved by myself! Add the new line so that the bytes are returned in LITTLE_ENDIAN. By default, the order of a ByteBuffer object is BIG_ENDIAN. Finally, the order method in is invoked to modify the byte order. The ByteOrder.nativeOrder() method returns the LITTLE_ENDIAN byte order. The order method creates a new buffer modifiedBuffer, and sets the byte order to LITTLE_ENDIAN.

    ByteBuffer byteBuffer= ByteBuffer.allocateDirect(1*4);
    byteBuffer.order(ByteOrder.nativeOrder()); // new line added
    byteBuffer.putFloat(data);
    

    TfLite saved models only support Little Endian format by default.