I'm trying to vectorize the following "power-of-grad" function so that it accepts multiple order
s: (see here)
def grad_pow(f, order, argnum):
for i in jnp.arange(order):
f = grad(f, argnums=argnum)
return f
This function produces the following error after applying vmap
on the argument order
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
It arose in the jnp.arange argument 'stop'
I have tried writing a static version of grad_pow
using jax.lax.cond
and jax.lax.scan
, following the logic here:
def static_grad_pow(f, order, argnum):
order_max = 3 ## maximum order
def grad_pow(f, i):
return cond(i <= order, grad(f, argnum), f), None
return scan(grad_pow, f, jnp.arange(order_max+1))[0]
if __name__ == "__main__":
test_func = lambda x: jnp.exp(-2*x)
test_func_grad_pow = static_grad_pow(jax.tree_util.Partial(test_func), 1, 0)
Nevertheless, this solution still produces an error:
return cond(i <= order, grad(f, argnum), f), None
TypeError: differentiating with respect to argnums=0 requires at least 1 positional arguments to be passed by the caller, but got only 0 positional arguments.
Just wondering how this issue can be resolved?
The fundamental issue with your question is that a vmapped function cannot return a function, it can only return arrays. All other details aside, that precludes any possibility of writing a valid function that does what you intend.
There are alternatives: for example, rather than attempting to create a function that will return a function, you could instead create a function that accepts arguments and applies that function to those arguments.
In that case, you'll run into another issue: if n
is traced, there is no way to apply grad
times. JAX transformations like grad
are evaluated at trace-time, and traced values like n
are not available until runtime. One way to work around this is to pre-define all the functions you're interested in, and to use lax.switch
to choose between them at runtime. The result would look something like this:
import jax
import jax.numpy as jnp
from functools import partial
@partial(jax.jit, static_argnums=[0], static_argnames=['argnum', 'max_order'])
def apply_multi_grad(f, order, *args, argnum=0, max_order=10):
funcs = [f]
for i in range(max_order):
funcs.append(jax.grad(funcs[-1], argnum))
return jax.lax.switch(order, funcs, *args)
order = jnp.arange(3)
x = jnp.ones(3)
f = jnp.sin
print(jax.vmap(apply_multi_grad, in_axes=(None, 0, 0))(f, order, x))
# [ 0.84147096 0.5403023 -0.84147096]
# Compare by doing it manually:
print(jnp.array([f(x[0]), jax.grad(f)(x[1]), jax.grad(jax.grad(f))(x[2])]))
# [ 0.84147096 0.5403023 -0.84147096]