I am coding on a single-device laptop and I am using jax.pmap
because my code will run on multiple TPUs. I would like to "fake" having multiple devices to test my code and try different things.
Is there any way to achieve this? Thanks!
You can spoof multiple XLA devices backed by a single device by setting the following environment variable:
$ set XLA_FLAGS="--xla_force_host_platform_device_count=8"
In Python, you could do it like this
# Note: must set this env variable before jax is imported
import os
os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=8"
import jax
print(jax.devices())
# [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3),
# CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
import jax.numpy as jnp
out = jax.pmap(lambda x: x ** 2)(jnp.arange(8))
print(out)
# [ 0 1 4 9 16 25 36 49]
Note that when a only a single physical device is present, all the "devices" here will be backed by the same threadpool. This will not improve performance of the code, but it can be useful for testing the semantics of parallel implementations on a single-device machine.