Search code examples
pythonmachine-learningdeep-learningjaxautomatic-differentiation

Taking derivatives with multiple inputs in JAX


I am trying to take first and second derivatives of functions in JAX however, my ways of doing that give me the wrong number or zeros. I have an array with two columns for each variable and two rows for each input

import jax.numpy as jnp
import jax

rng = rng = jax.random.PRNGKey(1234)
array = jax.random.normal(rng, (2,2))

Two test functions

def F1(arr):
    return 1/arr

def F2(arr):
    return jnp.array([arr[0]**2 + arr[1]**3])

and two methods of taking first and second derivatives, one using jax.grad()

def dF_m1(arr, F):
    return jax.grad(lambda arr: F(arr)[0])(arr)

def ddF_m1(arr, F, dF):
    return jax.grad(lambda arr: dF(arr, F)[0])(arr)

and another using jax.jacobian()

def dF_m2(arr, F):
    jac = jax.jacobian(lambda arr: F(arr))(arr)
    return jnp.diag(jac)

def ddF_m2(arr, F, dF):
    hess = jax.jacobian(lambda arr: dF(arr, F))(arr)
    return jnp.diag(hess)

Computing the first and second derivative (and error) of each function using both methods gives the following

exact_dF1  = (-1/array**2)
exact_ddF1 = (2/array**3)

print("Function 1 using all grad()")
dF1_m1 = jax.vmap(dF_m1, in_axes=(0,None))(array, F1)
ddF1_m1 = jax.vmap(ddF_m1, in_axes=(0,None,None))(array, F1, dF_m1)
print(dF1_m1  - exact_dF1,"\n")
print(ddF1_m1 - exact_ddF1,"\n")

print("Function 1 using all jacobian()")
dF1_m2 = jax.vmap(dF_m2, in_axes=(0,None))(array, F1)
ddF1_m2 = jax.vmap(ddF_m2, in_axes=(0,None,None))(array, F1, dF_m2)
print(dF1_m2  - exact_dF1,"\n")
print(ddF1_m2 - exact_ddF1,"\n")

Output

Function 1 using all grad()
[[ 0.         48.43877   ]
 [ 0.          0.62903005]] 

[[  0.        674.248    ]
 [  0.          0.9977852]] 

Function 1 using all jacobian()
[[0. 0.]
 [0. 0.]] 

[[0. 0.]
 [0. 0.]] 

and

exact_dF2  = jnp.hstack( (2*array[:, 0:1], 3*array[:, 1:2]**2))
exact_ddF2 = jnp.hstack( (2 + 0*array[:, 0:1], 6*array[:, 1:2]))

print("Function 2 using all grad()")
dF2_m1 = jax.vmap(dF_m1, in_axes=(0,None))(array, F2)
ddF2_m1 = jax.vmap(ddF_m1, in_axes=(0,None,None))(array, F2, dF_m1)
print(dF2_m1  - exact_dF2,"\n")
print(ddF2_m1 - exact_ddF2,"\n")

print("Function 2 using all jacobian()")
dF2_m2 = jax.vmap(dF_m2, in_axes=(0,None))(array, F2)
ddF2_m2 = jax.vmap(ddF_m2, in_axes=(0,None,None))(array, F2, dF_m2)
print(dF2_m2  - exact_dF2,"\n")
print(ddF2_m2 - exact_ddF2,"\n")

Output

Function 2 using all grad()
[[0. 0.]
 [0. 0.]] 

[[0.         0.86209416]
 [0.         7.5651155 ]] 

Function 2 using all jacobian()
[[ 0.         -0.10149619]
 [ 0.         -6.925739  ]] 

[[0.        2.8620942]
 [0.        9.565115 ]] 

I would prefer only to use jax.grad() for something like F1 but it seems right now that only jax.jacobian is working. The whole reason for this is that I need to calculate higher-order derivatives of a neural network with respect to its inputs. Thank you for any help.


Solution

  • Assuming exact_* is what you're attempting to compute, you're going about it in the wrong way. Your indexing within the differentiated functions (i.e. ...[0]) is removing some of the elements that you're trying to compute.

    What exact_dF1 and exact_ddF1 are computing is element-wise first and second derivatives for 2D inputs. You can compute this using either grad or jacobian by applying vmap twice (once for each input dimension). For example:

    exact_dF1  = (-1/array**2)
    grad_dF1 = jax.vmap(jax.vmap(jax.grad(F1)))(array)
    jac_dF1 = jax.vmap(jax.vmap(jax.jacobian(F1)))(array)
    print(jnp.allclose(exact_dF1, grad_dF1))  # True
    print(jnp.allclose(exact_dF1, jac_dF1))  # True
    
    exact_ddF1 = (2/array**3)
    grad_ddF1 = jax.vmap(jax.vmap(jax.grad(jax.grad(F1))))(array)
    jac_ddF1 = jax.vmap(jax.vmap(jax.jacobian(jax.jacobian(F1))))(array)
    print(jnp.allclose(exact_ddF1, grad_ddF1))  # True
    print(jnp.allclose(exact_ddF1, jac_ddF1))  # True
    

    What exact_dF2 and exact_ddF2 are computing is a row-wise jacobian and hessian of a 2D->1D mapping. By its nature, this is difficult to compute using jax.grad, which is meant for functions with scalar output, but you can compute it using the jacobian this way:

    exact_dF2  = jnp.hstack( (2*array[:, 0:1], 3*array[:, 1:2]**2))
    exact_ddF2 = jnp.hstack( (2 + 0*array[:, 0:1], 6*array[:, 1:2]))
    
    jac_dF2 = jax.vmap(jax.jacobian(lambda a: F2(a)[0]))(array)
    jac_ddF2_full = jax.vmap(jax.jacobian(jax.jacobian(lambda a: F2(a)[0])))(array)
    jac_ddF2 = jax.vmap(jnp.diagonal)(jac_ddF2_full)
    print(jnp.allclose(exact_dF2, jac_dF2))  # True
    print(jnp.allclose(exact_ddF2, jac_ddF2))  # True