Search code examples
pythonpytorch

Load Model into Computer Memory and Extract the Model from the Memory in Another Script


I have a Linux server and 2 python scripts: one that loads 2 PyTorch models from a checkpoint and another that does inference using the models from the first script. I want to run the first script whenever I start the Linux server, so that the models are constantly in the computer memory. And I want to run the second script whenever I receive an API call to the server. My question is: is it possible to store 2 loaded models in the computer memory and somehow access them in an inference script so that I can run the inference without loading these models from the checkpoint? I don't want to load these models in the same script I do the inference simply because it takes a lot of time to load these 2 models.

Take for example OpenAI. They handle API calls very fast, meaning they don't load their models every time they do the inference. At least this is what I think...

If it is not possible, what would you suggest doing in this situation?


Solution

  • I was able to solve this problem using Flask (as far as I know you can also do that in Django, but Flask is a bit easier). First, you need to create a url for your server so that it is accessible from the internet (like mytestserver.com). Then, on the server create a python script. In this script:

    • define Flask app (see code below);
    • load model 1, 2, etc.;
    • define a function you want to call;
    • add .route decorator to that function;
    • run the Flask app on some port (make sure that this port is accessible from the internet).

    Once your script is ready, run it on the server and you are good to go (python myscript.py). All your models will be loaded into the server memory and you will be able to do inference without loading the models every time you call the inference script.

    Here is the example code for the inference script:

    app = Flask(__name__)
    
    model_1 = # load model_1 here
    model_2 = # load model_2 here
    
    @app.route('/api/runInference', methods=['POST'])
    def inference():
      if request.method == 'POST':
        data = # load data from request
        # inference code
        response = # define response
        return response
    
    if __name__ == '__main__':
        app.run(host='0.0.0.0', port=5000)
    
    

    Make sure to call 'mytestserver.com/api/runInference' or any other url you want in order to access inference function.