Search code examples

Batch normalization when batch size=1

What will happen when I use batch normalization but set batch_size = 1?

Because I am using 3D medical images as training dataset, the batch size can only be set to 1 because of GPU limitation. Normally, I know, when batch_size = 1, variance will be 0. And (x-mean)/variance will lead to error because of division by 0.

But why did errors not occur when I set batch_size = 1? Why my network was trained as good as I expected? Could anyone explain it?

Some people argued that:

The ZeroDivisionError may not be encountered because of two cases. First, the exception is caught in a try catch block. Second, a small rational number is added ( 1e-19 ) to the variance term so that it is never zero.

But some people disagree. They said that:

You should calculate mean and std across all pixels in the images of the batch. (So even batch_size = 1, there are still a lot of pixels in the batch. So the reason why batch_size=1 can still work is not because of 1e-19)

I have checked the Pytorch source code, and from the code I think the latter one is right.

Does anyone have different opinion???


  • variance will be 0

    No, it won't; BatchNormalization computes statistics only with respect to a single axis (usually the channels axis, =-1 (last) by default); every other axis is collapsed, i.e. summed over for averaging; details below.

    More importantly, however, unless you can explicitly justify it, I advise against using BatchNormalization with batch_size=1; there are strong theoretical reasons against it, and multiple publications have shown BN performance degrade for batch_size under 32, and severely for <=8. In a nutshell, batch statistics "averaged" over a single sample vary greatly sample-to-sample (high variance), and BN mechanisms don't work as intended.

    Small mini-batch alternatives: Batch Renormalization -- Layer Normalization -- Weight Normalization

    Implementation details: from source code:

    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]

    Eventually, tf.nn.monents is called with axes=reduction_axes, which performs a reduce_sum to compute variance. Then, in the TensorFlow backend, mean and variance are passed to tf.nn.batch_normalization to return train- or inference-normalized inputs.

    In other words, if your input is (batch_size, height, width, depth, channels), or (1, height, width, depth, channels), then BN will run calculations over the 1, height, width, and depth dimensions.

    Can variance ever be zero? - yes, if every single datapoint for any given channel slice (along every dimension) is the same. But this should be near-impossible for real data.

    Other answers: first one is misleading:

    a small rational number is added (1e-19) to the variance

    This doesn't happen in computing variance, but it is added to variance when normalizing; nonetheless, it is rarely necessary, as variance is far from zero. Also, the epsilon term is actually defaulted to 1e-3 by Keras; it serves roles in regularizing, beyond mere avoiding zero-division.

    Update: I failed to address an important piece of intuition with suspecting variance to be 0; indeed, the batch statistics variance is zero, since there is only one statistic - but the "statistic" itself concerns the mean & variance of the channel + spatial dimensions. In other words, the variance of the mean & variance (of the single train sample) is zero, but the mean & variance themselves aren't.