Search code examples
pythonopencvpytorchonnx

How to convert PyTorch graph to ONNX and then inference from OpenCV?


I'm attempting to convert a PyTorch graph to ONNX via the torch.onnx.export function and then use the OpenCV functions blobFromImage, setInput, and forward to inference the converted graph. I think I'm on the right track but I keep running into errors and there are very few helpful examples of how to do this that I could find.

I realize the general stack overflow policy is to post only relevant portions of code, however with the errors I'm getting it seems this is a case where the devil is in the details so I suspect I'll have to post a full example to make the cause of the errors clear.

Here is my training net (pretty standard for MNIST):

# MnistNet.py

# Net Layout:
# batchSize x 1 x 28 x 28
#     conv1 Conv2d(1, 6, 5)
# batchSize x 6 x 24 x 24
#     relu(x)
#     max_pool2d(x, kernel_size=2)
# batchSize x 6 x 12 x 12
#     conv2 Conv2d(6, 16, 5)
# batchSize x 16 x 8 x 8
#     relu(x)
#     max_pool2d(x, kernel_size=2)
# batchSize x 16 x 4 x 4
#     view(-1, 16 * 4 * 4)    Note: 16 * 4 * 4 = 256
# batchSize x 1 x 256
#     fc1 Linear(256, 120)
#     relu(x)
# batchSize x 1 x 120
#     fc2 Linear(120, 84)
#     relu(x)
# batchSize x 1 x 84
#     fc3 Linear(84, 10)
# batchSize x 1 x 10

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

class MnistNet(nn.Module):

    TRANSFORM = torchvision.transforms.Compose([
        torchvision.transforms.Resize((28, 28)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.5], [0.5])
    ])

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    # end function

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), kernel_size=2)
        x = F.max_pool2d(F.relu(self.conv2(x)), kernel_size=2)
        x = x.view(-1, 256)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    # end function

# end class

Here is my training script (again, pretty standard for MNIST):

# 1_train.py

from MnistNet import MnistNet

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision

from termcolor import colored

BATCH_SIZE = 64
NUM_EPOCHS = 10

GRAPH_NAME = 'MNIST.pt'

def main():
    trainDataset = torchvision.datasets.MNIST('built_in_mnist_download', train=True, transform=MnistNet.TRANSFORM, download=True)

    trainDataLoader = DataLoader(trainDataset, batch_size=BATCH_SIZE, shuffle=True)

    # declare net, loss function, and optimizer
    mnistNet = MnistNet()
    lossFunction = nn.CrossEntropyLoss()
    optimizer = optim.Adam(mnistNet.parameters())

    # get device (cuda or cpu)
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(colored('using cuda', 'green'))
    else:
        device = torch.device('cpu')
        print(colored('GPU does not seem to be available, using CPU', 'red'))
    # end if

    # set network to device
    mnistNet.to(device)

    # set network to train mode
    mnistNet.train()

    print('beginning training . . .')

    # for each epoch . . .
    for epoch in range(1, NUM_EPOCHS+1):

        # variables to calculate loss and accuracy within the epoch
        epochLosses = []
        epochAccuracies = []

        # for each batch . . .
        for i, element in enumerate(trainDataLoader):
            # break out the input images and labels, note these are Tensors
            inputImages, labels = element

            inputImages = inputImages.to(device)
            labels = labels.to(device)

            # clear gradients from the previous step
            optimizer.zero_grad()

            # get net output
            outputs = mnistNet(inputImages)
            # calculate loss
            loss = lossFunction(outputs, labels)
            # call backward() to compute gradients
            loss.backward()
            # update parameters using gradients
            optimizer.step()

            # append the current classification loss to the list of epoch losses
            epochLosses.append(loss.item())

            # calculate current classification accuracy

            # get the highest scoring classification for each prediction
            _, predictions = torch.max(outputs.data, 1)

            # number of labels and predictions should always be the same, log an error if this is not the case
            if labels.size(0) != predictions.size(0):
                print(colored('ERROR: labels.size(0) != predictions.size(0)', 'red'))
            # end if

            # determine the number of correct predictions for the current batch
            correctPredictions = 0
            for j in range(len(labels)):
                if predictions[j].item() == labels[j].item():
                    correctPredictions += 1
                # end if
            # end for

            # append the current batch accuracy to the list of accuracies
            epochAccuracies.append(correctPredictions / labels.size(0))
        # end for

        # calculate epoch loss and accuracy from the respective lists
        epochLoss = sum(epochLosses) / len(epochLosses)
        epochAccuracy = sum(epochAccuracies) / len(epochAccuracies)

        print('epoch ' + str(epoch) + ', epochLoss = ' + '{:.4f}'.format(epochLoss) +
              ', epochAccuracy = ' + '{:.4f}'.format(epochAccuracy * 100) + '%')
    # end for

    print('finished training')

    # save the model
    torch.save(mnistNet.state_dict(), GRAPH_NAME)

    print('saved graph as ' + str(GRAPH_NAME))

# end function

if __name__ == '__main__':
    main()

Here is my best attempt so far at a script to convert a saved graph from PyTorch to ONNX (I'm not sure if this is correct, I can at least say it runs without error):

# 3_convert_graph_to_onnx.py

from MnistNet import MnistNet

import torch

GRAPH_NAME = 'MNIST.pt'
ONNX_GRAPH_NAME = 'MNIST.onnx'

def main():

    net = MnistNet()
    net.load_state_dict(torch.load(GRAPH_NAME))

    net.eval()

    # make a dummy input with a batch size of 1, 1 channel, 28 x 28
    dummyInput = torch.randn(10, 1, 28, 28)

    torch.onnx.export(net, dummyInput, ONNX_GRAPH_NAME, verbose=True)

# end function

if __name__ == '__main__':
    main()

Here is my attempt to inference the ONNX graph with OpenCV (Note that PyTorch is included, but is only used to load the test MNIST dataset, and the images are converted to OpenCV format before inferencing):

# 4_onnx_opencv_inf.py

from MnistNet import MnistNet

import torchvision

import cv2
import numpy as np
from termcolor import colored

ONNX_GRAPH_NAME = 'MNIST.onnx'

def main():
    testDataset = torchvision.datasets.MNIST('built_in_mnist_download', train=False, transform=MnistNet.TRANSFORM, download=True)

    labels = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9' ]

    net = cv2.dnn.readNetFromONNX(ONNX_GRAPH_NAME)

    # test on 3 images
    for i in range(3):
        # get PyTorch tensor image and ground truth index from dataset
        ptImage, gndTrIdx = testDataset[i]
        # convert to PIL image
        pilImage = torchvision.transforms.ToPILImage()(ptImage)
        # convert to OpenCV image, would convert RGB to BGR here if image was color
        openCvImage = np.array(pilImage)

        gndTr = labels[gndTrIdx]

        # can show OpenCV image here if desired
        # cv2.imshow('openCvImage', openCvImage)
        # cv2.waitKey()

        blob = cv2.dnn.blobFromImage(image=openCvImage, scalefactor=1.0/255.0, size=(64, 64))

        net.setInput(blob)
        preds = net.forward()

        predIdx = np.array(preds)[0].argmax()

        prediction = str(predIdx)
        if prediction == gndTr:
            print(colored('i = ' + str(i) + ', predIdx = ' + str(predIdx) + ', gndTrIdx = ' + str(gndTrIdx) + ', correct answer', 'green'))
        else:
            print(colored('i = ' + str(i) + ', predIdx = ' + str(predIdx) + ', gndTrIdx = ' + str(gndTrIdx) + ', incorrect answer', 'red'))
        # end if

    # end for

# end function

if __name__ == '__main__':
    main()

Currently this final script crashes with this error:

$ python3 4_onnx_opencv_inf.py 
[ERROR:0] global /tmp/pip-req-build-99ib2vsi/opencv/modules/dnn/src/dnn.cpp (3441) getLayerShapesRecursively OPENCV/DNN: [Reshape]:(18): getMemoryShapes() throws exception. inputs=1 outputs=1/1 blobs=0
[ERROR:0] global /tmp/pip-req-build-99ib2vsi/opencv/modules/dnn/src/dnn.cpp (3447) getLayerShapesRecursively     input[0] = [ 1 16 13 13 ]
[ERROR:0] global /tmp/pip-req-build-99ib2vsi/opencv/modules/dnn/src/dnn.cpp (3451) getLayerShapesRecursively     output[0] = [ 1 256 ]
[ERROR:0] global /tmp/pip-req-build-99ib2vsi/opencv/modules/dnn/src/dnn.cpp (3457) getLayerShapesRecursively Exception message: OpenCV(4.4.0) /tmp/pip-req-build-99ib2vsi/opencv/modules/dnn/src/layers/reshape_layer.cpp:154: error: (-1:Backtrace) Can't infer a dim denoted by -1 in function 'computeShapeByReshapeMask'

Traceback (most recent call last):
  File "4_onnx_opencv_inf.py", line 54, in <module>
    main()
  File "4_onnx_opencv_inf.py", line 38, in main
    preds = net.forward()
cv2.error: OpenCV(4.4.0) /tmp/pip-req-build-99ib2vsi/opencv/modules/dnn/src/layers/reshape_layer.cpp:154: error: (-1:Backtrace) Can't infer a dim denoted by -1 in function 'computeShapeByReshapeMask'

I'm not really sure what to do next based on this error, can anybody please advise on this? I suspect that I'm at least doing the procedure generally correctly and missing a few small details.


Solution

  • I was using the wrong size in the ONNX inference script.

    In 4_onnx_opencv_inf.py changing:

    blob = cv2.dnn.blobFromImage(image=openCvImage, scalefactor=1.0/255.0, size=(64, 64))
    

    to

    blob = cv2.dnn.blobFromImage(image=openCvImage, scalefactor=1.0/255.0, size=(28, 28))
    

    makes it run (I'm using Ubuntu 20.04 and PyTorch 1.7.0), however the accuracy is worse. With regular PyTorch inferencing as above (2nd script) I'm getting 98.5% accuracy, with the OpenCV ONNX version I'm getting 95% accuracy.

    I suspect the difference is due to the parameters in cv2.dnn.blobFromImage not being set to handle the normalization correctly, but that is a different post entirely.