Search code examples
pythonoptimizationpytorchgradientautograd

How to implement a custom forward/backward function for torch.autograd.Function?


I would like to use pytorch to optimize a objective function which makes use of an operation that cannot be tracked by torch.autograd. I wrapped such operation with a custom forward() of the torch.autograd.Function class (as suggested here and here). Since I know the gradient of such operation, i can write also the backward(). Everything look like this:

class Projector(torch.autograd.Function):

    # non_torch_var are constant values needed by the operation
    @staticmethod
    def forward(ctx, vertices, non_torch_var1, non_torch_var2, non_torch_var3):

        ctx.save_for_backward(vertices)
        vertices2=vertices.detach().cpu().numpy()
        ctx.non_torch_var1  = non_torch_var1 
        ctx.non_torch_var2  = non_torch_var2  
        ctx.non_torch_var3  = non_torch_var3 
        out = project_mesh(vertices2, non_torch_var1, non_torch_var2, non_torch_var3)
        out = torch.tensor(out, requires_grad=True)
        return out

    @staticmethod
    def backward(ctx, grad_out):
        vertices  = ctx.saved_tensors[0]
        vertices2 = vertices.detach().cpu().numpy()
        non_torch_var1 = ctx.non_torch_var1
        non_torch_var2 = ctx.non_torch_var2 
        non_torch_var3 = ctx.non_torch_var3

        grad_vertices = grad_project_mesh(vertices2, non_torch_var1, non_torch_var2, non_torch_var3)
        grad_vertices = torch.tensor(grad_vertices, requires_grad=True)
        return grad_vertices, None, None, None

This implementation, although, seems to not work. I used the torchviz package to plot what is going on with the following lines

import torchviz
out = Projector.apply(*input)
grad_x, = torch.autograd.grad(out.sum(), vertices, create_graph=True)
torchviz.make_dot((grad_x, vertices, out), params={"grad_x": grad_x, "vertices": vertices, "out": out}).render("attached", format="png")

and I got this graph, which is showing that grad_x is not connected to anything.

Do you have an idea of what is going wrong with such a code?


Solution

  • The graph correctly shows how out is computed from vertices (which seems to equal input in your code). Variable grad_x is correctly shown as disconnected because it isn't used to compute out. In other words, out isn't a function of grad_x. That grad_x is disconnected doesn't mean the gradient doesn't flow nor your custom backward implementation doesn't work. On the contrary, that a path exists from vertices to out in the graph means that the gradient should flow, i.e. the autograd engine can compute the gradient of out w.r.t vertices. To check the correctness of your custom backward implementation, you need to check if the value of grad_x is correct.

    In short, the gradient should flow because there is a path from vertices to out, and its correctness should be verified by inspecting its values and not by looking at the computation graph.