I am trying to run multiple sbx programs (that use JAX) concurrently using joblib
. Here is my program -
'''
For installation please do -
pip install gym
pip install sbx-rl
pip install mujoco
pip install shimmy
'''
from joblib import Parallel, delayed
import gym
from sbx import SAC
# from stable_baselines3 import SAC
def train():
env = gym.make("Humanoid-v4")
model = SAC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=7e5, progress_bar=True)
def train_model():
train()
if __name__ == '__main__':
Parallel(n_jobs=10)(delayed(train)() for i in range(3))
This is the error that I am getting -
/home/dgthomas/.local/lib/python3.10/site-packages/stable_baselines3/common/vec_env/patch_gym.py:49: UserWarning: You provided an OpenAI Gym environment. We strongly recommend transitioning to Gymnasium environments. Stable-Baselines3 is automatically wrapping your environments in a compatibility layer, which could potentially cause issues.
warnings.warn(
/home/dgthomas/.local/lib/python3.10/site-packages/stable_baselines3/common/vec_env/patch_gym.py:49: UserWarning: You provided an OpenAI Gym environment. We strongly recommend transitioning to Gymnasium environments. Stable-Baselines3 is automatically wrapping your environments in a compatibility layer, which could potentially cause issues.
warnings.warn(
/home/dgthomas/.local/lib/python3.10/site-packages/stable_baselines3/common/vec_env/patch_gym.py:49: UserWarning: You provided an OpenAI Gym environment. We strongly recommend transitioning to Gymnasium environments. Stable-Baselines3 is automatically wrapping your environments in a compatibility layer, which could potentially cause issues.
warnings.warn(
2024-01-30 11:19:12.354168: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory
2024-01-30 11:19:12.354264: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory; current tracing scope: custom-call.11; current profiling annotation: XlaModule:#prefix=jit(_threefry_split)/jit(main),hlo_module=jit__threefry_split,program_id=2#.
joblib.externals.loky.process_executor._RemoteTraceback:
"""
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 463, in _process_worker
r = call_item()
File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 291, in __call__
return self.fn(*self.args, **self.kwargs)
File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 589, in __call__
return [func(*args, **kwargs)
File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 589, in <listcomp>
return [func(*args, **kwargs)
File "/work/LAS/usr/tbd/5_test.py", line 23, in my_func
model = SAC("MlpPolicy", env,verbose=0)
File "/home/dgthomas/.local/lib/python3.10/site-packages/sbx/sac/sac.py", line 109, in __init__
self._setup_model()
File "/home/dgthomas/.local/lib/python3.10/site-packages/sbx/sac/sac.py", line 126, in _setup_model
self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
File "/home/dgthomas/.local/lib/python3.10/site-packages/sbx/sac/policies.py", line 143, in build
key, actor_key, qf_key, dropout_key = jax.random.split(key, 4)
File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/random.py", line 303, in split
return _return_prng_keys(wrapped, _split(typed_key, num))
File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/random.py", line 289, in _split
return prng.random_split(key, shape=shape)
File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/prng.py", line 769, in random_split
return random_split_p.bind(keys, shape=shape)
File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 444, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 447, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 935, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/prng.py", line 781, in random_split_impl
base_arr = random_split_impl_base(
File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/prng.py", line 787, in random_split_impl_base
return split(base_arr)
File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/prng.py", line 786, in <lambda>
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, shape))
File "/home/dgthomas/.local/lib/python3.10/site-packages/jax/_src/prng.py", line 1291, in threefry_split
return _threefry_split(key, shape)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory; current tracing scope: custom-call.11; current profiling annotation: XlaModule:#prefix=jit(_threefry_split)/jit(main),hlo_module=jit__threefry_split,program_id=2#.
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/work/LAS/usr/tbd/5_test.py", line 27, in <module>
Parallel(n_jobs=3)(delayed(my_func)() for i in range(3))
File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
return output if self.return_generator else list(output)
File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
yield from self._retrieve()
File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
self._raise_error_fast()
File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
error_job.get_result(self.timeout)
File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
return self._return_or_raise()
File "/home/dgthomas/.local/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
raise self._result
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory; current tracing scope: custom-call.11; current profiling annotation: XlaModule:#prefix=jit(_threefry_split)/jit(main),hlo_module=jit__threefry_split,program_id=2#.
I am using a 40 GB GPU (a100-pcie
). Therefore I doubt that my GPU is running out of memory. Please let me know if any clarification is needed.
Edit 1: This is how I call my program - export XLA_PYTHON_CLIENT_PREALLOCATE=false && python 5_test.py
(The name of my program is 5_test.py
)
It appears you are using multiple processes targeting the same GPU. In each process, JAX will attempt to reserve 75% of the available GPU memory (see GPU memory allocation), so attempting this with two or more processes will exhaust the available memory.
You could fix this by turning off pre-allocation as mentioned in that doc, by setting the environment variables XLA_PYTHON_CLIENT_PREALLOCATE=false
or XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
(with .XX
set to .08
or something suitable), but I suspect the end result will be less efficient than if you had just run your full program from a single JAX process: multiple host processes targeting a single GPU device concurrently will just compete with each other for resources and lead to suboptimal results.