Search code examples
google-app-enginetensorflow

How to restore Tensorflow model from Google bucket without writing to filesystem?


I have a 2gb Tensorflow model that I'd like to add to a Flask project I have on App Engine but I can't seem to find any documentation stating what I'm trying to do is possible.

Since App Engine doesn't allow writing to the file system, I'm storing my model's files in a Google Bucket and attempting to restore the model from there. These are the files there:

  • model.ckpt.data-00000-of-00001
  • model.ckpt.index
  • model.ckpt.meta
  • checkpoint

Working locally, I can just use

with tf.Session() as sess:
    logger.info("Importing model into TF")
    saver = tf.train.import_meta_graph('model.ckpt.meta')
    saver.restore(sess, model.ckpt)

Where the model is loaded into memory using Flask's @before_first_request.

Once it's on App Engine, I assumed I could to this:

blob = bucket.get_blob('blob_name')
filename = os.path.join(model_dir, blob.name)
blob.download_to_filename(filename)

Then do the same restore. But App Engine won't allow it.

Is there a way to stream these files into Tensorflow's restore functions so the files don't have to be written to the file system?


Solution

  • After some tips from Dan Cornilescu and digging into it I found that Tensorflow builds the MetaGraphDef with a function called ParseFromString, so here's what I ended up doing:

    from google.cloud import storage
    from tensorflow import MetaGraphDef
    
    client = storage.Client()
    bucket = client.get_bucket(Config.MODEL_BUCKET)
    blob = bucket.get_blob('model.ckpt.meta')
    model_graph = blob.download_as_string()
    
    mgd = MetaGraphDef()
    mgd.ParseFromString(model_graph)
    
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(mgd)