Search code examples
pythonpytorchautograd

Directly access derivative of primitive functions in PyTorch


For the backpropagation in PyTorch, many gradients of simple, functions are of course already implemented.

But what if I want to have a function that evaluate the gradient of an existing primitive function directly, e.g. the derivative of torch.sigmoid(x) with respect to x? I'd also like to be able to backpropagate through this new function.

The goal would be something like the following, but by using only torch.sigmoid instead of a custom (re-)implementation.

import torch
import matplotlib.pyplot as plt

def dsigmoid_dx(x):
  return torch.sigmoid(x) * (1-torch.sigmoid(x))

xx = torch.linspace(-3.5, 3.5, 100)
yy = dsigmoid_dx(xx)
# ... do other stuff with yy

Of course, I could make x require gradients, pass it through the function, and then use autograd, e.g. as follows:

import torch
import matplotlib.pyplot as plt


xx = torch.linspace(-3.5, 3.5, 100, requires_grad=True)
yy = torch.sigmoid(xx)

grad = torch.autograd.grad(yy, [xx], grad_outputs=torch.ones_like(yy), create_graph=True)[0]

plt.plot(xx.detach(), grad.detach())
plt.plot(xx.detach(), yy.detach(), color='red')
plt.show();

Is it (for individual, primitive functions) possible to somehow directly access the implemented backward function? In the pytorch docs it's shown how to extend autograd, but I can't figure out how to directly access these functions for existing ones (again, e.g. torch.sigmoid)

To summarize, I want to avoid having to reimplement simple derivatives of functions, which are obviously already implemented in the framework (and presumably in a numerically stable way). Is this possible? Or do I always have to reimplement it myself?


Solution

  • Since the computation of yy only involves one (native) function which is torch.sigmoid, then ultimately calling autograd.grad or similarly yy.backward will result in directly calling the implemented backward function of sigmoid. Which is by the looks of it what you are looking for in the first place. In other words, backpropagating on yy is the exact definition of accessing (ie. calling) for a given point.

    So one alternative interface you can use is backward:

    xx = torch.linspace(-3.5, 3.5, 100, requires_grad=True)
    yy = torch.sigmoid(xx)
    
    yy.sum().backward()
    plt.plot(xx.detach(), xx.grad)
    plt.plot(xx.detach(), yy.detach(), color='red')