Search code examples

how does the pytorch autograd work?

I submitted this as an issue to cycleGAN pytorch implementation, but since nobody replied me there, i will ask again here.

I'm mainly puzzled by the fact that multiple forward passes was called before one single backward pass, see the following in code cycle_gan_model

# GAN loss
# D_A(G_A(A))
self.fake_B = self.netG_A.forward(self.real_A)
pred_fake = self.netD_A.forward(self.fake_B)
self.loss_G_A = self.criterionGAN(pred_fake, True)
# D_B(G_B(B))
self.fake_A = self.netG_B.forward(self.real_B)
pred_fake = self.netD_B.forward(self.fake_A)
self.loss_G_B = self.criterionGAN(pred_fake, True)
# Forward cycle loss G_B(G_A(A))
self.rec_A = self.netG_B.forward(self.fake_B)
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# Backward cycle loss G_A(G_B(B))
self.rec_B = self.netG_A.forward(self.fake_A)
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B

The way I see it, G_A and G_B each has three forward passes, twice accepting the real data (real_A or real_B) and twice the fake data (fake_B or fake_A).

In tensorflow (I think) the backward pass is always computed w.r.t the last input data. In this case, the backpropagation of loss_G would be wrong. One should instead do backward pass thrice, each immediately following their involving forward pass.

Specifically, netG_A's gradients from loss_G_A is w.r.t real_A but its gradients from loss_cycle_B is w.r.t fake_A.

I assume this is somehow taken care of in pytorch. But how does the model know w.r.t which input data it should compute the gradients?


  • Pytorch uses a tape based system for automatic differentiation. that means that it will backpropagate from the last operation it did. I think that the best way to understand is make a diagram from the process. I attach one that I did by handenter image description here

    Now you will see that some modules are "repeated". The way I think about them is the same way I think about RNNs; in that way, the gradients will just be added.