Search code examples
pythonjaxautomatic-differentiation

Why does jax.grad(lambda v: jnp.linalg.norm(v-v))(jnp.ones(2)) produce nans?


Can someone explain the following behaviour? Is it a bug?

from jax import grad
import jax.numpy as jnp

x = jnp.ones(2)
grad(lambda v: jnp.linalg.norm(v-v))(x) # returns DeviceArray([nan, nan], dtype=float32)

grad(lambda v: jnp.linalg.norm(0))(x) # returns DeviceArray([0., 0.], dtype=float32)

I've tried looking up the error online but didn't find anything relevant.

I also skimmed through https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html


Solution

  • When you compute grad(lambda v: jnp.linalg.norm(v-v))(x), your function looks roughly like this:

    f(x) = sqrt[(x - x)^2]
    

    so, evaluating with the chain rule, the derivative is

    df/dx = (x - x) / sqrt[(x - x)^2]
    

    which, when you plug-in any finite x evaluates to

    0 / sqrt(0)
    

    which is undefined, and represented by NaN in floating point arithmetic.

    When you compute grad(lambda v: jnp.linalg.norm(0))(x), your function looks roughly like this:

    g(x) = sqrt[0.0^2]
    

    and because it has no dependence on x the derivative is simply

    dg/dx = 0.0
    

    Does that answer your question?