Search code examples
pytorchonnxpth

How to load an ONNX file and use it to make a ML prediction in PyTorch?


Below is the source code, I use to load a .pth file and do a multi-class image classification prediction.

model = Classifier()    # The Model Class.
model.load_state_dict(torch.load('<PTH-FILE-HERE>.pth'))
model = model.to(device)
model.eval()
# prediction function to test images
def predict(img_path):
    image = Image.open(img_path)
    resize = transforms.Compose(
                    [ transforms.Resize((256,256)), transforms.ToTensor()])             
    image = resize(image)
    image = image.to(device)
    y_result = model(image.unsqueeze(0))
    result_idx = y_result.argmax(dim=1)
    print(result_idx)

I converted the .pth file to an ONNX file using torch.onnx.export.

Now, How can I write a prediction script similar to above one by using the ONNX file alone and not using the .pth file.? Is it possible to do so?


Solution

  • You can use ONNX Runtime.

    # !pip install onnx onnxruntime-gpu 
    import onnx, onnxruntime
    
    model_name = 'model.onnx'
    onnx_model = onnx.load(model_name)
    onnx.checker.check_model(onnx_model)
    
    image = Image.open(img_path)
    resize = transforms.Compose(
                    [ transforms.Resize((256,256)), transforms.ToTensor()])             
    image = resize(image)
    image = image.unsqueeze(0) # add fake batch dimension
    image = image.to(device)
    
    EP_list = ['CUDAExecutionProvider', 'CPUExecutionProvider']
    
    ort_session = onnxruntime.InferenceSession(model_name, providers=EP_list)
    
    def to_numpy(tensor):
          return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
    
    # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(image)}
    ort_outs = ort_session.run(None, ort_inputs)
    
    max = float('-inf')
    max_index = -1
    for i in range(0, len(ort_outs[0][0])):       
       if(ort_outs[0][0][i] > max):    
           max = ort_outs[0][0][i]
           max_index = i
    print(max_index)
    

    You can follow the tutorial for detailed explanation.

    Usually, the purpose of using onnx is to load the model in a different framework and run inference there e.g. PyTorch -> ONNX -> TensorRT.