Search code examples
pythonpytorch

How can I run backward() on individual image pixels without causing an error of trying to backward through the graph a second time?


I have this code

...
vertex_id = 275

deform_verts.retain_grad() # input
predicted_silhouette.retain_grad() # output

impact_img = torch.zeros_like(predicted_silhouette, requires_grad=False)
for i in range(image_size):
  for j in range(image_size):
      pixel = predicted_silhouette[i][j]
      pixel.retain_grad()
      pixel.backward()
      impact = deform_verts.grad[vertex_id] 
      impact_img[i][j] += impact.sum()


plt.imshow(impact_img.detach().cpu().numpy())

I'm trying to create an image based off how a single entry of deform_verts affects the entire image. For this purpose, I go through every pixel of the output image and call .backward() and insert its gradient into a new image for visualization purposes. However, when I call backward() on a first pixel, I cannot call it on a second one because I suspect intermediate variables have been used in the backpropagation already. I tried to use retain_graph but I don't think this does what I want since the image looks far from what I expected.


Solution

  • You need to zero-out the gradient after reading it with deform_verts.grad.zero_() . The backward method actually accumulates gradients but you just want to retain the graph, not the previous gradients