Search code examples
parallel-processingmultiprocessingpython-multiprocessingjaxspmd

JAX vmap vs pmap vs Python multiprocessing


I am rewriting some code from pure Python to JAX. I have gotten to the point where in my old code, I was using Python's multiprocessing module to parallelize the evaluation of a function over all of the CPU cores in a single node as follows:

# start pool process 
pool = multiprocessing.Pool(processes=10) # if node has 10 CPU cores, start 10 processes

# use pool.map to evaluate function(input) for each input in parallel
# suppose len(inputs) is very large and 10 inputs are processed in parallel at a time
# store the results in a list called out
out = pool.map(function,inputs)

# close pool processes to free memory
pool.close()
pool.join()

I know that JAX has vmap and pmap, but I don't understand if either of those are a drop-in replacement for how I'm using multiprocessing.pool.map above.

  1. Is vmap(function,in_axes=0)(inputs) distributing to all available CPU cores or what?
  2. How is pmap(function,in_axes=0)(inputs) different from vmap and multiprocessing.pool.map?
  3. Is my usage of multiprocessing.pool.map above an example of a "single-program, multiple-data (SPMD)" code that pmap is meant for?
  4. When I actually do pmap(function,in_axes=0)(inputs) I get an error -- ValueError: compiling computation that requires 10 logical devices, but only 1 XLA devices are available (num_replicas=10, num_partitions=1) -- what does this mean?
  5. Finally, my use case is very simple: I merely want to use some/all of the CPU cores on a single node (e.g., all 10 CPU cores on my Macbook). But I have heard about nesting pmap(vmap) -- is this used to parallelize over the cores of multiple connected nodes (say on a supercomputer)? This would be more akin to mpi4py rather than multiprocessing (the latter is restricted to a single node).

Solution

    1. Is vmap(function,in_axes=0)(inputs) distributing to all available CPU cores or what?

    No, vmap has nothing to do with parallelization. It is a vectorizing transformation, not a parallelizing transformation. In the course of normal operation, JAX may use multiple cores via XLA, so vmapped operations may also do this. But there's no explicit parallelization in vmap.

    1. How is pmap(function,in_axes=0)(inputs) different from vmap and multiprocessing.pool.map?

    pmap parallelizes over multiple XLA devices. vmap does not parallelize, but rather vectorizes on a single device. multiprocessing parallelizes over multiple Python processes.

    1. Is my usage of multiprocessing.pool.map above an example of a "single-program, multiple-data (SPMD)" code that pmap is meant for?

    Yes, it could be described as SPMD across multiple python processes.

    1. When I actually do pmap(function,in_axes=0)(inputs) I get an error -- ValueError: compiling computation that requires 10 logical devices, but only 1 XLA devices are available (num_replicas=10, num_partitions=1) -- what does this mean?

    pmap parallelizes over multiple XLA devices, and you have configured only a single XLA device, so the requested operation is not possible.

    1. Finally, my use case is very simple: I merely want to use some/all of the CPU cores on a single node (e.g., all 10 CPU cores on my Macbook). But I have heard about nesting pmap(vmap) -- is this used to parallelize over the cores of multiple connected nodes (say on a supercomputer)? This would be more akin to mpi4py rather than multiprocessing (the latter is restricted to a single node).

    Yes, I believe that pmap can be used to compute on multiple CPU cores. Whether it's nested with vmap is irrelevant. See JAX pmap with multi-core CPU.

    Note also that jax.pmap is deprecated in favor of the newer jax.shard_map, which is a much more flexible transformation for multi-device/multi-host computation. There's some info here: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html and https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html