Search code examples
pytorchbackpropagation

Gradients of loss with respect to random parameters are the same in pytorch


In the simple code below, I perform a simple linear operation on an input tensor of ones and compute its binary cross-entropy loss considering a vector of zeros as the expected output. When computing the gradient of the loss with respect to w, the rows are the same and equal to the gradient with respect to b. This is counter-intuitive since w and b have random values. What is the reason?

n_input, n_output = 5, 3
x = torch.ones(n_input)
y = torch.zeros(n_output) # expected output
w = torch.randn(n_input, n_output, requires_grad=True) 
b = torch.randn(n_output, requires_grad=True)
z = torch.matmul(x,w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) 

loss.backward()
print(w.grad)
print(b.grad)

Output:

tensor([[0.2179, 0.4337, 0.1959],
        [0.2179, 0.4337, 0.1959],
        [0.2179, 0.4337, 0.1959],
        [0.2179, 0.4337, 0.1959],
        [0.2179, 0.4337, 0.1959]])
tensor([0.2179, 0.4337, 0.1959])

Solution

  • It's because Your input is symmetric.

    Imagine the issue from the point of view of a perceptron (You have 3 of them in Your setup): each input is 1.0 so the weights of a specific neuron don't matter (it is not important from which input You will take as there is 1.0 everywhere).

    If You diversify the input, everything works just fine:

        n_input, n_output = 5, 3
        x = torch.randn(n_input)
        y = torch.ones(n_output)/2.  # expected output
        w = torch.randn(n_input, n_output, requires_grad=True)
        b = torch.randn(n_output, requires_grad=True)
        z = torch.matmul(x, w) + b
    
        loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
        loss.backward()
        print(w.grad)
        print(b.grad)
    
        tensor([[-0.1939,  0.1657, -0.2501],
            [ 0.0561, -0.0480,  0.0724],
            [-0.3162,  0.2703, -0.4079],
            [ 0.0947, -0.0809,  0.1221],
            [-0.0140,  0.0120, -0.0181]])
        tensor([-0.1263,  0.1080, -0.1630])