Search code examples
pythonnode.jstensorflowtensorflow.js

Using JS backend with Python for Machine Learning


I need some senior advice here. I want to create an API using JS, but all the ML functionality using Python. I dont want to get rid of the awesome JS libraries like GraphQL, but i dont want to sacrifice the Python performance. I know I can use Tensorflow.js, but as I said, in terms of performance, Python is way better.

I have in mind something like deploying to the cloud a ML model using Python and then fetch the predictions in my JS API or something like that.

Other idea is to create the inference using Python, save it in form of .h5 or .json, and then load them directly with Tensorflow.js in my API.

##### LOCAL #####
inputs = Input(shape=(trainX.shape[1], trainX.shape[2], trainX.shape[3]))

...
Conv2D
Conv2D
Conv2D
...

model = Model(inputs=inputs, outputs=predictions)

model.compile(...)

model.fit(...)

model.save(model.json) # I dont think I can save the weights in Python in the .json format
##### API #####
https.get('SOMEURL', (resp) => {
  const model = await tf.loadLayersModel('file://path/to/my-model/model.json');
  
  const { data } = resp
  
  return model.predict(data)

}).on("error", (err) => {
  console.log("Error: " + err.message);
});

I dont really know if this could ever work, or there is a better form for this (or is it even possible).

All ideas and advices are appreciated. Thank You.


Solution

  • You have pointed out the two methods that you can use to performing predictions for your ML/DL model. I will list down the steps needed for each and my own personal recommendations.


    Local:

    Here you would have to build and train the model using Tensorflow and Python. Then to use the model on your web application you would need to convert it to the correct format using tfjs-converter. For example, you would get back a model.json and group1-shard1of1.bin file which you can then use to make predictions on data from the client side. To improve performance you can quantize the model when converting it.

    1. I find it easier this way to make predictions as the whole process is not difficult.
    2. Model is always on the client side so it should be the best option if you are looking for very quick predictions.
    3. Security wise, if this model is being used in production then no user data will ever be passed to the server side so users do not have to worry about their data being used inappropriately. For example, if you are in the European Union you would have to abide by the General Data Protection Regulation (GDPR) which really complicates things.
    4. If you want to improve the model then you would need to train a new model followed by an update on the web application to change the model files. You are unable to perform online-learning(training the model on new data it sees and improving it on the fly).

    API:

    Here you would have to use some sort of library for making REST API's. I would recommend FastAPI which is quite easy to get up and running. You would need to create routes for you to POST data to the model. You create routes that you make POST request to where these request receive the data from the client side and then using the model you have perform predictions on the data. Then it will send back the predictions in the request's body. The API and the code for making predictions would have to be hosted somewhere for you to query it from the client side, you could use Heroku for this. This article goes over the entire process.

    1. Process is convoluted in comparison to the local method.
    2. Data needs to be sent to the server so if you need very fast predictions on a lot of data this will be slower compared to the local method.
    3. For production use-cases this is the preferred method unless the user data cannot be sent to the server.
    4. These are REST API's so get it to work with GraphQL you would have to wrap the REST API's with GraphQL using the steps detailed here.
    5. You can continuously improve the model without having to touch the client side code.

    I dont want to get rid of the awesome JS libraries like GraphQL, but i dont want to sacrifice the Python performance. I know I can use Tensorflow.js, but as I said, in terms of performance, Python is way better.

    One thing I would like to point out is that the prediction speed for the model is going to be the same regardless if you use Python or Javascript. The only way you can improve it is by quantization, which reduces the model size while also improving CPU and hardware accelerator latency, with little degradation in model accuracy as all that you are doing is making predictions using the model. Unless you are sending huge amounts of data to the endpoint in an area with slow internet speeds the differences between using either method is negligible.