Search code examples
multicorejaxpmap

JAX pmap with multi-core CPU


What is the correct method for using multiple CPU cores with jax.pmap?

The following example creates an environment variable for SPMD on CPU core backends, tests that JAX recognises the devices, and attempts a device lock.

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'

import jax as jx
import jax.numpy as jnp

jx.local_device_count()
# WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# 2

jx.devices("cpu")
# [CpuDevice(id=0), CpuDevice(id=1)]

def sfunc(x): while True: pass

jx.pmap(sfunc)(jnp.arange(2))

Executing from a jupyter kernel and observing htop shows that only one core is locked

execute from jupyter kernel

I receive the same output from htop when omitting the first two lines and running:

$ env XLA_FLAGS=--xla_force_host_platform_device_count=2 python test.py

Replacing sfunc with

def sfunc(x): return 2.0*x

and calling

jx.pmap(sfunc)(jnp.arange(2))
# ShardedDeviceArray([0., 2.], dtype=float32, weak_type=True)

does return a SharedDeviecArray.

Clearly I am not correctly configuring JAX/XLA to use two cores. What am I missing and what can I do to diagnose the problem?


Solution

  • As far as I can tell, you are configuring the cores correctly (see e.g. Issue #2714). The problem lies in your test function:

    def sfunc(x): while True: pass
    

    This function gets stuck in an infinite loop at trace-time, not at run-time. Tracing happens in your host Python process on a single CPU (see How to think in JAX for an introduction to the idea of tracing within JAX transformations).

    If you want to observe CPU usage at runtime, you'll have to use a function that finishes tracing and begins running. For that you could use any long-running function that actually produces results. Here is a simple example:

    def sfunc(x):
      for i in range(100):
        x = (x @ x)
      return x
    
    jx.pmap(sfunc)(jnp.zeros((2, 1000, 1000)))