Search code examples
pythonpytorchgradientautograd

How do I use autograd for a separate function independent of backpropagate in PyTorch?


I have two variables, x and theta. I am trying to minimise my loss with respect to theta only, but as part of my loss function I need the derivative of a different function (f) with respect to x. This derivative itself is not relevant to the minimisation, only its output. However, when implementing this in PyTorch I am getting a Runtime error.

A minimal example is as follows:

# minimal example of two different autograds
import torch

from torch.autograd.functional import jacobian
def f(theta, x):
    return torch.sum(theta * x ** 2)

def df(theta, x):
    J = jacobian(lambda x: f(theta, x), x)
    return J

# example evaluations of the autograd gradient
x = torch.tensor([1., 2.])
theta = torch.tensor([1., 1.], requires_grad = True)

# derivative should be 2*theta*x (same as an analytical)
with torch.no_grad():
    print(df(theta, x))
    print(2*theta*x)

tensor([2., 4.])

tensor([2., 4.])

# define some arbitrary loss as a fn of theta
loss = torch.sum(df(theta, x)**2)
loss.backward()

gives the following error

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

If I provide an analytic derivative (2*theta*x), it works fine:

loss = torch.sum((2*theta*x)**2)
loss.backward()

Is there a way to do this in PyTorch? Or am I limited in some way?

Let me know if anyone needs any more details.

PS

I am imagining the solution is something similar to the way that JAX does autograd, as that is what I am more familiar with. What I mean here is that in JAX I believe you would just do:

from jax import grad
df = grad(lambda x: f(theta, x))

and then df would just be a function that can be called at any point. But is PyTorch the same? Or is there some conflict within .backward() that causes this error?


Solution

  • PyTorch's jacobian does not create a computation graph unless you explicitely ask for it

    J = jacobian(lambda x: f(theta, x), x, create_graph=True)
    

    .. with create_graph argument.

    The documentation is quite clear about it

    create_graph (bool, optional) – If True, the Jacobian will be computed in a differentiable manner