Search code examples
pythondeep-learningpytorchonnxbatch-normalization

Batchnorms force set to training mode on torch.onnx.export when running stats are None


As described in this git issue (very complete description), I'm trying to load an .onnx model in openvino backend. However, the BatchNorm layers are considered in training mode when setting track_running_stats=False. Here is how I do this before converting my torch model to onnx:

model.eval()

for child in model.children():
    if type(child)==nn.BatchNorm2d:
        child.track_running_stats = False
        child.running_mean = None
        child.running_var = None

Then, I export the model to onnx :

dummy_input = torch.randn(1, 3, 200, 200, requires_grad=True)  
torch.onnx.export(model, dummy_input, model_path,  export_params=True, opset_version=16, training=torch.onnx.TrainingMode.PRESERVE)

Finally, I got this error when loading it in openvino :

Error: Check '(node.get_outputs_size() == 1)' failed at src/frontends/onnx/frontend/src/op/batch_norm.cpp:67:
While validating ONNX node '<Node(BatchNormalization): BatchNormalization_10>':
Training mode of BatchNormalization is not supported.

As mentioned in the git issue, I have tried to look at the BatchNorm inputs/outputs :

for node in onnx_model.graph.node:
    if any(("BatchNorm" in s or "bn" in s) for s in node.input) or any(("BatchNorm" in s or "bn" in s) for s in node.output):
        print('Node:',node.name)
        print(node)

So you can see those nodes are related to BN :

Node: ReduceMean_5
input: "onnx::ReduceMean_22"
output: "onnx::BatchNormalization_23"
name: "ReduceMean_5"
op_type: "ReduceMean"
attribute {
  name: "axes"
  ints: 0
  ints: 1
  type: INTS
}
attribute {
  name: "keepdims"
  i: 0
  type: INT
}

Node: ReduceMean_9
input: "onnx::ReduceMean_26"
output: "onnx::BatchNormalization_27"
name: "ReduceMean_9"
op_type: "ReduceMean"
attribute {
  name: "axes"
  ints: 0
  ints: 1
  type: INTS
}
attribute {
  name: "keepdims"
  i: 0
  type: INT
}

Node: BatchNormalization_10
input: "input"
input: "bn1.weight"
input: "bn1.bias"
input: "onnx::BatchNormalization_23"
input: "onnx::BatchNormalization_27"
output: "input.4"
output: "29"
output: "30"
name: "BatchNormalization_10"
op_type: "BatchNormalization"
attribute {
  name: "epsilon"
  f: 9.999999747378752e-06
  type: FLOAT
}
attribute {
  name: "momentum"
  f: 0.8999999761581421
  type: FLOAT
}
attribute {
  name: "training_mode"
  i: 1
  type: INT
}

Node: BatchNormalization_13
input: "input.8"
input: "bn2.0.weight"
input: "bn2.0.bias"
input: "bn2.0.running_mean"
input: "bn2.0.running_var"
output: "input.12"
name: "BatchNormalization_13"
op_type: "BatchNormalization"
attribute {
  name: "epsilon"
  f: 9.999999747378752e-06
  type: FLOAT
}
attribute {
  name: "momentum"
  f: 0.8999999761581421
  type: FLOAT
}
attribute {
  name: "training_mode"
  i: 0
  type: INT
}

Node: ReduceMean_18
input: "onnx::ReduceMean_37"
output: "onnx::BatchNormalization_38"
name: "ReduceMean_18"
op_type: "ReduceMean"
attribute {
  name: "axes"
  ints: 0
  ints: 1
  type: INTS
}
attribute {
  name: "keepdims"
  i: 0
  type: INT
}

Node: ReduceMean_22
input: "onnx::ReduceMean_41"
output: "onnx::BatchNormalization_42"
name: "ReduceMean_22"
op_type: "ReduceMean"
attribute {
  name: "axes"
  ints: 0
  ints: 1
  type: INTS
}
attribute {
  name: "keepdims"
  i: 0
  type: INT
}

Node: BatchNormalization_23
input: "input.16"
input: "bn3.weight"
input: "bn3.bias"
input: "onnx::BatchNormalization_38"
input: "onnx::BatchNormalization_42"
output: "43"
output: "44"
output: "45"
name: "BatchNormalization_23"
op_type: "BatchNormalization"
attribute {
  name: "epsilon"
  f: 9.999999747378752e-06
  type: FLOAT
}
attribute {
  name: "momentum"
  f: 0.8999999761581421
  type: FLOAT
}
attribute {
  name: "training_mode"
  i: 1
  type: INT
}

And that you can indeed see that 2/3 BN layers are in training mode = 1 (-> True). How to deal with it in order for onnx to consider them in eval mode, while keeping the track_running_stats=False ?

I'm not very familiar with Onnx and more globally a beginner in DL so I'd enjoy any advices !


Solution

  • I finally have found a solution based on a suggestion made on the GitHub issue linked in the question.

    So, after setting track_running_stats to False, the BatchNormalization layers are considered in training mode as you can see in the Onnx graph.

    I have deleted directly into the graph the unused outputs referred to the mean and var in the batch normalization layers, then setting manually the layer into eval mode (training_mode = 0). You must delete the unused outputs and not only set the training_mode attribute to 0 because otherwise the check will not be passed.

    for node in onnx_model.graph.node:
        if node.op_type == "BatchNormalization":
            for attribute in node.attribute:
                if attribute.name == 'training_mode':
                    if attribute.i == 1:
                        node.output.remove(node.output[1])
                        node.output.remove(node.output[1])
                    attribute.i = 0
    

    After that, I'm able to run inferences correctly and to have the expected results.