Search code examples
multiprocessingtpujaxpmap

Test jax.pmap before deploying on multi-device hardware


I am coding on a single-device laptop and I am using jax.pmap because my code will run on multiple TPUs. I would like to "fake" having multiple devices to test my code and try different things.

Is there any way to achieve this? Thanks!


Solution

  • You can spoof multiple XLA devices backed by a single device by setting the following environment variable:

    $ set XLA_FLAGS="--xla_force_host_platform_device_count=8"
    

    In Python, you could do it like this

    # Note: must set this env variable before jax is imported
    import os
    os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=8"
    
    import jax
    
    print(jax.devices())
    # [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3),
    #  CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
    
    import jax.numpy as jnp
    out = jax.pmap(lambda x: x ** 2)(jnp.arange(8))
    print(out)
    # [ 0  1  4  9 16 25 36 49]
    

    Note that when a only a single physical device is present, all the "devices" here will be backed by the same threadpool. This will not improve performance of the code, but it can be useful for testing the semantics of parallel implementations on a single-device machine.