Search code examples
pythonamazon-web-servicespytorchtorchserve

Extremely slow Bert inference on TorchServe for random requests


I have deployed Bert Hugging Face models via TorchServe on the AWS EC2 GPU instance. There are enough resources provisioned, usage of everything is consistently below 50%.

TorchServe performs inference on Bert models quickly, most of the time below 50ms. But there are times when it takes a ridiculously long time, sometimes, 1-10 SECONDS, a few times it was even 120s when it timed out inside TorchServe itself! Interestingly, this never happens when the load is high, but when there are few requests.

I have identified requests that took a long time or timed out and reran them. Except there was no problem and TorchServe responded within 50ms.

I have tried replicating the issue on my local setup with long-running testing scripts via k6, but to no avail. This behavior doesn't seem to happen locally.

The only thing that I can still think of, is that I am using Pytorch .jit models, and loading them into TorchServe. I have experienced performance issues with .jit before and set profiling mode and executor to False which resolved them. This may be connected.


self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Load TorchScript model
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
self.model = torch.jit.load(serialized_file)
self.model.eval()



# Using "@torch.no_grad()" decorator for python methods which call torch model

Attaching some TorchServe access logs as well as the Grafana screenshot for illustration purposes.

2022-10-28T11:40:15,240 [INFO ] W-9004-A_categorization_1.4 ACCESS_LOG - /<some IP address>:<some_port> "POST /predictions/A_categorization HTTP/1.1" 200 43
2022-10-28T11:40:15,246 [INFO ] W-9005-A_text_quality_1.4 ACCESS_LOG - /<some IP address>:<some_port> "POST /predictions/A_text_quality HTTP/1.1" 200 39
2022-10-28T11:41:04,114 [INFO ] W-9004-B_categorization_1.4 ACCESS_LOG - /<some IP address>:<some_port> "POST /predictions/B_categorization HTTP/1.1" 200 51
2022-10-28T11:41:04,119 [INFO ] W-9005-B_text_quality_1.4 ACCESS_LOG - /<some IP address>:<some_port> "POST /predictions/B_text_quality HTTP/1.1" 200 43


2022-10-28T11:39:16,713 [INFO ] W-9004-A_categorization_1.4 ACCESS_LOG - /<some IP address>:<some_port> "POST /predictions/A_categorization HTTP/1.1" 200 1032
2022-10-28T11:39:23,465 [INFO ] W-9005-A_text_quality_1.4 ACCESS_LOG - /<some IP address>:<some_port> "POST /predictions/A_text_quality HTTP/1.1" 200 7784

Grafana screenshot

Additional project information:

Nvidia-smi:

NVIDIA-SMI 470.57.02    
Driver Version: 470.57.02    
CUDA Version: 11.4

Docker image:

pytorch/torchserve:0.6.0-gpu

PIP requirements:

torch==1.9.1
transformers==3.1.0
pydantic==1.9.1
boto3==1.24.8
loguru==0.6.0
APScheduler==3.9.1

TorchServe config.properties:

load_models=all
async_logging=true

models={\
  "categorization": {\
    "1.4": {\
        "defaultVersion": true,\
        "marName": "categorization.mar",\
        "minWorkers": 1,\
        "maxWorkers": 1,\
        "batchSize": 1,\
        "maxBatchDelay": 100\
    }\
  },\
  "text_quality": {\
    "1.4": {\
        "defaultVersion": true,\
        "marName": "text_quality.mar",\
        "minWorkers": 1,\
        "maxWorkers": 1,\
        "batchSize": 1,\
        "maxBatchDelay": 100\
    }\
  }\
}

Additionally, how can I turn off logging for TorchServe health checks?


Solution

  • The problem was a single extremely long request (~1.5MB of text) which the tokenizer processed for more than 120 seconds. The worker was consequently killed by TorchServe for failing to respond to health checks for more than 120 seconds (this is the default value). It was hard to detect why the worker died because requests with the normal text length also failed until a new worker was brought back online, giving the impression that TS was failing for no good reason. We finally caught this long request when we integrated Sentry into our microservice architecture.

    Solution: We reduced the worker processing time by limiting the text length to a maximum of 10,000 characters per field. In our case, this still gives good results. Check if this is true for you before implementing this solution.

    def limit_string_length(text: str, max_length: int = 10000) -> str:
        return text[:max_length]
    

    Another misleading thing was a tokenizer max_length attribute.

    encoding = self.tokenizer.encode_plus(
        f"{request.title} {request.description}",
        add_special_tokens=True,
        max_length=160,
        return_token_type_ids=False,
        return_attention_mask=True,
        return_tensors="pt",
        truncation=True,
    )
    

    We set this to 160 wrongly assuming it would limit the text max length before tokenizing the request. This is NOT the case. The tokenizer tokenizes all text FIRST and then limits the output to 160 characters.