Search code examples
pythonjax

How is it possible that jax vmap returns not iterable?


import jax
import pgx
from jax import vmap, jit
import jax.numpy as jnp

env = pgx.make("tic_tac_toe")
key = jax.random.PRNGKey(42)

states = jax.jit(vmap(env.init))(jax.random.split(key, 4))
type(states)

states has a type pgx.tic_tac_toe.State. I was expecting an Iterable object with a size 4. Somehow iterable results are inside pgx.tic_tac_toe.State.

Can you please explain how is it possible that jax vmap returns not iterable?

How to force vmap to return the next result:

states = [env.init(key) for key in jax.random.split(key, 4)]

Note, this code works as expected:

def square(x):
    return x ** 2
inputs = jnp.array([1, 2, 3, 4])
result = jax.vmap(square)(inputs)
print(result) # list object

Solution

  • Can you please explain how is it possible that jax vmap returns not iterable?

    When passed a non-array object, vmap will map the leading axes of each array in its flattened pytree representation. You can see the shapes in the flattened object here:

    print([arr.shape for arr in jax.tree_util.tree_flatten(states)[0]])
    # [(4,), (4, 3, 3, 2), (4, 2), (4,), (4,), (4, 9), (4,), (4,), (4, 9)]
    

    This is an example of the struct-of-arrays pattern used by vmap, where it sounds like you were expececting an array-of-structs pattern.

    How to force vmap to return the next result

    If you wanted to convert this output into the list of state objects you were expecting, you could do so using utilities in jax.tree_util:

    leaves, treedef = jax.tree_util.tree_flatten(states)
    states_list = [treedef.unflatten(leaf) for leaf in zip(*leaves)]
    print(len(states_list))
    # 4
    

    That said, it appears that pgx is built to work natively with the original struct-of-arrays pattern, so you may find that you won't actually need this unstacked version in practice.