Search code examples
pythonpytorchautogradautomatic-differentiation

How to generate jacobian of a tensor-valued function using torch.autograd?


Computing the jacobian of a function f : R^d -> R^d is not too hard:

def jacobian(y, x):
    k, d = x.shape
    jacobian = list()
    
    for i in range(d):
        v = torch.zeros_like(y)
        v[:, i] = 1.
        dy_dx = torch.autograd.grad(y, x, grad_outputs = v, retain_graph = True, create_graph = True, allow_unused = True)[0]  # shape [k, d]
        jacobian.append(dy_dx)
    jacobian = torch.stack(jacobian, dim = 1).requires_grad_()
    return jacobian

Above, jacobian is invoked with y = f(x). However, now I have a function g = g(t, x), where t is a torch.tensor of shape k and x is a torch.tensor of shape (k, d1, d2, d3). The result of g is again a torch.tensor of shape (k, d1, d2, d3)

I've tried to use my already existing jacobian function. What I did was

y = g(t, x)
x = x.flatten(1)
y = y.flatten(1)
jacobian(y, x)

The problem is that all the time dy_dx is None. The only explanation I have for this is that most probably the dependency graph is broken after the flatten(1) call.

So, what can I do here? I should remark that what I actually want to compute is the divergence. That is, the trace of the jacobian. If there is a more performant solution for that specific case available, I'd be interested in that one.


Solution

  • You are correct, you are passing x.flatten(1) as an input even though y - let alone y.flatten(1) - was computed from x, not x.flatten(1). Instead, you could avoid the flattening with something like this:

    def jacobian(y, x):
        jacobian = []
        for i in range(x.numel()//len(x)):
            v = torch.zeros_like(y.flatten(1))
            v[:, i] = 1.
            dy_dx, *_ = torch.autograd.grad(y, x, v.view_as(x), 
                          retain_graph=True, create_graph=True, allow_unused=True)
            jacobian.append(dy_dx)
        jacobian = torch.stack(jacobian, dim=1)
        return jacobian
    

    So after calling the function, you can flatten d1, d2, and d3 together. Here is a minimal example:

    x = torch.rand(1,3,2,2, requires_grad=True)
    J = jacobian(x**2, x)
    
    > J.flatten(-3)
    tensor([[[1.2363, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.2386, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 1.0451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 1.4160, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.4090, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.7642, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3041, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.1995]]],
           grad_fn=<ViewBackward0>)
    

    From here, I guess you can compute the trace you're looking for with torch.trace.

    > [j.trace() for j in J.flatten(-3)]
    [tensor(7.6128, grad_fn=<TraceBackward0>)]
    

    Keep in mind you can also use the builtin jacobian function:

    J = torch.autograd.functional.jacobian(lambda x: g(t, x), x)
    

    However, you will need to reshape the result:

    k = x.numel()//len(x)
    
    > [j.trace() for j in J.view(len(x), k, k)]
    [tensor(7.6128)]