Search code examples
jax

How to use JAX pmap with CPU cores


I am trying to use JAX pmap but I am getting the error that XLA devices aren't visible - Here's my code -

import jax.numpy as jnp
import os
from jax import pmap
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

out = pmap(lambda x: x ** 2)(jnp.arange(8))
print(out)

Here's the error -

Traceback (most recent call last):
  File "new.py", line 6, in <module>
    out = pmap(lambda x: x ** 2)(jnp.arange(8))
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/api.py", line 1779, in cache_miss
    execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 411, in xla_pmap_impl_lazy
    compiled_fun, fingerprint = parallel_callable(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 682, in parallel_callable
    pmap_executable = pmap_computation.compile()
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 923, in compile
    executable = UnloadedPmapExecutable.from_hlo(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 993, in from_hlo
    raise ValueError(msg.format(shards.num_global_shards,
jax._src.traceback_util.UnfilteredStackTrace: ValueError: compiling computation that requires 8 logical devices, but only 1 XLA devices are available (num_replicas=8)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "new.py", line 6, in <module>
    out = pmap(lambda x: x ** 2)(jnp.arange(8))
ValueError: compiling computation that requires 8 logical devices, but only 1 XLA devices are available (num_replicas=8)

Based on this and this discussion, I did this os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8', but it doesn't seem to work.


Edit 1:

I tried this but it still doesn't work -

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax


from jax import pmap
import jax.numpy as jnp

out = pmap(lambda x: x ** 2)(jnp.arange(8))
print(out)


Solution

  • XLA flags are read when JAX is imported, so you need to set them before importing JAX if you want the flags to have an effect.

    You should also make sure you're in a clean runtime (i.e. not using a Jupyter kernel where you have previously imported jax).

    Additionally, keep in mind that --xla_force_host_platform_device_count=8 only affects the host (CPU) device count, so the code as written above won't work if you're using GPU-enabled JAX with a single GPU device. If this is the case, you can force pmap to run on the non-default CPU devices using the devices argument:

    out = pmap(lambda x: x ** 2, devices=jax.devices('cpu')(jnp.arange(8))