I have this function which works for single vector:
def vec_to_board(vector, player, dim, reverse=False):
player_board = np.zeros(dim * dim)
player_pos = np.argwhere(vector == player)
if not reverse:
player_board[mapping[player_pos.T]] = 1
else:
player_board[reverse_mapping[player_pos.T]] = 1
return np.reshape(player_board, [dim, dim])
However, I want it to work for a batch of vectors.
What I have tried so far:
states = jnp.array([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2], [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2]])
batch_size = 1
b_states = vmap(vec_to_board)((states, 1, 4), batch_size)
This doesn't work. However, if I understand correctly vmap should be able to handle this transformation for batches?
There are a couple issues you'll run into when trying to vmap
this function:
arr[idx] = 1
will raise errors. You need to replace these with equivalent JAX operations (see JAX Sharp Bits: in-place updates) and ensure your function works with JAX array operations rather than numpy array operations.player_pos
, has a shape dependent on the number of nonzero entries in vector == player
. You'll have to rewrite your function in terms of statically-shaped arrays. There is some discussion of this in the jnp.argwhere
docstring; for example, if you know a priori how many True entries you expect in the array, you can specify the size
to make this work.Good luck!