Search code examples
flaskgoogle-cloud-platformpytorchgoogle-cloud-rungpt-2

Flask app serving GPT2 on Google Cloud Run not persisting downloaded files?


I have a Flask app running on Google Cloud Run, which needs to download a large model (GPT-2 from huggingface). This takes a while to download, so I am trying to set up so that it only downloads on deployment and then just serves this up for subsequent visits. That is I have the following code in a script that is imported by my main flask app app.py:

import torch
# from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AutoTokenizer, AutoModelWithLMHead
# Disable gradient calculation - Useful for inference
torch.set_grad_enabled(False)

# Check if gpu or cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load tokenizer and model
try:
    tokenizer = AutoTokenizer.from_pretrained("./gpt2-xl")
    model = AutoModelWithLMHead.from_pretrained("./gpt2-xl")
except Exception as e:

    print('no model found! Downloading....')
    
    AutoTokenizer.from_pretrained('gpt2').save_pretrained('./gpt2-xl')
    AutoModelWithLMHead.from_pretrained('gpt2').save_pretrained('./gpt2-xl')
    tokenizer = AutoTokenizer.from_pretrained("./gpt2-xl")
    model = AutoModelWithLMHead.from_pretrained("./gpt2-xl")

model = model.to(device)

This basically tries to load the the downloaded model, and if that fails it downloads a new copy of the model. I have autoscaling set to a minimum of 1 which I thought would mean something would always be running and therefore the downloaded file would persist even after activity. But it keeps having to redownload the model which freezes up the app when some people try to use it. I am trying to recreate something like this app https://text-generator-gpt2-app-6q7gvhilqq-lz.a.run.app/ which does not appear to have the same load time issue . In the flask app itself I have the following:

@app.route('/')
@cross_origin()
def index():
    prompt = wp[random.randint(0, len(wp)-1)]
    res = generate(prompt, size=75)
    generated = res.split(prompt)[-1] + '\n \n...TO BE CONTINUED'
    #generated = prompt
    return flask.render_template('main.html', prompt = prompt, output = generated)

if __name__ == "__main__":
    app.run(host='0.0.0.0',
            debug=True,
            port=PORT)

But it seems to redownload the models every few hours...how can I avoid having the app re-downloading the models and the app freezing for those who want to try it?


Solution

  • Data written to the filesystem does not persist when the container instance is stopped.

    Cloud Run lifetime is the time between an HTTP Request and the HTTP response. Overlapped requests extend this lifetime. Once the final HTTP response is sent your container can be stopped.

    Cloud Run instances can run on different hardware (clusters). One instance will not have the same temporary data as another instance. Instances can be moved. Your strategy of downloading a large file and saving it to the in-memory file system will not work consistently.

    Filesystem access

    Also note that the file system is in-memory, which means you need to have additional memory to store files.