Search code examples
pythonarraysjaxgoogle-jax

Can jax.vmap() do a hstack()?


As the title says, I currently manually hstack() the first axis of a 3D array returned by jax.vmap(). In my code, the copy operation in hstack() is a currently a speed bottleneck. Can I avoid this by instructing jax.vmap() to do this directly?

Here is a simplified example:

import jax
import jax.numpy as jnp

def f(a, b, c):
  return jnp.array([[a.sum(), b.sum()], [c.sum(), 0.]]) # Returns a 2x2 array

def arr(m, n):
  return jnp.arange(m*n).reshape((m, n))

m = 3

a = arr(m, 2)
b = arr(m, 5)
c = arr(m, 7)

fv = jax.vmap(f)

vmap_output = fv(a, b, c)
desired_output = jnp.hstack(fv(a, b, c))

print(vmap_output)
print(desired_output)

This yields:

# vmap() output
[[[  1.  10.]
  [ 21.   0.]]

 [[  5.  35.]
  [ 70.   0.]]

 [[  9.  60.]
  [119.   0.]]]
# Desired output
[[  1.  10.   5.  35.   9.  60.]
 [ 21.   0.  70.   0. 119.   0.]]

If this is not possible, I would resort to pre-allocating an array and simply writing to the columns manually, but I hope to avoid this. Thanks for any clue!


Update from @jakevdp's answer

Alright, it isn't possible. So I resort to writing to the columns, but this fails as well:

def g(output, idx, a, b, c):
  block = jnp.array([[a.sum(), b.sum()], [c.sum(), 0.]]) # Returns a 2x2 array
  jax.lax.dynamic_update_slice_in_dim(output, block, idx*2, axis=1)

# Defined above: jax, jnp, m, a, b, c

g_output = jnp.zeros((2, 2*m))
idxs = jnp.arange(m)

gv = jax.vmap(g, in_axes=(None, 0, 0, 0, 0))

gv(g_output, idxs, a, b, c)

print(g_output)

This yields:

[[0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]

So writing to g_output in the function g is not retained. Is there a way around this?


Solution

  • No, vmap does not have any built-in capability to stack outputs differently than the batching semantics would imply. But if you're interested in fusing the hstack operation with the vmap operation to the extent possible, you could do so by wrapping it in jit. For example:

    @jax.jit
    def do_the_thing(a, b, c):
      return jnp.hstack(fv(a, b, c))
    
    print(do_the_thing(a, b, c))
    

    Edit: responding to your edited question: the reason the result is all zeros is because your function doesn't do anything: it returns None, so there's no way for it to affect the input array called g_output. JAX requires pure functions so side-effecting code like what you wrote above is not compatible. If you wanted to replace the hstack with an indexed update, you could do something like this:

    i = jnp.arange(2).reshape(1, 2, 1)
    j = jnp.arange(6).reshape(3, 1, 2)
    g_output = jnp.zeros((2, 2*m)).at[i, j].set(fv(a, b, c))
    

    but a nontrivial scatter operation like this will not typically be faster than a simple reshape, especially if you're running on an accelerator like GPU.

    If your arrays are large enough that reshapes are costly, you might find that a more direct implementation is better; for example:

    @jax.jit
    def g(a, b, c):
      output = jnp.zeros((2, 6))
      output = output.at[0, 0::2].set(a.sum(1))
      output = output.at[0, 1::2].set(b.sum(1))
      output = output.at[1, 0::2].set(c.sum(1))
      return output
    
    g_output = g(a, b, c)