Search code examples
pythondeep-learningpytorchperceptron

PyTorch optimizer.step() function doesn't update weights


The code can be seen below.
The problem is, that the optimizer.step() part doesn't work. I'm printing model.parameters() before and after the training, and the weights don't change.

I'm trying to make a perceptron that can solve the AND-problem. I've been successful in doing this with my own tiny library, where I've implemented a perceptron with the two functions predict() and train().

Just to clarify, I've just started learning deep learning using PyTorch, so it's probably a very newbie problem. I've tried searching for a solution, but without luck. I've also compared my code with other codes that work, but I don't know what I'm doing wrong.

import torch
from torch import nn, optim
from random import randint

class NeuralNet(nn.Module):
  def __init__(self):
    super(NeuralNet, self).__init__()
    self.layer1 = nn.Linear(2, 1)

  def forward(self, input):
    out = input
    out = self.layer1(out)
    out = torch.sign(out)
    out = torch.clamp(out, 0, 1) # 0=false, 1=true
    return out

data = torch.Tensor([[0, 0], [0, 1], [1, 0], [1, 1]])
target = torch.Tensor([0, 0, 0, 1])
model = NeuralNet()
epochs = 1000
lr = 0.01

print(list(model.parameters()))
print() # Print parameters before training
loss_func = nn.L1Loss()
optimizer = optim.Rprop(model.parameters(), lr)
for epoch in range(epochs + 1):
  optimizer.zero_grad()
  rand_int = randint(0, len(data) - 1)
  x = data[rand_int]
  y = target[rand_int]

  pred = model(x)
  loss = loss_func(pred, y)

  loss.backward()
  optimizer.step()

# Print parameters again
# But they haven't changed
print(list(model.parameters()))

Solution

  • The issue here is you are trying to perform back-propagation through a non-differentiable function. Non-differentiable means that no gradients can flow back through them, implying that all trainable weights applied before them will not be updated by your optimizer. Such functions are easy to spot; they are discrete, sharp operations that resemble 'if' statements. In your case it is the sign() function.

    Unfortunately, PyTorch does not do any hand-holding in this regard and will not point you to the issue. What you could do to alleviate the issue would be to transform the range of your output to [-1,1] and apply a Tanh() non-linearity instead of the sign() and clamp() operators.