The parameter in_axes
in vmap
seems to only work for positional arguments.
But throws AssertionError
(with no message) called with keyword argument.
from jax import vmap
import numpy as np
def foo(a, b, c):
return a * b + c
foo = vmap(foo, in_axes=(0, 0, None))
aj, bj = np.random.rand(2, 100, 1)
foo(aj, bj, 10) # works
foo(aj, bj, c=10) # throws error
console
Traceback (most recent call last):
File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\api_util.py", line 300, in flatten_axes
tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)
File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\tree_util.py", line 183, in tree_map
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\tree_util.py", line 183, in <listcomp>
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Tuple arity mismatch: 2 != 3; tuple: (<object object at 0x00000187F7BF4380>, <object object at 0x00000187F7BF4380>).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\IPython\core\interactiveshell.py", line 3433, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-20500a2f8a08>", line 1, in <module>
runfile('C:\\Users\\Amith\\PycharmProjects\\nntp\\tests\\test2.py', wdir='C:\\Users\\Amith\\PycharmProjects\\nntp\\tests')
File "C:\Program Files\JetBrains\PyCharm 2022.2\plugins\python\helpers\pydev\_pydev_bundle\pydev_umd.py", line 198, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "C:\Program Files\JetBrains\PyCharm 2022.2\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "C:\Users\Amith\PycharmProjects\nntp\tests\test2.py", line 11, in <module>
foo(aj, bj, c=10)
File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\api.py", line 1481, in vmap_f
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
File "C:\Users\Amith\PycharmProjects\nntp\venv\lib\site-packages\jax\_src\api_util.py", line 306, in flatten_axes
assert treedef_is_leaf(leaf)
AssertionError
how would one go about running foo as foo(aj, bj, c=10)
without provoking the error?
Yes, it's true that vmap
in_axes
only works for positional arguments. If you want to make a more general vmapped function, the best option currently is probably to use a wrapper function. For example:
def _foo(a, b, c):
return a * b + c
def foo(a, b, c):
return vmap(_foo, in_axes=(0, 0, None))(a, b, c)