I would like to use pytorch to optimize a objective function which makes use of an operation that cannot be tracked by torch.autograd. I wrapped such operation with a custom forward() of the torch.autograd.Function class (as suggested here and here). Since I know the gradient of such operation, i can write also the backward(). Everything look like this:
class Projector(torch.autograd.Function):
# non_torch_var are constant values needed by the operation
@staticmethod
def forward(ctx, vertices, non_torch_var1, non_torch_var2, non_torch_var3):
ctx.save_for_backward(vertices)
vertices2=vertices.detach().cpu().numpy()
ctx.non_torch_var1 = non_torch_var1
ctx.non_torch_var2 = non_torch_var2
ctx.non_torch_var3 = non_torch_var3
out = project_mesh(vertices2, non_torch_var1, non_torch_var2, non_torch_var3)
out = torch.tensor(out, requires_grad=True)
return out
@staticmethod
def backward(ctx, grad_out):
vertices = ctx.saved_tensors[0]
vertices2 = vertices.detach().cpu().numpy()
non_torch_var1 = ctx.non_torch_var1
non_torch_var2 = ctx.non_torch_var2
non_torch_var3 = ctx.non_torch_var3
grad_vertices = grad_project_mesh(vertices2, non_torch_var1, non_torch_var2, non_torch_var3)
grad_vertices = torch.tensor(grad_vertices, requires_grad=True)
return grad_vertices, None, None, None
This implementation, although, seems to not work. I used the torchviz package to plot what is going on with the following lines
import torchviz
out = Projector.apply(*input)
grad_x, = torch.autograd.grad(out.sum(), vertices, create_graph=True)
torchviz.make_dot((grad_x, vertices, out), params={"grad_x": grad_x, "vertices": vertices, "out": out}).render("attached", format="png")
and I got this graph, which is showing that grad_x is not connected to anything.
Do you have an idea of what is going wrong with such a code?
The graph correctly shows how out
is computed from vertices
(which seems to equal input
in your code). Variable grad_x
is correctly shown as disconnected because it isn't used to compute out
. In other words, out
isn't a function of grad_x
. That grad_x
is disconnected doesn't mean the gradient doesn't flow nor your custom backward
implementation doesn't work. On the contrary, that a path exists from vertices
to out
in the graph means that the gradient should flow, i.e. the autograd engine can compute the gradient of out
w.r.t vertices
. To check the correctness of your custom backward
implementation, you need to check if the value of grad_x
is correct.
In short, the gradient should flow because there is a path from vertices
to out
, and its correctness should be verified by inspecting its values and not by looking at the computation graph.