I have a function that will instantiate a huge array and do other things. I am running my code on TPUs so my memory is limited.
How can I execute my function specifically on the CPU?
If I do:
y = jax.device_put(my_function(), device=jax.devices("cpu")[0])
I guess that my_function()
is first executed on TPU and the result is put on CPU, which gives me memory error.
and using jax.config.update('jax_platform_name', 'cpu')
at the beginning of my code seems to have no effect.
Also please note that I can't modify my_function()
Thanks!
To directly specify the device on which a function should be executed, use the device
argument of jax.jit
. For example (using a GPU runtime because it's the accelerator I have access to at the moment):
import jax
gpu_device = jax.devices('gpu')[0]
cpu_device = jax.devices('cpu')[0]
def my_function(x):
return x.sum()
x = jax.numpy.arange(10)
x_gpu = jax.jit(my_function, device=gpu_device)(x)
print(x_gpu.device())
# gpu:0
x_cpu = jax.jit(my_function, device=cpu_device)(x)
print(x_cpu.device())
# TFRT_CPU_0
This can also be controlled with the jax.default_device
decorator around the call-site:
with jax.default_device(cpu_device):
print(jax.jit(my_function)(x).device())
# TFRT_CPU_0
with jax.default_device(gpu_device):
print(jax.jit(my_function)(x).device())
# gpu:0