Search code examples
pythonneural-networkpytorchonnxonnxruntime

PyTorch normalization in onnx model


I am doing image classification in pytorch, in that, I used this transforms

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

and completed the training. After, I converted the .pth model file to .onnx file

Now, in inference, how should I apply this transforms in numpy array, because the onnx handles input in numpy array


Solution

  • You have a couple options.

    Since normalize is pretty trivial to write yourself you could just do

    import numpy as np
    mean = np.array([0.485, 0.456, 0.406]).reshape(-1,1,1)
    std = np.array([0.229, 0.224, 0.225]).reshape(-1,1,1)
    x_normalized = (x - mean) / std
    

    which doesn't require the pytorch or torchvision libraries at all.

    If you are still using your pytorch dataset you could use the following transform

    transforms.Compose([
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        torch.Tensor.numpy  # or equivalently transforms.Lambda(lambda x: x.numpy())
    ])
    

    which will just apply the normalization to the tensor then convert it to a numpy array.