I am trying to take first and second derivatives of functions in JAX however, my ways of doing that give me the wrong number or zeros. I have an array with two columns for each variable and two rows for each input
import jax.numpy as jnp
import jax
rng = rng = jax.random.PRNGKey(1234)
array = jax.random.normal(rng, (2,2))
Two test functions
def F1(arr):
return 1/arr
def F2(arr):
return jnp.array([arr[0]**2 + arr[1]**3])
and two methods of taking first and second derivatives, one using jax.grad()
def dF_m1(arr, F):
return jax.grad(lambda arr: F(arr)[0])(arr)
def ddF_m1(arr, F, dF):
return jax.grad(lambda arr: dF(arr, F)[0])(arr)
and another using jax.jacobian()
def dF_m2(arr, F):
jac = jax.jacobian(lambda arr: F(arr))(arr)
return jnp.diag(jac)
def ddF_m2(arr, F, dF):
hess = jax.jacobian(lambda arr: dF(arr, F))(arr)
return jnp.diag(hess)
Computing the first and second derivative (and error) of each function using both methods gives the following
exact_dF1 = (-1/array**2)
exact_ddF1 = (2/array**3)
print("Function 1 using all grad()")
dF1_m1 = jax.vmap(dF_m1, in_axes=(0,None))(array, F1)
ddF1_m1 = jax.vmap(ddF_m1, in_axes=(0,None,None))(array, F1, dF_m1)
print(dF1_m1 - exact_dF1,"\n")
print(ddF1_m1 - exact_ddF1,"\n")
print("Function 1 using all jacobian()")
dF1_m2 = jax.vmap(dF_m2, in_axes=(0,None))(array, F1)
ddF1_m2 = jax.vmap(ddF_m2, in_axes=(0,None,None))(array, F1, dF_m2)
print(dF1_m2 - exact_dF1,"\n")
print(ddF1_m2 - exact_ddF1,"\n")
Output
Function 1 using all grad()
[[ 0. 48.43877 ]
[ 0. 0.62903005]]
[[ 0. 674.248 ]
[ 0. 0.9977852]]
Function 1 using all jacobian()
[[0. 0.]
[0. 0.]]
[[0. 0.]
[0. 0.]]
and
exact_dF2 = jnp.hstack( (2*array[:, 0:1], 3*array[:, 1:2]**2))
exact_ddF2 = jnp.hstack( (2 + 0*array[:, 0:1], 6*array[:, 1:2]))
print("Function 2 using all grad()")
dF2_m1 = jax.vmap(dF_m1, in_axes=(0,None))(array, F2)
ddF2_m1 = jax.vmap(ddF_m1, in_axes=(0,None,None))(array, F2, dF_m1)
print(dF2_m1 - exact_dF2,"\n")
print(ddF2_m1 - exact_ddF2,"\n")
print("Function 2 using all jacobian()")
dF2_m2 = jax.vmap(dF_m2, in_axes=(0,None))(array, F2)
ddF2_m2 = jax.vmap(ddF_m2, in_axes=(0,None,None))(array, F2, dF_m2)
print(dF2_m2 - exact_dF2,"\n")
print(ddF2_m2 - exact_ddF2,"\n")
Output
Function 2 using all grad()
[[0. 0.]
[0. 0.]]
[[0. 0.86209416]
[0. 7.5651155 ]]
Function 2 using all jacobian()
[[ 0. -0.10149619]
[ 0. -6.925739 ]]
[[0. 2.8620942]
[0. 9.565115 ]]
I would prefer only to use jax.grad()
for something like F1
but it seems right now that only jax.jacobian
is working. The whole reason for this is that I need to calculate higher-order derivatives of a neural network with respect to its inputs. Thank you for any help.
Assuming exact_*
is what you're attempting to compute, you're going about it in the wrong way. Your indexing within the differentiated functions (i.e. ...[0]
) is removing some of the elements that you're trying to compute.
What exact_dF1
and exact_ddF1
are computing is element-wise first and second derivatives for 2D inputs. You can compute this using either grad
or jacobian
by applying vmap
twice (once for each input dimension). For example:
exact_dF1 = (-1/array**2)
grad_dF1 = jax.vmap(jax.vmap(jax.grad(F1)))(array)
jac_dF1 = jax.vmap(jax.vmap(jax.jacobian(F1)))(array)
print(jnp.allclose(exact_dF1, grad_dF1)) # True
print(jnp.allclose(exact_dF1, jac_dF1)) # True
exact_ddF1 = (2/array**3)
grad_ddF1 = jax.vmap(jax.vmap(jax.grad(jax.grad(F1))))(array)
jac_ddF1 = jax.vmap(jax.vmap(jax.jacobian(jax.jacobian(F1))))(array)
print(jnp.allclose(exact_ddF1, grad_ddF1)) # True
print(jnp.allclose(exact_ddF1, jac_ddF1)) # True
What exact_dF2
and exact_ddF2
are computing is a row-wise jacobian and hessian of a 2D->1D mapping. By its nature, this is difficult to compute using jax.grad
, which is meant for functions with scalar output, but you can compute it using the jacobian this way:
exact_dF2 = jnp.hstack( (2*array[:, 0:1], 3*array[:, 1:2]**2))
exact_ddF2 = jnp.hstack( (2 + 0*array[:, 0:1], 6*array[:, 1:2]))
jac_dF2 = jax.vmap(jax.jacobian(lambda a: F2(a)[0]))(array)
jac_ddF2_full = jax.vmap(jax.jacobian(jax.jacobian(lambda a: F2(a)[0])))(array)
jac_ddF2 = jax.vmap(jnp.diagonal)(jac_ddF2_full)
print(jnp.allclose(exact_dF2, jac_dF2)) # True
print(jnp.allclose(exact_ddF2, jac_ddF2)) # True