Search code examples
jax

Is there a CUDA threadId alike in Jax (google)?


I'm trying to understand the behaviour of jax.vmap/pmap, (jax: https://jax.readthedocs.io/). CUDA has threadId to let you know which thread is executing the code, is there a similar concept in jax? (jax.process_id is not)


Solution

  • No, there is no real analog to CUDA threadid in JAX. Details about GPU thread assignment are handled at a lower level by the XLA compiler, and I don't know of any straightforward API to plumb this information back to JAX's Python runtime.

    One case where JAX does offer higher-level handling of device assignment is when using pmap; in this case you can explicitly pass a set of device IDs to the pmapped function if you want logic that depends on the device on which the mapped code is being executed. For example, I ran the following on an 8-device system:

    import jax
    import jax.numpy as jnp
    
    num_devices = jax.device_count()
    
    def f(device, data):
      return data + device
    
    device_index = jnp.arange(num_devices)
    data = jnp.zeros((num_devices, 10))
    
    jax.pmap(f)(device_index, data)
    
    # ShardedDeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    #                     [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
    #                     [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
    #                     [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
    #                     [4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
    #                     [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
    #                     [6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
    #                     [7., 7., 7., 7., 7., 7., 7., 7., 7., 7.]], dtype=float32)