Search code examples
pythonnumpyjax

Efficiently fill an array from a function


I want to construct a 2D array from a function in such a way that I can utilize jax.jit.

The way I would normally do this using numpy is to create an empty array, and then fill that array in-place.

xx = jnp.empty((num_a, num_b))
yy = jnp.empty((num_a, num_b))
zz = jnp.empty((num_a, num_b))

for ii_a in range(num_a):
    for ii_b in range(num_b):
        a = aa[ii_a, ii_b]
        b = bb[ii_a, ii_b]

        xyz = self.get_coord(a, b)

        xx[ii_a, ii_b] = xyz[0]
        yy[ii_a, ii_b] = xyz[1]
        zz[ii_a, ii_b] = xyz[2]

To make this work within jax I have attempted to use the jax.opt.index_update.

        xx = xx.at[ii_a, ii_b].set(xyz[0])
        yy = yy.at[ii_a, ii_b].set(xyz[1])
        zz = zz.at[ii_a, ii_b].set(xyz[2])

This runs without errors but is very slow when I try to use a @jax.jit decorator (at least an order of magnitude slower than the pure python/numpy version).

What is the best way to fill a multi-dimensional array from a function using jax?


Solution

  • JAX has a vmap transform that is designed specifically for this kind of application.

    As long as your get_coords function is compatible with JAX (i.e. is a pure function with no side-effects), you can accomplish this in one line:

    from jax import vmap
    xx, yy, zz = vmap(vmap(get_coord))(aa, bb)