Search code examples
pythondeep-learningpytorchtensorrttensorrt-python

Inference speed isn't improved with tensor-rt compared to regular cuda


I'm trying to use the tensor-rt framework to enhance the inference speed of my deep learning model. I've created a very simple python code to test tensor-rt with pytorch.

import torch
import argparse
import time
import numpy as np
import torch_tensorrt

# Define a simple PyTorch model
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = torch.nn.Linear(64 * 16 * 16, 512)
        self.relu3 = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(512, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool(x)
        x = x.view(-1, 64 * 16 * 16)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        return x

def compute(use_tensorrt=False):
    force_cpu = False
    useCuda = torch.cuda.is_available() and not force_cpu
    if useCuda:
        print('Using CUDA.')
        dtype = torch.cuda.FloatTensor
        ltype = torch.cuda.LongTensor
        device = torch.device("cuda:0")
    else:
        print('No CUDA available.')
        dtype = torch.FloatTensor
        ltype = torch.LongTensor
        device = torch.device("cpu")

    model = MyModel()

    input_shape = (8192, 3, 32, 32)

    if use_tensorrt:
        model = torch.compile(
            model,
            backend="torch_tensorrt",
            options={
                "truncate_long_and_double": True,
                "precision": dtype,
                "workspace_size" : 20 << 30
            },
            dynamic=False,
        )

    model = model.to(device)
    model.eval()

    num_iterations = 100
    total_time = 0.0
    with torch.no_grad():
        input_data = torch.randn(input_shape).to(device).type(dtype)
        #warmup
        for i in range(100):
            output_data = model(input_data)

        for i in range(num_iterations):
            start_time = time.time()
            output_data = model(input_data)
            end_time = time.time()
            total_time += end_time - start_time
    pytorch_fps = num_iterations / total_time
    print(f"PyTorch FPS: {pytorch_fps:.2f}")

if __name__ == "__main__":
    print("Without TensorRT")
    compute()
    print("With TensorRT")
    compute(use_tensorrt=True)

Unfortunately, when I run this code, I get approximately the same FPS with tensor-rt and without, which is ~14.2, even with a significant warmup. Does anyone know what could be the issue ? Is there something I'm missing ?

Here are some more information about my setup:

libraries:

torch 2.0.1
torch_tensorrt 1.4.0

GPU:

nvcc: NVIDIA (R) Cuda compiler driver
Cuda compilation tools, release 11.5, V11.5.119
Build cuda_11.5.r11.5/compiler.30672275_0

Solution

  • As torch.compile was first released in PyTorch 2.0, most of it was still experimental and wasn't documented very thoroughly.

    I have absolutely no idea why your implementation doesn't work, but you get the desired improvement by using torch_tensorrt.compile instead of torch.compile.

        input_shape = (8192, 3, 32, 32)
        inputs = [torch.randn(input_shape).to(device)]
    
        if use_tensorrt:
            model = torch_tensorrt.compile(
                model,
                inputs=inputs,
                workspace_size = 20 << 30,
                enabled_precisions = {torch.float},
            )
    

    On my system with the same library versions, this yields a 3x improvement

    >>> Without TensorRT
    >>> Using CUDA.
    >>> PyTorch FPS: 12.75
    >>> With TensorRT
    >>> Using CUDA.
    >>> PyTorch FPS: 42.96
    

    If I had to guess, it probably has something to do with torch.compile failing with the options provided, and defaulting to some other IR.