Can anyone help me understand the similarities and differences between the xmap and pmap in JAX? Read through the documentation multiple times, but still cannot understand the new concepts in the tutorials.
I am specifically interested in if there are any examples for how to convert pmap train setup to xmap train setup.
It seems to me that: mesh+xmap can do similar things as pmap. Am I understanding correctly?
pmap
is a simple parallelizing transform that only supports distribution of data over a single leading axis. It is deprecated, and will likely be removed in a future version of JAX.
xmap
is a generalization of pmap
that allows for more flexible parallelization over multiple named axes. It has always been experimental, and will likely be removed in a future version of JAX.
The best way to do parallel computation in JAX going forward is either implicitly via sharded inputs into jit
, or explicitly via shard_map
. Unfortunately, neither of these approaches is very well documented at the moment; there is some information at Distributed Arrays and Automatic Parallelization and shard_map
for simple per-device code but both are written more for developers than for end-users. That said, more comprehensive docs for the newer parallelism models in JAX are currently in progress, and should be on the website soon.