A deep neural network f consists of one fully connected network and one batch normalization layer. (e.g., f = FCN + BN)
Given a dataset with inputs x and y, can the following property hold?
f(x+y) = f(x) + f(y)
I think the fully connected network naturally preserves linearity. As for batch normalization, since x, y and x+y are normalized using the mean and standard deviation of the entire batch, they undergo normalization with the same mean and standard deviation. As a result, batch normalization also appears to be linear.
Is there a case where batch normalization preserves linearity, or does it introduce non-linearity that prevents this property from holding?
Very short answer: No, batch normalization cannot be considered a linear transformation.
Short answer: You have to be careful in assuming linearity for your described model, but not only for the reason that you mentioned. In short,
Combining the criteria of additivity and homogeneity, we can call a function f linear if and only if it fulfills the following criterion:
f(α·x + β·y) = α·f(x) + β·f(y).
A fully connected layer l without a bias term (e.g. torch.nn.Linear(…, bias=False)
) is defined as
l(x) = xWT.
It is linear, since
l(α·x + β·y) = (α·x + β·y)WT = α·xWT + β·yWT = α·l(x) + β·l(y).
A fully connected layer l with a bias term (e.g. torch.nn.Linear(…, bias=True)
) is defined as
l(x) = xWT + b.
It is not linear, since
l(α·x + β·y) = (α·x + β·y)WT + b,
while
α·l(x) + β·l(y) = α·(xWT + b) + β·(yWT + b) = (α·x + β·y)WT + (α + β)·b.
This might be confusing, since the very name of the fully connected layer, Linear
, implies differently. However, its name stems from a definition of "linear" that is different from the one we use here.
We can similarly examine the batch norm layer, and we will find that even without its affine parameters, and thus without a bias term (e.g. torch.nn.BatchNorm1d(…, affine=False)
), it is not linear:
The batch norm layer n without its affine parameters (i.e. without the learned scaling parameters that are called β and γ in the Wikipedia article on batch normalization) is defined as
n(x) = (x - μ) / σ.
It is not linear since
n(α·x + β·y) = (α·x + β·y - μ) / σ,
while
α·n(x) + β·n(y) = α·(x - μ) / σ + β·(y - μ) / σ = (α·x + β·y - (α + β)·μ) / σ.
The same argument applies with the affine parameters enabled (e.g. torch.nn.BatchNorm1d(…, affine=True)
).
Here is a basic example that demonstrates the results from above in code:
import torch
torch.manual_seed(42)
batch_size, input_size, output_size = 4, 3, 1
batch_x, batch_y = (torch.randn(batch_size, input_size) for _ in range(2))
class Model(torch.nn.Module):
def __init__(self, input_size, output_size, use_bias, use_bn_mean):
super().__init__()
self.use_bias = use_bias
self.use_bn_mean = use_bn_mean
self.fc = torch.nn.Linear(input_size, output_size, bias=use_bias)
self.bn = torch.nn.BatchNorm1d(output_size)
# Init mean / var with non-identity values (i.e mean≠0, var≠1)
torch.nn.init.constant_(self.bn.running_mean, 4.2 if use_bn_mean else 0.)
torch.nn.init.constant_(self.bn.running_var, 1.5)
# Init affine params with non-identity values (i.e bias≠0, weight≠1)
torch.nn.init.constant_(self.bn.bias, 1.3 if use_bias else 0.)
torch.nn.init.constant_(self.bn.weight, 2.5)
def forward(self, x): return self.bn(self.fc(x))
with torch.no_grad():
for use_bias in (True, False):
for use_bn_mean in (True, False):
model = Model(input_size, output_size, use_bias, use_bn_mean)
model.eval()
out_sum_before = model(.5 * batch_x + .3 * batch_y)
out_sum_after = .5 * model(batch_x) + .3 * model(batch_y)
is_linear = (out_sum_after - out_sum_before).abs().max() < 1e-5
print(f"Model is {'' if is_linear else 'not '}linear for "
f"{model.use_bias=}, {model.use_bn_mean=}.")
This prints:
Model is not linear for model.use_bias=True, model.use_bn_mean=True.
Model is not linear for model.use_bias=True, model.use_bn_mean=False.
Model is not linear for model.use_bias=False, model.use_bn_mean=True.
Model is linear for model.use_bias=False, model.use_bn_mean=False.
That is: only if we suppress all bias terms, as well as the mean of the batch norm (which, I guess, defies the purpose of using batch normalization), our network will behave as a linear function. Also, this is only true during inference, when the batch norm layer uses the same, global standard deviation value for all batches.