Search code examples
pytorchgradientautograd

How to "manually" apply your gradients in Pytorch?


what would be the equivalent in Pytorch of the following in tensorflow, where loss is the calculated loss in the iteration of the network and net is the Neural Network.

with tf.GradientTape() as tape:              
    grads = tape.gradient(loss, net.trainable_variables)
    optimizer.apply_gradients(zip(grads, net.trainable_variables))

So, we compute our gradients for all the trainable variables in our network in accordance to the loss function. In the next line we apply the gradients via the optimizer. In the use case I have, this is the way to do it and it works fine.

Now, how would I do the same in Pytorch? I am aware of the "standard" way:

optimizer.zero_grad()
loss.backward()
optimizer.step()

That is however not applicable for me. So how can I apply the gradients "manually". Google doesn't help unfortunately, although I think it is probably a rather simple question.

Hope one of you can enlighten me!

Thanks!


Solution

  • Let's break the standard PyTorch way of doing updates; hopefully, that will clarify what you want.

    In Pytorch, each NN parameter has a .data and .grad attribute. .data is ... the actual weight tensor, and .grad is the attribute that will hold the gradient. It is None if the gradient is not computed yet. With this knowledge, let's understand the update steps.

    First, we do optimizer.zero_grad(). This zeros out or empties the .grad attribute. .grad may be None already if you never computed the gradients.

    Next, we do loss.backward(). This is the backprop step that will compute and update each parameter's .grad attribute.

    Once we have gradients, we want to update the weights with some rule (SGD, ADAM, etc.), and we do optimizer.step(). This will iterate over all the parameters and update the weights correctly using the compute .grad attributes.

    So, now to apply gradients manually, you can replace the optimizer.step() with a for loop like the below:

    for param in model.parameters():
        param.data = custom_rule(param.data, param.grad, learning_rate, **any_other_arguments)
    

    and that should do the trick.