Search code examples
machine-learningkerastensorflow.jstensorflowjs-converter

Tensorflow.js Loaded Model Performing Significantly Worse than Keras Model


So, I am attempting to create a Dog vs. Cat Image Classification model using Keras. Part of my goal is to create a website that deploys the model using Tensorflow.js. I have successfully deployed the model using Flask as the server.

The main issue is that the model is Tensorflow.js performs so much worse than the model in plain keras. When using plain keras, my model achieved around 90% accuracy on the test data. However, when used in tensorflow.js, the model did not get a single of the test images correct. I would appreciate any help or any tips on fixing this issue.

templates/index.html

<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width">
    <title>repl.it</title>

    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>

    <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
    <link href="{{ url_for('static', filename='index.css') }}" rel="stylesheet" type="text/css" />
    <link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
  </head>
  <body onload="$('#result').hide();$('#continue').hide();">
    <div class="container-fluid">
      <!-- START HEADER -->
      <div class="row" id="headerRow">
        <div class="col-md d-flex align-items-center" id="headerColumn">
          <h2>Cat<span class='or'>or</span>Dog</h2>
        </div>
      </div>
      <!-- END HEADER -->

      <!-- START BODY -->
      <div class="row bodyRow" id='bodyRow'>
        <div class="col-md d-flex align-items-center bodyColumn">
          <div class="body">
            <form class="d-flex align-items-center  justify-content-center imageSubmitForm" method="POST" enctype="multipart/form-data">
              <label class="d-flex align-items-center justify-content-center" for='imageInputField'>
                <i class="material-icons">file_upload</i>

                <p id='result'></p>
                <br/>
                <p id='continue'>Press Anywhere to continue...</p>
              </label>
              <input class="imageInputField" id='imageInputField' type='file' onchange='getPrediction(url)'/>
            </form>
          </div>
        </div>
      </div>
      <!-- END BODY -->

      <!-- START RESULT -->
      <div class="row resultRow">
        <div class="col-md-6 classResultColumn">
            <div class="d-flex align-items-center justify-content-center classResultBox">
                <p id='classResult'></p>
            </div>
        </div>
        <div class="col-md-6 scoreResultColumn">
            <div class="d-flex align-items-center justify-content-center scoreResultBox">
                <p id='scoreResult'></p>
            </div>
        </div>
      </div>
      <!-- END RESULT -->

      <!-- START FOOTER -->
      <!--
      <div class="row d-flex align-items-center footerRow" id='footerRow'>
        <center><a src="#">Source Code</a></center>
      </div>
      -->
      <!-- FOOTER -->
    </div>

    <!-- START SCRIPTS -->
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js"></script>
    <script src="https://code.jquery.com/jquery-3.4.1.min.js"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.12.9/umd/popper.min.js" integrity="sha384-ApNbgh9B+Y1QKtv3Rn7W3mgPxhU9K/ScQsAP7hUibX39j7fakFPskvXusvfa0b4Q" crossorigin="anonymous"></script>
    <script src="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/js/bootstrap.min.js" integrity="sha384-JZR6Spejh4U02d8jOt6vLEHfe/JQGiRRSQQxSfFWpi1MquVdAyjUar5+76PVCmYl" crossorigin="anonymous"></script>
    <script src="{{ url_for('static', filename='index.js') }}"></script>   
    <!-- END SCRIPTS -->
  </body>
</html>

static/index.js

let fileInput = document.getElementById("imageInputField");
let classResultElement = document.getElementById("classResult");
let scoreResultElement = document.getElementById("scoreResult");
let url = "/model";

let model;
let file;
let data;
let responseContent;
let features;
let predictedClass;

let getPrediction = async(url) => {
    if (!model)
        model = await tf.loadLayersModel(url);

    file = fileInput.files[0];
    data = new FormData();
    data.append("file", file);

    $.ajax({
        url : "/api/preprocess",
        type: 'POST',
        data: data,
        traditional: true,
        processData: false,
        contentType: false,

        success: function(response)
        {
            responseContent = JSON.parse(response)['image'];

            if (responseContent != "False")
            {
                features = tf.tensor(responseContent);
                score = model.predict(features).dataSync();

                alert(score);

                if (score >= 0.5) {
                    predictedClass = "Dog";

                    classResultElement.innerHTML = "<b>Predicted Class:</b> " + predictedClass;
                    scoreResultElement.innerHTML = "<b>Certainty:</b> " + score*100.0 + "%";
                } else {
                    predictedClass = "Cat";

                    classResultElement.innerHTML = "<b>Predicted Class:</b>" + predictedClass;
                    scoreResultElement.innerHTML = "<b>Certainty:</b> " + (1.0 - score) * 100.0 + "%";
                }

                alert(predictedClass);
            }
        }
    });
}

app.py

import flask
from flask_cors import CORS
from werkzeug import secure_filename
import time
import os
import keras
import numpy as np
import json
import matplotlib.pyplot as plt

app = flask.Flask(__name__)
CORS(app)

UPLOADS_DIR = "uploads/"

@app.route("/")
def index():
  """
    Fetch and return the main homepage. 
  """
  return flask.render_template("index.html")

@app.route("/favicon.ico")
def get_favicon():
  """
    Return a fake message in order to silence the error caused by a favicon not being found.
  """
  return "Favicon Does Not Exist"

@app.route("/model")
def get_modeljson():
  """
    Get the model.json file and return it's contents.
  """
  with open("model/model.json", "r") as f:
    return f.read()

@app.route("/<path:path>")
def get_shard(path):
  """
    get the binary weight file for the model (also known as a shard).

    path    =>    the filename of the binary weight file.
  """
  return flask.send_from_directory("model/", path)

@app.route("/api/preprocess", methods=['POST'])
def preprocess():
  """
    takes an image object from an AJAX request and returns a normalized list of the values.
  """
  if flask.request.method == 'POST':
    file = flask.request.files['file']
    filename = secure_filename(file.filename)
    new_filename = "{}_{}".format(time.time(), filename)
    file.save(os.path.join(UPLOADS_DIR, new_filename))

    img_obj = keras.preprocessing.image.load_img(os.path.join(UPLOADS_DIR, new_filename), target_size=(224, 224))
    img_arr = keras.preprocessing.image.img_to_array(img_obj).reshape(1, 224, 224, 3)
    img_arr = np.divide(img_arr, 255.)

    os.remove(os.path.join(UPLOADS_DIR, new_filename))
    return json.dumps({"image":img_arr.tolist()})
  return json.dumps({"image":"False"})

if __name__ == "__main__":
  app.run()

You can find the URL to the kaggle notebook used to train the model here. You can find the notebook used to test the code here.

Any help or tips is greatly appreciated.


Solution

  • After a ton of coffee and barely any sleep I came across a solution. Apparently, the interanals of WebGL works differently than the internals of Tensorflow in Python. The workaround here is to disable WebGL.

    Just before you load the model of the graph, add...

    tf.ENV.set("WEBGL_PACK", false);
    

    This disables WebGL and forces TFJS to act more like python!