I'm trying to understand the behaviour of jax.vmap/pmap
, (jax: https://jax.readthedocs.io/). CUDA has threadId to let you know which thread is executing the code, is there a similar concept in jax? (jax.process_id
is not)
No, there is no real analog to CUDA threadid in JAX. Details about GPU thread assignment are handled at a lower level by the XLA compiler, and I don't know of any straightforward API to plumb this information back to JAX's Python runtime.
One case where JAX does offer higher-level handling of device assignment is when using pmap
; in this case you can explicitly pass a set of device IDs to the pmapped function if you want logic that depends on the device on which the mapped code is being executed. For example, I ran the following on an 8-device system:
import jax
import jax.numpy as jnp
num_devices = jax.device_count()
def f(device, data):
return data + device
device_index = jnp.arange(num_devices)
data = jnp.zeros((num_devices, 10))
jax.pmap(f)(device_index, data)
# ShardedDeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
# [6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
# [7., 7., 7., 7., 7., 7., 7., 7., 7., 7.]], dtype=float32)