Search code examples

aggregate calculation vmap jax

I'm trying to implement fast routine to calculate array of energies and find the smallest calculated value and its index. Here is my code that is working fine:

def findMinEnergy(x):
  def calcEnergy(a):
    return a*a  # very simplified body, it is actually 15 lines of code
  energies = vmap(calcEnergy, in_axes=(0))(x)
  idx = energies.argmin(axis=0)
  minenrgy = energies[idx]
  return idx, minenrgy

I wonder if it is possible to not use the (separate) argmin call, but return the min calculated energy value and it's index from the vmap (similar like other aggregate functions work, e.g. jax.sum)? I hope that it could be more efficient.


  • If you JIT-compile your current approach, you should find that it's as efficient as doing something more sophisticated.

    Looking at the implementation of argmin, you'll see that it computes both the value and the index before returning only the index:

    If you want, you could follow this implementation and define a function using lax.reduce that returns both these values in a single pass:

    import jax
    import jax.numpy as jnp
    def min_and_argmin_onepass(x):
      # This only works for 1D float arrays, but you could generalize it.
      assert x.ndim == 1
      assert jnp.issubdtype(x.dtype, jnp.floating)
      def reducer(op_val_index, acc_val_index):
        op_val, op_index = op_val_index
        acc_val, acc_index = acc_val_index
        pick_op_val = (op_val < acc_val) | jnp.isnan(op_val)
        pick_op_index = pick_op_val | ((op_val == acc_val) & (op_index < acc_index))
        return (jnp.where(pick_op_val, op_val, acc_val),
                jnp.where(pick_op_index, op_index, acc_index))
      indices = jnp.arange(len(x))
      return jax.lax.reduce((x, indices), (jnp.inf, 0), reducer, (0,))

    Testing this, we see it matches the output of the less sophisticated approach:

    def min_and_argmin(x):
      i = jnp.argmin(x)
      return x[i], i
    x = jax.random.uniform(jax.random.key(0), (1000000,))
    # (Array(9.536743e-07, dtype=float32), Array(24430, dtype=int32))
    # (Array(9.536743e-07, dtype=float32), Array(24430, dtype=int32))

    If you compare the runtime of the two, you'll see comparable runtimes:

    %timeit jax.block_until_ready(min_and_argmin_onepass(x))
    # 2.17 ms ± 68.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    %timeit jax.block_until_ready(min_and_argmin(x))
    # 2.07 ms ± 66.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

    The jax.jit decorator here means that the compiler optimizes the sequence of operations in the less sophisticated approach, and the result is that you don't gain much advantage from trying to express things more cleverly. Given this, I think your best option is to stick with your original code rather than trying to out-optimize the XLA compiler.