Search code examples
odedifferential-equationsjaxautomatic-differentiationautodiff

How to use and interpret JAX Vector-Jacobian Product (VJP) for this example?


I am trying to learn how to find the Jacobian of a vector-valued ODE function using JAX. I am using the examples at https://implicit-layers-tutorial.org/implicit_functions/ That page implements its own ODE integrator and associated custom forward-mode and reverse-mode Jacobian functions. I am trying to reproduce that using the official jax odeint and diffrax libraries, but both of these primarily use reverse-mode Vector Jacobian Product (VJP) instead of the forward-mode Jacobian Vector Product (JVP) for which example code is available on that page.

Here is a code snippet that I adapted from that page:

import matplotlib.pyplot as plt

from jax.config import config
config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax import jit, jvp, vjp
from jax.experimental.ode import odeint

from diffrax import diffeqsolve, ODETerm, PIDController, SaveAt, Dopri5, NoAdjoint

# returns time derivatives of each of our 3 state variables (vector-valued function)
def f(state, t, args):
    x, y, z = state
    rho, sigma, beta = args 
    return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])

# convenience function that calls jax-odeint given input initial conditions and parameters (this is the function that we want Jacobian/sensitivities of)
def evolve(y0, rho, sigma, beta): 
    return odeint(f, y0, tarr, (rho, sigma, beta))


# set up initial conditions, timespan for integration, and fiducial parameter values
y0 = jnp.array([5., 5., 5.])
tarr = jnp.linspace(0, 1., 1000)
rho = 28.
sigma = 10.
beta = 8/3. 


# first just make sure evolve() works 
ys = evolve(y0, rho, sigma, beta)

fig, ax = plt.subplots(1,figsize=(6,4),dpi=150,subplot_kw={'projection':'3d'})   
ax.plot(ys.T[0],ys.T[1],ys.T[2],'b-',lw=0.5)

# now try to take reverse-mode vector-jacobian product (VJP) since forward-mode JVP is not defined for jax-odeint
vjp_ys, vjp_evolve = vjp(evolve,y0,rho,sigma,beta)

# vjp_ys and ys are equal -- they are the solution time series of the 3 components (state variables) of y 
print(jnp.array_equal(ys,vjp_ys))

# define some perturbation in y0 and parameters 
delta_y0 = jnp.array([0., 0., 0.])
delta_rho = 0.
delta_sigma = 0.
delta_beta = 1.

####### THIS FAILS 
# vjp_evolve is a function but I am not sure how to use it to get perturbations delta_ys given y0/parameter variations
vjp_evolve(delta_y0,delta_rho,delta_sigma,delta_beta)

That last line raises an error:

TypeError: The function returned by `jax.vjp` applied to evolve was called with 4 arguments, but functions returned by `jax.vjp` must be called with a single argument corresponding to the single value returned by evolve (even if that returned value is a tuple or other container).

For example, if we have:

  def f(x):
    return (x, x)
  _, f_vjp = jax.vjp(f, 1.0)

the function `f` returns a single tuple as output, and so we call `f_vjp` with a single tuple as its argument:

  x_bar, = f_vjp((2.0, 2.0))

If we instead call `f_vjp(2.0, 2.0)`, with the values 'splatted out' as arguments rather than in a tuple, this error can arise.

I suspect I am confused at the concept of reverse-mode VJP and what the input would be in the case of this vector-valued ODE. The same problem would persist if I had used diffrax solvers.

For what it's worth, I can reproduce the forward-mode JVP results on that website if I use a diffrax solver while specifying adjoint=NoAdjoint, so that jax.jvp can be used:

# I am similarly confused about how to use VJP with diffrax's default reverse-mode autodiff of the ODE system
# however I am able to use forward-mode JVP with diffrax's ODE solver if I specify adjoint=NoAdjoint

# diffrax expects reverse order for inputs (time first, then state, then args) -- opposite of jax odeint 
def f_diffrax(t, state, args):
    x, y, z = state
    rho, sigma, beta = args 
    return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])

# set up diffrax inputs as closely to jax-odeint as possible 
terms = ODETerm(f_diffrax)
t0 = 0.0
t1 = 1.0 
dt0 = None
max_steps = 16**3 # not sure if this is needed
tsave = SaveAt(ts=tarr,dense=True)

def evolve_diffrax(y0, rho, sigma, beta):
    return diffeqsolve(terms,Dopri5(),t0,t1,dt0,y0,jnp.array([rho,sigma,beta]),saveat=tsave,
                       stepsize_controller=PIDController(rtol=1.4e-8,atol=1.4e-8),max_steps=max_steps,adjoint=NoAdjoint())

# get solution AND differentials assuming the same changes in y0 and parameters as we tried (and failed) to get above 
diffrax_ys, diffrax_delta_ys = jvp(evolve_diffrax, (y0,rho,sigma,beta),(delta_y0,delta_rho,delta_sigma,delta_beta))

# get the actual solution arrays from the diffrax Solution objects 
diffrax_ys = diffrax_ys.ys
diffrax_delta_ys = diffrax_delta_ys.ys

# plot 
fig, ax = plt.subplots(1,figsize=(6,4),dpi=150,subplot_kw={'projection':'3d'})   
ax.plot(diffrax_ys.T[0],diffrax_ys.T[1],diffrax_ys.T[2],color='violet',lw=0.5)
ax.quiver(diffrax_ys.T[0][::10],diffrax_ys.T[1][::10],diffrax_ys.T[2][::10],
          diffrax_delta_ys.T[0][::10],diffrax_delta_ys.T[1][::10],diffrax_delta_ys.T[2][::10])
    

enter image description here

That reproduces one of the main plots of that website (showing that the ODE is very sensitive to variations in the beta parameter). So I understand the concept of forward-mode JVP (given perturbations in initial conditions and/or parameters, JVP gives the corresponding perturbation in the ODE solution as a function of time). But what does reverse-mode VJP do and what would be the correct input to the vjp_evolve function above?


Solution

  • JVP is forward-mode autodiff: given tangents of the input to the function at a primal point, it returns tangents on the outputs.

    VJP is reverse-mode autodiff: given cotangents on the output of the function at a primal point, it returns cotangents on the inputs.

    So you can call vjp_evolve with cotangents of the same shape as vjp_ys:

    print(vjp_evolve(jnp.ones_like(vjp_ys)))
    
    (Array([ 1.74762118, 26.45747015, -2.03017559], dtype=float64),
     Array(871.66349663, dtype=float64),
     Array(-83.07586548, dtype=float64),
     Array(-1754.48788565, dtype=float64))
    

    Conceptually, JVP propagates gradients forward through a computation, while VJP propagates gradients backward. The JAX docs might be useful background for understanding the JVP & VJP transformations more deeply: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vector-jacobian-products-vjps-aka-reverse-mode-autodiff