Search code examples
pythonjax

How to vmap over specific funciton in jax?


I have this function which works for single vector:

def vec_to_board(vector, player, dim, reverse=False):
    player_board = np.zeros(dim * dim)
    player_pos = np.argwhere(vector == player)
    if not reverse:
        player_board[mapping[player_pos.T]] = 1
    else:
        player_board[reverse_mapping[player_pos.T]] = 1
    return np.reshape(player_board, [dim, dim])

However, I want it to work for a batch of vectors.

What I have tried so far:

states = jnp.array([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2], [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2]])
 batch_size = 1
b_states = vmap(vec_to_board)((states, 1, 4), batch_size)

This doesn't work. However, if I understand correctly vmap should be able to handle this transformation for batches?


Solution

  • There are a couple issues you'll run into when trying to vmap this function:

    1. This function is defined in terms of numpy arrays, not jax arrays. How do I know? JAX arrays are immutable, so things like arr[idx] = 1 will raise errors. You need to replace these with equivalent JAX operations (see JAX Sharp Bits: in-place updates) and ensure your function works with JAX array operations rather than numpy array operations.
    2. Your function makes used of dynamically-shaped arrays; e.g. player_pos, has a shape dependent on the number of nonzero entries in vector == player. You'll have to rewrite your function in terms of statically-shaped arrays. There is some discussion of this in the jnp.argwhere docstring; for example, if you know a priori how many True entries you expect in the array, you can specify the size to make this work.

    Good luck!