I am building a model in pytorch with multiple networks. For example let's consider netA
and netB
. In the loss function I need to work with the composition netA(netB)
. In different parts of the optimization I need to calculate the gradient of loss_func(netA(netB))
with respect to only the parameters of netA
and in another situation I need to calculate the gradients wrt the parameters of netB
. How one should approach the problem?
My approach: In the case of calculating the gradient wrt the parameters of netA
I use loss_func(netA(netB.detach()))
.
If I write loss_func(netA(netB).detach())
it seems that the both parameters of netA
and netB
are detached.
I tried to use loss_func(netA.detach(netB))
in order to only detach the parameters of netA
but it doesn't work. (I get the error that netA
doesn't have attribute detach.)
The gradients are properties of tensors not networks.
Therefore, you can only .detach
a tensor.
You can have different optimizers for each network. This way you can compute gradients for all networks all the time, but only update weights (calling step
of the relevant optimizer) for the relevant network.