Search code examples
pythontensorflowkotlintensorflow-litefirebase-mlkit

Firebase ML Kit using custom TFLITE produces the same inference for varied outputs on Android


I'm working on a audio classification model that classifies audio based on it's genre.

The model takes in a few audio features like spectral centroid, etc and produces outputs such as classical/rock/etc. Input shape -> [1,26] It's a multi-label classifier. I have a Keras model which I've converted to a TFLite model for use on mobile platforms. I have tested the initial model and it works with a pretty decent accuracy, the tflite model when run with Python on my PC works just as well.

When I deploy this to Firebase's ML Kit and use it with the Android API, it produces a single label/class as an output for all kinds of input. I don't think it's a problem with the model as it works fine in my Jupyter notebook. I don't understand how it can produces different inference for the same input?

Keras model:

#The test model
import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Dropout, Activation

model = models.Sequential()
model.add(layers.Dense(256, activation='relu', input_shape=(X_train.shape[1],)))
model.add(Dropout(0.5))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='rmsprop',
              loss='sparse_categorical_crossentropy',
             metrics=['sparse_categorical_accuracy'])
history = model.fit(X_train,
                    y_train,
                    epochs=10)
#print(X_test[:1],y_test)
pred = model.predict_classes(X_test)
print(pred)
print(y_test)

Conversion code:

import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                       tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

Input/Output shapes:

import tensorflow as tf
​
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
​
# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>
​
# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>
[ 1 26]
<class 'numpy.float32'>
[ 1 10]
<class 'numpy.float32'>

The demo Kotlin code for testing:

listenButton.setOnClickListener {
            incorrecttagButton.alpha = 1f
            incorrecttagButton.isClickable = true
            //Code for listening to music
           FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
               .addOnSuccessListener { isDownloaded ->
                   val options =
                       if (isDownloaded) {
                           FirebaseModelInterpreterOptions.Builder(remoteModel).build()
                       } else {
                           FirebaseModelInterpreterOptions.Builder(localModel).build()
                       }
                   Log.d("HUSKY","Downloaded? ${isDownloaded}")
                   val interpreter = FirebaseModelInterpreter.getInstance(options)
                   val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
                       .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 26))
                       .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1,10))
                       .build()
                   if(songNum==5){
                       songNum=0
                   }
                   val testSong = testsongs[songNum]
                   Log.d("HUSKY", "Song num = ${songNum} F = ${testSong} ")
                   val input = Array(1){FloatArray(26)}
                   val itr =  testSong.split(",").toTypedArray()
                   val preInput = itr.map { it.toFloat() }
                   var x = 0
                   preInput.forEach {
                       input[0][x] = preInput[x]
                       x+=1
                   }
                   //val input = preInput.toTypedArray()
                   Log.d("HUSKY", "${input[0][1]}")
                   val inputs = FirebaseModelInputs.Builder()
                       .add(input) // add() as many input arrays as your model requires
                       .build()

                   val labelArray = "blues classical country disco hiphop jazz metal pop reggae rock".split(" ").toTypedArray()
                   Log.d("HUSKY2", "GG")
                   interpreter?.run(inputs, inputOutputOptions)?.addOnSuccessListener { result ->
                       Log.d("HUSKY2", "GGWP")
                       val output = result.getOutput<Array<FloatArray>>(0)
                       val probabilities = output[0]
                       var bestMatch = 0f
                       var bestMatchIndex = 0
                       for (i in probabilities.indices){
                           if(probabilities[i]>bestMatch){
                               bestMatch = probabilities[i]
                               bestMatchIndex = i
                           }
                           Log.d("HUSKY2", "${labelArray[i]} ${probabilities[i]}")
                           genreLabel.text = labelArray[i]
                       }
                       genreLabel.text = labelArray[bestMatchIndex].capitalize()
                       confidenceLabel.text = probabilities[bestMatchIndex].toString()

                       // ...
                   }?.addOnFailureListener { e ->
                       // Task failed with an exception
                       // ...
                       Log.d("HUSKY2", "GGWP :( ${e.toString()}")
                   }

               }

I'm using SongNum to increment the String Array to change the song. The features are stored as a String with a comma as the delimiter.

The output is as follows and is the same, regardless of the input features(SongNum variable to change songs [0-4]) and the confidence on pop is always 1.0 :

2020-02-25 00:11:21.014 17434-17434/com.rohanbojja.audient D/HUSKY: Downloaded? true
2020-02-25 00:11:21.015 17434-17434/com.rohanbojja.audient D/HUSKY: Song num = 0 F = 0.3595172803692916,0.04380025714635849,1365.710742222286,1643.935571084307,2725.445556640625,0.06513807508680555,-273.0061247040518,132.66331747988934,-31.86709317807114,44.21442952318603,4.335704872427025,32.32360339344842,-2.4662076330637714,20.458242724823684,-4.760171779927926,20.413702740993585,3.69545905318442,8.581128171784677,-15.601809275025104,5.295758930950924,-5.270195074271744,5.895109210872318,-6.1406603018722645,-2.9278519508415286,-1.9189588023091468,5.954495267889836 
2020-02-25 00:11:21.016 17434-17434/com.rohanbojja.audient D/HUSKY: 0.043800257
2020-02-25 00:11:21.016 17434-17434/com.rohanbojja.audient D/HUSKY2: GG
2020-02-25 00:11:21.021 17434-17434/com.rohanbojja.audient D/HUSKY2: GGWP
2020-02-25 00:11:21.021 17434-17434/com.rohanbojja.audient D/HUSKY2: blues 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: classical 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: country 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: disco 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: hiphop 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: jazz 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: metal 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: pop 1.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: reggae 0.0
2020-02-25 00:11:21.022 17434-17434/com.rohanbojja.audient D/HUSKY2: rock 0.0

The output on Jupyter Notebook is like below:

(blues,)    (classical,)    (country,)  (disco,)    (hiphop,)   (jazz,) (metal,)    (pop,)  (reggae,)   (rock,)
0   0.257037    0.000705    0.429687    0.030933    0.009291    0.004909    1.734001e-03    0.000912    0.203305    0.061488

From what I can conclude, I'm messing the usage of the ML Kit API? or perhaps the way I'm passing input data or retrieving output data? I'm new to android development.

Output: 'pop' has a confidence of 1.0 always! Expected output: Every genre should have some confidence between [0-1.0] and not 'pop' always, like my result from Jupyter notebook.

Sorry for the messy code.

Any help would be greatly appreciated!

Update 1: I swapped relu with sigmoid activation functions and I can notice the difference. It's still almost always "pop", but with about 0.30 confidence. It's super mysterious now. Happens only with ML Kit BTW, haven't really tried implementing it natively.

Update 2: I don't understand how I can get different inferences with the same model. I'm lost.


Solution

  • I've not normalized my features once extracted during the prediction phase i.e., the extracted features aren't transformed.

    I've transformed the training data with

    X = StandardScaler().fit_transform(np.array(data.iloc[:,1:-1]))
    

    To solve this issue, I had to transform the features:

    scaler=StandardScaler().fit(np.array(data.iloc[:,1:-1]))
    input_data = scaler.transform(input_data2)