Search code examples
pythonnumpydeep-learningpytorchneural-network

Faster way to use saved Pytorch model (bypassing import torch?)


I'm using a Slurm Workload Manager on a server and import torch takes around 30-40 seconds. The IT people running it said they couldn't do much to improve it and it was just hardware related (maybe they missed something? but i've gone through the internet before asking them and couldn't find much either). By comparison, import numpy takes around 1 second.

I would like to know if there is a way to use the saved weights of a pytorch model to ONLY predict an output with a given input without importing torch (so no need to import everything related to gradients, etc ...). Theoretically, it is just matrix multiplications (I think?) so it probably is feasible by only using numpy? I need to do this several times on different jobs so I cannot cache / pass around the imported torch which is why I'm actively looking for a solution (but generally speaking taking something from 30-40 seconds to a few is pretty cool anyway).

If that matters, here is the architecture of my model:

ActionNN(
  (conv_1): Conv2d(5, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv_2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv_3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (norm_layer_1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (norm_layer_2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (norm_layer_3): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  (mlp): Sequential(
    (0): Linear(in_features=71, out_features=128, bias=True)
    (1): ReLU()
  )
  (layer): Linear(in_features=128, out_features=210, bias=True)
  (layer): Linear(in_features=128, out_features=210, bias=True)
  (layer): Linear(in_features=128, out_features=210, bias=True)
  (layer): Linear(in_features=128, out_features=210, bias=True)
  (layer): Linear(in_features=128, out_features=28, bias=True)
  (layer): Linear(in_features=128, out_features=28, bias=True)
  (layer): Linear(in_features=128, out_features=28, bias=True)
  (sigmoid): Sigmoid()
  (tanh): Tanh()
)
Number of parameters: 152284

If it was only fully connected layers, it would be "pretty easy" but because my network is a tiny bit more complex, I'm not sure how I should do it.

I saved the parameters using torch.save(my_network.state_dict(), my_path).

Since my script takes in total on average 35 seconds (import torch included), I would be able to run it in on average a second or two, which would be great.

Here is my profiling of import torch:

         1226310 function calls (1209639 primitive calls) in 49.994 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1273   21.590    0.017   21.590    0.017 {method 'read' of '_io.BufferedReader' objects}
     5276   12.145    0.002   12.145    0.002 {built-in method posix.stat}
     1273    7.427    0.006    7.427    0.006 {built-in method io.open_code}
    45/25    5.631    0.125    9.939    0.398 {built-in method _imp.create_dynamic}
        2    0.564    0.282    0.564    0.282 {built-in method _ctypes.dlopen}
     1273    0.288    0.000    0.288    0.000 {built-in method marshal.loads}
       17    0.286    0.017    0.286    0.017 {method 'readline' of '_io.BufferedReader' objects}
2809/2753    0.098    0.000    0.546    0.000 {built-in method builtins.__build_class__}
   1620/1    0.062    0.000   49.997   49.997 {built-in method builtins.exec}
    50145    0.051    0.000    0.119    0.000 {built-in method builtins.getattr}
     1159    0.048    0.000    0.115    0.000 inspect.py:3245(signature)
      424    0.048    0.000    0.113    0.000 assumptions.py:596(__init__)
       13    0.039    0.003    0.039    0.003 {built-in method io.open}
     1411    0.035    0.000    0.045    0.000 library.py:71(impl)
     1663    0.034    0.000   12.209    0.007 <frozen importlib._bootstrap_external>:1536(find_spec)

Solution

  • There is an easy way to save up on import time, it's to spin up a server, and import torch only once at start up and load the model once only.

    Use Flask or better yet FastAPI, and spin up a simple HTTP server that will run the script on an HTTP call.

    The server will take 40 seconds to start, but then any inference call will take just the time to connect and run inference.

    from fastapi import Request, FastAPI
    import torch
    
    model = torch.load(<yourmodel here>)
    
    app = FastAPI()
    
    @app.post("/predict")
    async def inference(request: Request):
        input = request.json()
        prediction = model.predict(input)
        return {"predictions": predictions}
    
    

    call the server with whatever client by posting data to http://<your-host>:<post>/predict

    see https://fastapi.tiangolo.com/tutorial/first-steps/ for more details.