Search code examples
pythontensorflowbatch-normalization

Which tensorflow batch norm code is used when input rank is 4?


I am using slim.batch_norm from layers and trying to understand the code flow in my use case. It looks to me like the logic that decides whether to use _fused_batch_norm() or the base class will only use the _fused_batch_norm() in my case if the input rank is 2. The code description sounds like it should also be used if rank is 4 and the function itself (_fused_batch_norm()) supports rank of 4, but the logic seems to prevent calling it. Below is a snippet of the code showing what I am referring to:

  # Only use _fused_batch_norm (1) if fused is set True or if it is
  # possible to use (currently it doesn't support batch weights,
  # renorm, and the case when rank is neither 2 nor 4),
  # and (2) if used with zero_debias_moving_mean, or an input shape of rank 2,
  # or non-default updates_collections (not implemented in
  # normalization_layers.BatchNormalization yet); otherwise use the fused
  # implementation in normalization_layers.BatchNormalization.
  inputs = ops.convert_to_tensor(inputs)
  rank = inputs.get_shape().ndims
  feature_supported = batch_weights is None and not renorm and rank in [2, 4]
  possible_to_fuse = fused is None and feature_supported
  if (fused or possible_to_fuse) and (
      zero_debias_moving_mean or rank == 2 or
      updates_collections is not ops.GraphKeys.UPDATE_OPS):
      return _fused_batch_norm(...)

For my use case, I have the following parameters all at default settings:

batch_weights=None
fused=False
renorm=False
zero_debias_moving_mean=False
updates_collections=ops.GraphKeys.UPDATE_OPS

If my input is rank 4, it looks like the code will use the fused implementation in normalization_layers.BatchNormalization Is my understanding of the logic correct?

Is this the expected and proper behavior? I am wondering if the the condition rank==2 should actually be rank in [2,4]? If the latter is correct, then this would be a potential bug. If the original is correct, then why have rank in [2,4] for determining feature_supported ?


Solution

  • You are right, it's a bug. When rank=4 and fused=None (or True), an optimized _fused_batch_norm can and should be used. This agrees with tf.nn.fused_batch_norm.

    Looks like they mixed up the logic expression, which should trigger if possible_to_fuse=True, no matter what everything else is. Moreover, if feature_supported=True and not fused=False, _fused_batch_norm is eligible as well.

    You should report it to tensorflow issue tracker.