Search code examples

how do I optmize the weights of the input layer using backward for this simple neural network in pytorch when .grad is None

I defined the following simple neural network:

import torch
import torch.nn as nn

X = torch.tensor(([1, 2]), dtype=torch.float)
y = torch.tensor([1.])
learning_rate = 0.001

class Neural_Network(nn.Module):
    def __init__(self, ):
        super(Neural_Network, self).__init__()
        self.W1 = torch.nn.Parameter(torch.tensor(([1, 0], [2, 3]), dtype=torch.float, requires_grad=True))
        self.W2 = torch.nn.Parameter(torch.tensor(([2], [1]), dtype=torch.float, requires_grad=True))
    def forward(self, X):
        self.xW1 = torch.matmul(X, self.W1)
        self.h = torch.tensor([torch.tanh(self.xW1[0]), torch.tanh(self.xW1[1])])
        return torch.sigmoid(torch.matmul(self.h, self.W2))
net = Neural_Network() 

for z in range(60):
    loss = (y - net(X))**2
    optim = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
    loss = criterion(net(X), y)

I can run it and print(net.W1) print(net.W2) prints

Parameter containing:
tensor([[1., 0.],
        [2., 3.]], requires_grad=True)
Parameter containing:
        [1.0078]], requires_grad=True)

So my problem is that it seems like W1 is not being updated. When I call print(net.W1.grad) I get None for every iteration which confuses me a lot.

  • I tried to define the function as one line like so: loss = (y - torch.sigmoid(math.tanh(x[0] * W_1[0][0] + x[1] * W_1[1][0]) * W_2[0] + math.tanh(x[0] * W_1[0][1] + x[1] * W_1[1][1]) * W_2[1])) ** 2, but it did not help anything.

  • For sure I could hardcode the derivate and everything but it seems painful and I though .backward() can be used in this case.

How can I optmize W1 with using backward()?


  • I suspect that the following line:

    self.h = torch.tensor([torch.tanh(self.xW1[0]), torch.tanh(self.xW1[1])])

    is the culprit.

    The new tensor self.h does not inherit the requires_grad attribute from self.xW1 and, by default, it is set to False.

    You can call self.h = self.tanh(self.xW1) and the operation will be then applied point-wise to all the elements of self.xW1.

    In addition, I suggest you to inspect your gradients by using PyTorch hooks.