Search code examples
pythonloopsvectorizationjax

Vectorizing power of `jax.grad`


I'm trying to vectorize the following "power-of-grad" function so that it accepts multiple orders: (see here)

def grad_pow(f, order, argnum):

    for i in jnp.arange(order):
        f = grad(f, argnums=argnum)

    return f

This function produces the following error after applying vmap on the argument order:

jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
It arose in the jnp.arange argument 'stop'

I have tried writing a static version of grad_pow using jax.lax.cond and jax.lax.scan, following the logic here:

def static_grad_pow(f, order, argnum):

    order_max = 3  ## maximum order

    def grad_pow(f, i):
        return cond(i <= order, grad(f, argnum), f), None

    return scan(grad_pow, f, jnp.arange(order_max+1))[0]


if __name__ == "__main__":

    test_func = lambda x: jnp.exp(-2*x)
    test_func_grad_pow = static_grad_pow(jax.tree_util.Partial(test_func), 1, 0)
    print(test_func_grad_pow(1.))

Nevertheless, this solution still produces an error:

    return cond(i <= order, grad(f, argnum), f), None
TypeError: differentiating with respect to argnums=0 requires at least 1 positional arguments to be passed by the caller, but got only 0 positional arguments.

Just wondering how this issue can be resolved?


Solution

  • The fundamental issue with your question is that a vmapped function cannot return a function, it can only return arrays. All other details aside, that precludes any possibility of writing a valid function that does what you intend.

    There are alternatives: for example, rather than attempting to create a function that will return a function, you could instead create a function that accepts arguments and applies that function to those arguments.

    In that case, you'll run into another issue: if n is traced, there is no way to apply grad n times. JAX transformations like grad are evaluated at trace-time, and traced values like n are not available until runtime. One way to work around this is to pre-define all the functions you're interested in, and to use lax.switch to choose between them at runtime. The result would look something like this:

    import jax
    import jax.numpy as jnp
    from functools import partial
    
    @partial(jax.jit, static_argnums=[0], static_argnames=['argnum', 'max_order'])
    def apply_multi_grad(f, order, *args, argnum=0, max_order=10):
      funcs = [f]
      for i in range(max_order):
        funcs.append(jax.grad(funcs[-1], argnum))
      return jax.lax.switch(order, funcs, *args)
    
    
    order = jnp.arange(3)
    x = jnp.ones(3)
    f = jnp.sin
    
    print(jax.vmap(apply_multi_grad, in_axes=(None, 0, 0))(f, order, x))
    # [ 0.84147096  0.5403023  -0.84147096]
    
    # Compare by doing it manually:
    print(jnp.array([f(x[0]), jax.grad(f)(x[1]), jax.grad(jax.grad(f))(x[2])]))
    # [ 0.84147096  0.5403023  -0.84147096]