What is the correct method for using multiple CPU cores with jax.pmap
?
The following example creates an environment variable for SPMD on CPU core backends, tests that JAX recognises the devices, and attempts a device lock.
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'
import jax as jx
import jax.numpy as jnp
jx.local_device_count()
# WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# 2
jx.devices("cpu")
# [CpuDevice(id=0), CpuDevice(id=1)]
def sfunc(x): while True: pass
jx.pmap(sfunc)(jnp.arange(2))
Executing from a jupyter kernel and observing htop
shows that only one core is locked
I receive the same output from htop
when omitting the first two lines and running:
$ env XLA_FLAGS=--xla_force_host_platform_device_count=2 python test.py
Replacing sfunc
with
def sfunc(x): return 2.0*x
and calling
jx.pmap(sfunc)(jnp.arange(2))
# ShardedDeviceArray([0., 2.], dtype=float32, weak_type=True)
does return a SharedDeviecArray
.
Clearly I am not correctly configuring JAX/XLA to use two cores. What am I missing and what can I do to diagnose the problem?
As far as I can tell, you are configuring the cores correctly (see e.g. Issue #2714). The problem lies in your test function:
def sfunc(x): while True: pass
This function gets stuck in an infinite loop at trace-time, not at run-time. Tracing happens in your host Python process on a single CPU (see How to think in JAX for an introduction to the idea of tracing within JAX transformations).
If you want to observe CPU usage at runtime, you'll have to use a function that finishes tracing and begins running. For that you could use any long-running function that actually produces results. Here is a simple example:
def sfunc(x):
for i in range(100):
x = (x @ x)
return x
jx.pmap(sfunc)(jnp.zeros((2, 1000, 1000)))