Search code examples
pythonjax

jax parallel multiplication of pairs of matrix with different shapes


Task: I have two lists of matrices A,B with length N. For each pair of elements A[i], B[i] shapes are such that matrix product is well-defined, however for each i in $0,\dots, N-1$ shapes can be different. Hence, I can not stack them in array. Shapes are static.

I would like to do achieve same result as following :

out = [None] * length(A)
for i, a, b in enumerate(zip(A,B)):
   out[i] = a @ b

However, I would like to do this in parallel with jax. The best option will be vmap, but it is impossible as shapes are different.

Here I will discuss solutions that I know and why they are not satisfactory.

  1. Write for loop and then jit it. This will grow compilation time super linear over length N. This is not good, as I know all shapes of input and output before running computation, so I would expect to constant compilation time (provided say list of shapes).

  2. Use fori_loop primitive from jax. In documentation, there is following:

The semantics of fori_loop are given by this Python implementation:


def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

However, my case is easier: I don't need to care val across iterations. This means that fori is sequential. While my case is parallel. Hence, it should be possible to do better.

  1. Pad with zeros, use vmap, read result. I don't control distribution of shapes, so it can lead to blowing memory if only one shape is big.

  2. Use lax.map Here (What are the tradeoffs between jax.lax.map and jax.vmap?) I read following:

The lax.map solution will generally be slow, because it is always executed sequentially with no possibilty of fusing/parallelization between iterations.

So I don't know what to do. Thanks!

Upd after answer:

N = 100
d = 1000
key = jrandom.key(0)
Ajnp = jrandom.normal(key, (N, d, d))
Bjnp = jrandom.normal(key, (N, d, d))

Anp = list(np.random.randn(N,d,d))
Bnp = list(np.random.randn(N,d,d))

vmatmul = vmap(jnp.matmul, (0,0))

def lmatmul(A,B):
    return [a @ b for a, b in zip(A,B)]
%timeit vmatmul(Ajnp, Bjnp).block_until_ready()  # jax vmap over arrays

6.59 ms ± 73.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit block_until_ready(lmatmul(list(Ajnp), list(Bjnp))) # jax loop over lists

13 ms ± 221 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit lmatmul(Anp, Bnp) # numpy loop over lists

1.28 s ± 13.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Solution

  • I think your best approach will be something like your original formulation, though you can avoid pre-allocating the out list:

    out = [a @ b for a, b in zip(A, B)]
    

    Because of JAX's Asynchronous dispatch, if you run this on an accelerator like GPU the operations will be executed in parallel to the extent possible.

    All of your other proposed solutions either won't work due to static shape limitations, will force sequential computation, or will incur overhead that will make them worse in practice than this more straightforward approach.