Search code examples
machine-learningjaxflax

Computing the gradient of a batched function using JAX


I would need to compute the gradient of a batched function using JAX. The following is a minimal example of what I would like to do:

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

x = jnp.expand_dims(jnp.linspace(-1, 1, 20), axis=1)

u = lambda x: jnp.sin(jnp.pi * x)
ux = jax.vmap(jax.grad(u))

plt.plot(x, u(x))
plt.plot(x, ux(x))  # Use vx instead of ux
plt.show()

I have tried a variety of ways of making this work using vmap, but I don't seem to be able to get the code to run without removing the batch dimension in the input x. I have seen some workarounds using the Jacobian but this doesn't seem natural as the given is a scalar function of a single variable.

In the end u will be a neural network (implemented in Flax) that I need to differentiate with respect to the input (not the parameters of the network), so I cannot remove the batch dimension.


Solution

  • To ensure the kernel (u) returns a scalar value, so that jax.grad makes sense, the batched dimension also needs to be mapped over.

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    x = jnp.expand_dims(jnp.linspace(-1, 1, 20), axis=1)
    
    u = lambda x: jnp.sin(jnp.pi * x)
    ux = jax.vmap(jax.vmap(jax.grad(u)))
    # ux = lambda x : jax.lax.map(jax.vmap(jax.grad(u)), x) # sequential version
    # ux = lambda x : jax.vmap(jax.grad(u))(x.reshape(-1)).reshape(x.shape) # flattened map version
    
    plt.plot(x, u(x))
    plt.plot(x, ux(x))  # Use vx instead of ux
    plt.show()
    

    Which composition of maps to use depends on what's happening in the batched dimension.