Search code examples
javascripttensorflowtensorflow.jstensorflow-hubtfjs-node

Tensorflow JS load TFHub model locally


I'm loading a model from TFHub via the loadGraphModel of @tensorflow/tfjs-converter package.

loadModel = function () {
    return __awaiter(this, void 0, void 0, function () {
      return __generator(this, function (_a) {
        return [2, tfconv.loadGraphModel('https://tfhub.dev/tensorflow/tfjs-model/toxicity/1/default/1', { fromTFHub: true })];
      });
    });
  };

This works fine. I have then downloaded the model locally for offline prediction:

.
├── group1-shard1of7
├── group1-shard2of7
├── group1-shard3of7
├── group1-shard4of7
├── group1-shard5of7
├── group1-shard6of7
├── group1-shard7of7
├── model.json
└── vocab.json

and I would like to load now this model locally. For other models, I was using the method tfjs.loadLayersModel(this.path) that support file protocol file://, but if I try to load this model I get the following error

'className' and 'config' must set.

In fact in the model.json these keys are missing. I have tried to load the graph from the hub in this way:

var loadGraphModel = function () {
  return new Promise(function (resolve, reject) {
    tfconv.loadGraphModel(
      'https://tfhub.dev/tensorflow/tfjs-model/toxicity/1/default/1',
      { fromTFHub: true })
      .then(res => {
        console.log(res)
        resolve(res);
      })
      .catch(err => reject(err));
  });
}

GraphModel {
  modelUrl: 'https://tfhub.dev/tensorflow/tfjs-model/toxicity/1/default/1/model.json?tfjs-format=file',
  loadOptions: { fromTFHub: true },
  version: 'undefined.undefined',
  handler: HTTPRequest {
    DEFAULT_METHOD: 'POST',
    weightPathPrefix: undefined,
    onProgress: undefined,
    fetch: [Function],
    path: 'https://tfhub.dev/tensorflow/tfjs-model/toxicity/1/default/1/model.json?tfjs-format=file',
    requestInit: {}
  },
  artifacts: {
    modelTopology: { node: [Array], library: {}, versions: {} },
    weightSpecs: [
    ...

So how to load this TFHub GraphModel locally like saving and reloading from local file system?


Solution

  • for tf_saved_model -

    Host the model files and shard bins on s3/localfilesystem setup with CORS.

    I ended up doing this

    npm install http-server -g
    
    # where model files are stored
    http-server -p 3004 --cors
    

    On the React end using tensorflow-js

    import * as tf from "@tensorflow/tfjs";
    import { loadGraphModel } from "@tensorflow/tfjs-converter";
    
    const net = await loadGraphModel("http://localhost:3004/tfjsexport/model.json");
    
    
    const img = tf.browser.fromPixels(video);
    const resized = tf.image.resizeBilinear(img, [640, 480]);
    const casted = resized.cast("int32");
    const expanded = casted.expandDims(0);
    const obj = await net.executeAsync(expanded);