Search code examples
pythonjax

What are the tradeoffs between jax.lax.map and jax.vmap?


This Github issue hints that there are tradeoffs in performance / memory / compilation time when choosing between jax.lax.map and jax.vmap. What are the specific details of these tradeoffs with respect to both GPUs and CPUs?


Solution

  • The main difference is that jax.vmap is a vectorizing transformation, while lax.map is an iterative transformation. Let's look at an example.

    Example function: vector_dot

    Suppose you have implemented a simple function that takes 1D vectors as inputs. For simplicity let's make it a simple dot product, but one that asserts the inputs are one-dimensional:

    import jax
    import jax.numpy as jnp
    import numpy as np
    
    def vector_dot(x, y):
      assert x.ndim == y.ndim == 1, "vector inputs required"
      return jnp.dot(x, y)
    

    We can create some random 1D vectors to test this:

    rng = np.random.default_rng(8675309)
    x = rng.uniform(size=50)
    y = rng.uniform(size=50)
    
    print(vector_dot(x, y))
    # 14.919376
    

    To see what JAX is doing with this function under the hood, we can print the jaxpr, which is JAX's intermediate-level representation of a function:

    print(jax.make_jaxpr(vector_dot)(x, y))
    # { lambda ; a:f32[50] b:f32[50]. let
    #     c:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] a b
    #   in (c,) }
    

    This shows that JAX lowers this code to a single call to dot_general, the primitive for generalized dot products in JAX and XLA.

    Iterating over vector_dot

    Now, suppose you have a 2D input, and you'd like to apply this function to each row. There are several ways you could imagine doing this: three examples are using a Python for loop, using jax.vmap, or using jax.lax.map:

    def batched_dot_for_loop(x_batched, y):
      return jnp.array([vector_dot(x, y) for x in x_batched])
    
    def batched_dot_lax_map(x_batched, y):
      return jax.lax.map(lambda x: vector_dot(x, y), x_batched)
    
    batched_dot_vmap = jax.vmap(vector_dot, in_axes=(0, None))
    

    Applying these three functions to a batched input yields the same results, to within floating point precision:

    x_batched = rng.uniform(size=(4, 50))
    
    print(batched_dot_for_loop(x_batched, y))
    # [11.964929  12.485695  13.683528  12.9286175]
    
    print(batched_dot_lax_map(x_batched, y))
    # [11.964929  12.485695  13.683528  12.9286175]
    
    print(batched_dot_vmap(x_batched, y))
    # [11.964927  12.485697  13.683528  12.9286175]
    

    But if we look at the jaxpr for each, we can see that the three approaches lead to very different computational characteristics.

    The for loop solution looks like this:

    print(jax.make_jaxpr(batched_dot_for_loop)(x_batched, y))
    
    { lambda ; a:f32[4,50] b:f32[50]. let
        c:f32[1,50] = slice[
          limit_indices=(1, 50)
          start_indices=(0, 0)
          strides=(1, 1)
        ] a
        d:f32[50] = squeeze[dimensions=(0,)] c
        e:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] d b
        f:f32[1,50] = slice[
          limit_indices=(2, 50)
          start_indices=(1, 0)
          strides=(1, 1)
        ] a
        g:f32[50] = squeeze[dimensions=(0,)] f
        h:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] g b
        i:f32[1,50] = slice[
          limit_indices=(3, 50)
          start_indices=(2, 0)
          strides=(1, 1)
        ] a
        j:f32[50] = squeeze[dimensions=(0,)] i
        k:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] j b
        l:f32[1,50] = slice[
          limit_indices=(4, 50)
          start_indices=(3, 0)
          strides=(1, 1)
        ] a
        m:f32[50] = squeeze[dimensions=(0,)] l
        n:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] m b
        o:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
        p:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] h
        q:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
        r:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n
        s:f32[4] = concatenate[dimension=0] o p q r
      in (s,) }
    

    The key feature is that the iterations in the for loop are unrolled into a single long program.

    The lax.map version looks like this:

    print(jax.make_jaxpr(batched_dot_lax_map)(x_batched, y))
    
    { lambda ; a:f32[4,50] b:f32[50]. let
        c:f32[4] = scan[
          jaxpr={ lambda ; d:f32[50] e:f32[50]. let
              f:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] e d
            in (f,) }
          length=4
          linear=(False, False)
          num_carry=0
          num_consts=1
          reverse=False
          unroll=1
        ] b a
      in (c,) }
    

    The key feature is that it is loaded into a scan primitive, which is XLA's native static loop operation.

    The vmap version looks like this:

    print(jax.make_jaxpr(batched_dot_vmap)(x_batched, y))
    
    { lambda ; a:f32[4,50] b:f32[50]. let
        c:f32[4] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a b
      in (c,) }
    

    The key feature here is that the vmap transformation is able to recognize that a batched 1D dot product is equivalent to a 2D dot product, so the result is a single extremely efficient native operation.

    Performance considerations

    These three approaches can have very different performance characteristics. The details will depend on the specifics of the original function (here vector_dot) but in broad strokes, we can consider three aspects:

    Compilation Cost

    If you JIT-compile your program, you'll find:

    • The for-loop based solution will have compilation times that grow super-linearly with the number of iterations. This is due to the unrolling seen in the jaxpr above.
    • The lax.map and jax.vmap solutions will have fast compilation time, which under normal circumstances will not grow with the size of the batch dimension.

    Runtime

    In terms of runtime:

    • The for loop solution can be very fast, because XLA can often fuse operations between the unrolled iterations. This is the flip side of the long compilation times.
    • The lax.map solution will generally be slow, because it is always executed sequentially with no possibilty of fusing/parallelization between iterations.
    • The jax.vmap solution will generally be the fastest, especially on accelerators like GPU or TPU, because it can make use of native batching parallelism on the device.

    Memory Cost

    • The for loop and lax.map solutions generally have good memory performance, because they execute sequentially and don't require storage of large intermediate results.
    • The main downside of the jax.vmap solution is that it can cause memory to blow up because the entire problem must fit into memory at once. This is not an issue with the simple vector_dot function used here, but can be for more complicated functions.

    Benchmarks

    You can see these general principles at play when benchmarking the above functions. The following timings are on a Colab T4 GPU:

    y = rng.uniform(size=1000)
    x_batched = rng.uniform(size=(200, 1000))
    
    %time jax.jit(batched_dot_for_loop).lower(x_batched, y).compile()
    # CPU times: user 4.96 s, sys: 55 ms, total: 5.01 s
    # Wall time: 7.24 s
    %timeit jax.jit(batched_dot_for_loop)(x_batched, y).block_until_ready()
    # 1.09 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    
    %time jax.jit(batched_dot_lax_map).lower(x_batched, y).compile()
    # CPU times: user 117 ms, sys: 2.71 ms, total: 120 ms
    # Wall time: 172 ms
    %timeit jax.jit(batched_dot_lax_map)(x_batched, y).block_until_ready()
    # 2.67 ms ± 56.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    
    %time jax.jit(batched_dot_vmap).lower(x_batched, y).compile()
    # CPU times: user 51 ms, sys: 941 µs, total: 52 ms
    # Wall time: 103 ms
    %timeit jax.jit(batched_dot_vmap)(x_batched, y).block_until_ready()
    # 719 µs ± 129 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)