Search code examples
numpynumpy-ndarrayjax

Efficiently custom array creation routines in JAX


I'm still getting a handle of best practices in jax. My broad question is the following:

What are best practices for the implementation of custom array creation routines in jax?

For instance, I want to implement a function that creates a matrix with zeros everywhere except with ones in a given column. I went for this (Jupyter notebook):

import numpy as np
import jax.numpy as jnp

def ones_at_col(shape_mat, idx):
    idxs = jnp.arange(shape_mat[1])[None,:]
    mat = jnp.where(idx==idxs, 1, 0)
    mat = jnp.repeat(mat, shape_mat[0], axis=0)
    return mat

shape_mat = (5,10)

print(ones_at_col(shape_mat, 5))

%timeit np.zeros(shape_mat)

%timeit jnp.zeros(shape_mat)

%timeit ones_at_col(shape_mat, 5)

The output is

[[0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]]
127 ns ± 0.717 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
31.3 µs ± 331 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
123 µs ± 1.79 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

My function is a factor of 4 slower than the jnp.zeros() routine, which is not too bad. This tells me that what I'm doing is not crazy.

But then both jax routines are much slower than the equivalent numpy routines. These functions cannot be jitted because they take the shape as an argument, and so cannot be traced. I presume this is why they are inherently slower? I guess that if either of them appeared within the scope of another jitted function, they could be traced and sped up?

Is there something better I can do or am I pushing the limits of what is possible in jax?


Solution

  • The best way to do this is probably something like this:

    mat = jnp.zeros(shape_mat).at[:, 5].set(1)
    

    Regarding timing comparisons with NumPy, relevant reading is JAX FAQ: is JAX faster than NumPy? The summary is that for this particular case (creating a simple array) you would not expect JAX to match NumPy performance-wise, due to JAX's per-operation dispatch overhead.

    If you wish for faster performance in JAX, you should always use jax.jit to just-in-time compile your function. For example, this version of the function should be pretty optimal (though again, not nearly as fast as NumPy for the reasons discussed at the FAQ link):

    @partial(jax.jit, static_argnames=['shape_mat', 'idx'])
    def ones_at_col(shape_mat, idx):
      return jnp.zeros(shape_mat).at[:, idx].set(1)
    

    You could leave idx non-static if you'll be calling this function multiple times with different index values, and if you're creating these arrays within another function, you should just put the code inline and JIT-compile that outer function.

    Another side-note: your microbenchmarks may not be measuring what you think they're measuring: for tips on this see JAX FAQ: benchmarking JAX code. In particular, be careful of compilation time and asynchronous dispatch effects.