Search code examples
neural-networkpytorchgradient-descentdetach

Gradient with respect to the parameters of a specific layer in Pytorch


I am building a model in pytorch with multiple networks. For example let's consider netA and netB. In the loss function I need to work with the composition netA(netB). In different parts of the optimization I need to calculate the gradient of loss_func(netA(netB)) with respect to only the parameters of netA and in another situation I need to calculate the gradients wrt the parameters of netB. How one should approach the problem?

My approach: In the case of calculating the gradient wrt the parameters of netA I use loss_func(netA(netB.detach())).

If I write loss_func(netA(netB).detach()) it seems that the both parameters of netA and netB are detached.

I tried to use loss_func(netA.detach(netB)) in order to only detach the parameters of netA but it doesn't work. (I get the error that netA doesn't have attribute detach.)


Solution

  • The gradients are properties of tensors not networks.
    Therefore, you can only .detach a tensor.

    You can have different optimizers for each network. This way you can compute gradients for all networks all the time, but only update weights (calling step of the relevant optimizer) for the relevant network.