This Github issue hints that there are tradeoffs in performance / memory / compilation time when choosing between jax.lax.map
and jax.vmap
. What are the specific details of these tradeoffs with respect to both GPUs and CPUs?
The main difference is that jax.vmap
is a vectorizing transformation, while lax.map
is an iterative transformation. Let's look at an example.
vector_dot
Suppose you have implemented a simple function that takes 1D vectors as inputs. For simplicity let's make it a simple dot product, but one that asserts the inputs are one-dimensional:
import jax
import jax.numpy as jnp
import numpy as np
def vector_dot(x, y):
assert x.ndim == y.ndim == 1, "vector inputs required"
return jnp.dot(x, y)
We can create some random 1D vectors to test this:
rng = np.random.default_rng(8675309)
x = rng.uniform(size=50)
y = rng.uniform(size=50)
print(vector_dot(x, y))
# 14.919376
To see what JAX is doing with this function under the hood, we can print the jaxpr
, which is JAX's intermediate-level representation of a function:
print(jax.make_jaxpr(vector_dot)(x, y))
# { lambda ; a:f32[50] b:f32[50]. let
# c:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] a b
# in (c,) }
This shows that JAX lowers this code to a single call to dot_general
, the primitive for generalized dot products in JAX and XLA.
vector_dot
Now, suppose you have a 2D input, and you'd like to apply this function to each row. There are several ways you could imagine doing this: three examples are using a Python for
loop, using jax.vmap
, or using jax.lax.map
:
def batched_dot_for_loop(x_batched, y):
return jnp.array([vector_dot(x, y) for x in x_batched])
def batched_dot_lax_map(x_batched, y):
return jax.lax.map(lambda x: vector_dot(x, y), x_batched)
batched_dot_vmap = jax.vmap(vector_dot, in_axes=(0, None))
Applying these three functions to a batched input yields the same results, to within floating point precision:
x_batched = rng.uniform(size=(4, 50))
print(batched_dot_for_loop(x_batched, y))
# [11.964929 12.485695 13.683528 12.9286175]
print(batched_dot_lax_map(x_batched, y))
# [11.964929 12.485695 13.683528 12.9286175]
print(batched_dot_vmap(x_batched, y))
# [11.964927 12.485697 13.683528 12.9286175]
But if we look at the jaxpr
for each, we can see that the three approaches lead to very different computational characteristics.
The for
loop solution looks like this:
print(jax.make_jaxpr(batched_dot_for_loop)(x_batched, y))
{ lambda ; a:f32[4,50] b:f32[50]. let
c:f32[1,50] = slice[
limit_indices=(1, 50)
start_indices=(0, 0)
strides=(1, 1)
] a
d:f32[50] = squeeze[dimensions=(0,)] c
e:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] d b
f:f32[1,50] = slice[
limit_indices=(2, 50)
start_indices=(1, 0)
strides=(1, 1)
] a
g:f32[50] = squeeze[dimensions=(0,)] f
h:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] g b
i:f32[1,50] = slice[
limit_indices=(3, 50)
start_indices=(2, 0)
strides=(1, 1)
] a
j:f32[50] = squeeze[dimensions=(0,)] i
k:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] j b
l:f32[1,50] = slice[
limit_indices=(4, 50)
start_indices=(3, 0)
strides=(1, 1)
] a
m:f32[50] = squeeze[dimensions=(0,)] l
n:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] m b
o:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
p:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] h
q:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
r:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n
s:f32[4] = concatenate[dimension=0] o p q r
in (s,) }
The key feature is that the iterations in the for
loop are unrolled into a single long program.
The lax.map
version looks like this:
print(jax.make_jaxpr(batched_dot_lax_map)(x_batched, y))
{ lambda ; a:f32[4,50] b:f32[50]. let
c:f32[4] = scan[
jaxpr={ lambda ; d:f32[50] e:f32[50]. let
f:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] e d
in (f,) }
length=4
linear=(False, False)
num_carry=0
num_consts=1
reverse=False
unroll=1
] b a
in (c,) }
The key feature is that it is loaded into a scan
primitive, which is XLA's native static loop operation.
The vmap
version looks like this:
print(jax.make_jaxpr(batched_dot_vmap)(x_batched, y))
{ lambda ; a:f32[4,50] b:f32[50]. let
c:f32[4] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a b
in (c,) }
The key feature here is that the vmap
transformation is able to recognize that a batched 1D dot product is equivalent to a 2D dot product, so the result is a single extremely efficient native operation.
These three approaches can have very different performance characteristics. The details will depend on the specifics of the original function (here vector_dot
) but in broad strokes, we can consider three aspects:
If you JIT-compile your program, you'll find:
for
-loop based solution will have compilation times that grow super-linearly with the number of iterations. This is due to the unrolling seen in the jaxpr above.lax.map
and jax.vmap
solutions will have fast compilation time, which under normal circumstances will not grow with the size of the batch dimension.In terms of runtime:
for
loop solution can be very fast, because XLA can often fuse operations between the unrolled iterations. This is the flip side of the long compilation times.lax.map
solution will generally be slow, because it is always executed sequentially with no possibilty of fusing/parallelization between iterations.jax.vmap
solution will generally be the fastest, especially on accelerators like GPU or TPU, because it can make use of native batching parallelism on the device.for
loop and lax.map
solutions generally have good memory performance, because they execute sequentially and don't require storage of large intermediate results.jax.vmap
solution is that it can cause memory to blow up because the entire problem must fit into memory at once. This is not an issue with the simple vector_dot
function used here, but can be for more complicated functions.You can see these general principles at play when benchmarking the above functions. The following timings are on a Colab T4 GPU:
y = rng.uniform(size=1000)
x_batched = rng.uniform(size=(200, 1000))
%time jax.jit(batched_dot_for_loop).lower(x_batched, y).compile()
# CPU times: user 4.96 s, sys: 55 ms, total: 5.01 s
# Wall time: 7.24 s
%timeit jax.jit(batched_dot_for_loop)(x_batched, y).block_until_ready()
# 1.09 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%time jax.jit(batched_dot_lax_map).lower(x_batched, y).compile()
# CPU times: user 117 ms, sys: 2.71 ms, total: 120 ms
# Wall time: 172 ms
%timeit jax.jit(batched_dot_lax_map)(x_batched, y).block_until_ready()
# 2.67 ms ± 56.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%time jax.jit(batched_dot_vmap).lower(x_batched, y).compile()
# CPU times: user 51 ms, sys: 941 µs, total: 52 ms
# Wall time: 103 ms
%timeit jax.jit(batched_dot_vmap)(x_batched, y).block_until_ready()
# 719 µs ± 129 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)