Search code examples
pythonmachine-learningdeep-learningpytorchartificial-intelligence

torch.Linear weight doesn't update


#import blah blah

#active funtion
Linear = torch.nn.Linear(6,1)
sig = torch.nn.Sigmoid()

#optimizer
optim = torch.optim.SGD(Linear.parameters() ,lr = 0.001)

#input
#x => (891,6)

#output
y = y.reshape(891,1)
 
#cost function
loss_f = torch.nn.BCELoss()



for iter in range (10):
  for i in range (1000):
    optim.zero_grad()
    forward = sig(Linear(x)) > 0.5
    forward = forward.to(torch.float32)
    forward.requires_grad = True 
    loss = loss_f(forward, y)
    
    loss.backward()
    optim.step()

in this code, I want to update Linear.weight and Linear.bias but It doesn't work,, I think my code doesn't know what is weight and bias so, I tried to change

optim = torch.optim.SGD(Linear.parameters() ,lr = 0.001)

to

optim = torch.optim.SGD([Linear.weight, Linear.bias] ,lr = 0.001)

but It still didn't work,,

// I wanna explain more detail in my problem but my English level is so low 🥲 sorry


Solution

  • The BCELoss is defined as

    enter image description here

    As you can see the input x are probabilities. However your use of sig(Linear(x)) > 0.5 is wrong. Moreover, sig(Linear(x)) > 0.5 return a tensor with no autograd and it breaks the computation graph. You are explicitly setting the requires_grad=True however, since the graph is broken it cannot reach the linear layers during back propagation and so its weights are not learned/changed.

    Correct sample usage:

    import torch
    import numpy as np
    
    Linear = torch.nn.Linear(6,1)
    sig = torch.nn.Sigmoid()
    
    #optimizer
    optim = torch.optim.SGD(Linear.parameters() ,lr = 0.001)
    
    # Sample data
    x = torch.rand(891,6)
    y = torch.rand(891,1)
    
    loss_f = torch.nn.BCELoss()
    
    for iter in range (10):
        optim.zero_grad()
        output = sig(Linear(x))
        loss = loss_f(sig(Linear(x)), y)
        
        loss.backward()
        optim.step()
    
        print (Linear.bias.item())
    

    Output:

    0.10717090964317322
    0.10703673213720322
    0.10690263658761978
    0.10676861554384232
    0.10663467645645142
    0.10650081932544708
    0.10636703670024872
    0.10623333603143692
    0.10609971731901169
    0.10596618056297302