Search code examples
pytorchderivativeautograd

Derivative of BatchNorm2d in PyTorch


In my network, I want to calculate the forward pass and backward pass of my network both in the forward pass. For this, I have to manually define all the backward pass methods of the forward pass layers.
For the activation functions, that's easy. And also for the linear and conv layers, it worked well. But I'm really struggling with BatchNorm. As the BatchNorm paper only discusses the 1D case: So far, my implementation looks like this:

def backward_batchnorm2d(input, output, grad_output, layer):
    gamma = layer.weight
    beta = layer.bias
    avg = layer.running_mean
    var = layer.running_var
    eps = layer.eps
    B = input.shape[0]

    # avg, var, gamma and beta are of shape [channel_size]
    # while input, output, grad_output are of shape [batch_size, channel_size, w, h]
    # for my calculations I have to reshape avg, var, gamma and beta to [batch_size, channel_size, w, h] by repeating the channel values over the whole image and batches

    dL_dxi_hat = grad_output * gamma
    dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
    dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
    dL_dxi = dL_dxi_hat / torch.sqrt(var + eps) + 2.0 * dL_dvar * (input - avg) / B + dL_davg / B # dL_dxi_hat / sqrt()
    dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True)
    dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
    return dL_dxi, dL_dgamma, dL_dbeta

When I check my gradients with torch.autograd.grad() I notice that dL_dgamma and dL_dbeta are correct, but dL_dxi is incorrect, (by a lot). But I can't find my mistake. Where is my mistake?

For reference, here is the definition of BatchNorm:

enter image description here

And here are the formulas for the derivatives for the 1D case:enter image description here


Solution

  • def backward_batchnorm2d(input, output, grad_output, layer):
        gamma = layer.weight
        gamma = gamma.view(1,-1,1,1) # edit
        # beta = layer.bias
        # avg = layer.running_mean
        # var = layer.running_var
        eps = layer.eps
        B = input.shape[0] * input.shape[2] * input.shape[3] # edit
    
        # add new
        mean = input.mean(dim = (0,2,3), keepdim = True)
        variance = input.var(dim = (0,2,3), unbiased=False, keepdim = True)
        x_hat = (input - mean)/(torch.sqrt(variance + eps))
        
        dL_dxi_hat = grad_output * gamma
        # dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True) 
        # dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
        dL_dvar = (-0.5 * dL_dxi_hat * (input - mean)).sum((0, 2, 3), keepdim=True)  * ((variance + eps) ** -1.5) # edit
        dL_davg = (-1.0 / torch.sqrt(variance + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + (dL_dvar * (-2.0 * (input - mean)).sum((0, 2, 3), keepdim=True) / B) #edit
        
        dL_dxi = (dL_dxi_hat / torch.sqrt(variance + eps)) + (2.0 * dL_dvar * (input - mean) / B) + (dL_davg / B) # dL_dxi_hat / sqrt()
        # dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True) 
        dL_dgamma = (grad_output * x_hat).sum((0, 2, 3), keepdim=True) # edit
        dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
        return dL_dxi, dL_dgamma, dL_dbeta
    
    1. Because you didn't upload your forward snipcode, so if your gamma has the shape size is 1, you need to reshape it to [1,gamma.shape[0],1,1].
    2. The formula follows 1D where the scale factor is the sum of the batch size. However, in 2D, the summation should between 3 dimensions, so B = input.shape[0] * input.shape[2] * input.shape[3].
    3. The running_mean and running_var only use in test/inference mode, we don't use them in training (you can find it in the paper). The mean and variance you need are computed from the input, you can store the mean, variance and x_hat = (x-mean)/sqrt(variance + eps) into your object layer or re-compute as I did in the code above # add new. Then replace them with the formula of dL_dvar, dL_davg, dL_dxi.
    4. your dL_dgamma should be incorrect since you multiplied the gradient of output by itself, it should be modified to grad_output * x_hat.