Search code examples
deep-learningneural-networkpytorchonnx

How to record the Batch Normalization layers under the ONNX format?


I need to print the structure of a neural network from a Pytorch code. I resorted to use ONNX format: I used the torch.onnx.export function, but what happens is that the Batch Normalization layers are not recorded, since they are included in the convolutional layers (see here and here, for example). After some search, I found this Q&A on StackOverflow, which seems to provide a solution to the above problem, by adding the option training=TrainingMode.TRAINING. Unfortunately, as noted also by others on the web, this option seems to not work, as such option does not seem to be recognized.

---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

<ipython-input-208-8820886c111a> in <module>
     10                   input_names = ['input'],
     11                   output_names = ['output'],
---> 12                   training=TrainingMode.TRAINING)
     13 
     14 

NameError: name 'TrainingMode' is not defined

Below I provide an example code for showing the problem, I am currently working on Colab.

import torch
import torchvision
from torch import nn

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

class NeuralNetwork(nn.Module):
  def __init__(self):
    super(NeuralNetwork, self).__init__()
    self.Network = nn.Sequential(
        nn.Conv2d(1,1,3),
        nn.BatchNorm2d(1,track_running_stats=False),
        nn.ReLU(),
    )

  def forward(self,x):
    output = self.Network(x)
    return output

model = NeuralNetwork().to(device)

print(model)

import onnx

net=NeuralNetwork()
torch.onnx.export(net, dummy_input, "test.onnx", 
                  verbose=True, 
                  export_params=True, 
                  opset_version=12,
                  do_constant_folding=True,
                  input_names = ['input'], 
                  output_names = ['output'],
                  training=TrainingMode.TRAINING)

Solution

  • You need to import definition of TrainingMode

    with import:

    from torch.onnx import TrainingMode
    

    without import:

    torch.onnx.TrainingMode.TRAINING
    

    After that you can successfully export onnx model (visualized with netron):

    enter image description here