I am trying to perform multiprocessing to parallelize my program (that uses JAX) using pmap
. I am a newbie with JAX and realize that maybe pmap isn't the right approach. The documentation here, said that pmap
can express SPMD programs (which is the case here) and therefore I decided to use it.
Here's my program. I am basically trying to run a reinforcement learning program (that uses JAX too) twice, using parallel execution -
'''
For installation please do -
pip install gym
pip install sbx
pip install mujoco
pip install shimmy
'''
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
import gym
from sbx import SAC
def my_func():
env = gym.make("Humanoid-v4")
model = SAC("MlpPolicy", env,verbose=0)
model.learn(total_timesteps=7e5, progress_bar=True)
from jax import pmap
import jax.numpy as jnp
out = pmap(lambda _: my_func())(jnp.arange(2))
I get the following error -
(tbd) thoma@thoma-Lenovo-Legion-5-15IMH05H:~/PycharmProjects/tbd$ python new.py
/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/vec_env/patch_gym.py:49: UserWarning: You provided an OpenAI Gym environment. We strongly recommend transitioning to Gymnasium environments. Stable-Baselines3 is automatically wrapping your environments in a compatibility layer, which could potentially cause issues.
warnings.warn(
/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)
if not isinstance(terminated, (bool, np.bool8)):
Traceback (most recent call last):
File "new.py", line 17, in <module>
out = pmap(lambda _: my_func())(jnp.arange(2))
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/api.py", line 1779, in cache_miss
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 411, in xla_pmap_impl_lazy
compiled_fun, fingerprint = parallel_callable(
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 678, in parallel_callable
pmap_computation = lower_parallel_callable(
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 825, in lower_parallel_callable
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 748, in stage_parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2233, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "new.py", line 17, in <lambda>
out = pmap(lambda _: my_func())(jnp.arange(2))
File "new.py", line 12, in my_func
model.learn(total_timesteps=7e5, progress_bar=True)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/sbx/sac/sac.py", line 173, in learn
return super().learn(
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 328, in learn
rollout = self.collect_rollouts(
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 557, in collect_rollouts
actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 390, in _sample_action
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/base_class.py", line 553, in predict
return self.policy.predict(observation, state, episode_start, deterministic)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/sbx/common/policies.py", line 58, in predict
actions = np.array(actions).reshape((-1, *self.action_space.shape))
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/core.py", line 605, in __array__
raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[1,17].
The error occurred while tracing the function <lambda> at new.py:17 for pmap. This value became a tracer due to JAX operations on these lines:
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line new.py:11 (my_func)
operation a:u32[] = convert_element_type[new_dtype=uint32 weak_type=False] b
from line new.py:11 (my_func)
operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
from line new.py:11 (my_func)
operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
from line new.py:11 (my_func)
operation a:f32[376,256] = pjit[
jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
g:f32[] = div e 1.4142135381698608
h:f32[] = erf g
i:f32[] = div f 1.4142135381698608
j:f32[] = erf i
k:f32[376,256] = pjit[
jaxpr={ lambda ; l:key<fry>[] m:f32[] n:f32[]. let
o:f32[1,1] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 1)] m
p:f32[1,1] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 1)] n
q:u32[376,256] = random_bits[bit_width=32 shape=(376, 256)] l
r:u32[376,256] = shift_right_logical q 9
s:u32[376,256] = or r 1065353216
t:f32[376,256] = bitcast_convert_type[new_dtype=float32] s
u:f32[376,256] = sub t 1.0
v:f32[1,1] = sub p o
w:f32[376,256] = mul u v
x:f32[376,256] = add w o
y:f32[376,256] = max o x
in (y,) }
name=_uniform
] b h j
z:f32[376,256] = erf_inv k
ba:f32[376,256] = mul 1.4142135381698608 z
bb:f32[] = stop_gradient e
bc:f32[] = nextafter bb inf
bd:f32[] = stop_gradient f
be:f32[] = nextafter bd -inf
bf:f32[376,256] = pjit[
jaxpr={ lambda ; bg:f32[376,256] bh:f32[] bi:f32[]. let
bj:f32[376,256] = max bh bg
bk:f32[376,256] = min bi bj
in (bk,) }
name=clip
] ba bc be
in (bf,) }
name=_truncated_normal
] bl bm bn
from line new.py:11 (my_func)
(Additional originating lines are not shown.)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "new.py", line 17, in <module>
out = pmap(lambda _: my_func())(jnp.arange(2))
File "new.py", line 17, in <lambda>
out = pmap(lambda _: my_func())(jnp.arange(2))
File "new.py", line 12, in my_func
model.learn(total_timesteps=7e5, progress_bar=True)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/sbx/sac/sac.py", line 173, in learn
return super().learn(
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 328, in learn
rollout = self.collect_rollouts(
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 557, in collect_rollouts
actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 390, in _sample_action
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/base_class.py", line 553, in predict
return self.policy.predict(observation, state, episode_start, deterministic)
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/sbx/common/policies.py", line 58, in predict
actions = np.array(actions).reshape((-1, *self.action_space.shape))
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[1,17].
The error occurred while tracing the function <lambda> at new.py:17 for pmap. This value became a tracer due to JAX operations on these lines:
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line new.py:11 (my_func)
operation a:u32[] = convert_element_type[new_dtype=uint32 weak_type=False] b
from line new.py:11 (my_func)
operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
from line new.py:11 (my_func)
operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
from line new.py:11 (my_func)
operation a:f32[376,256] = pjit[
jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
g:f32[] = div e 1.4142135381698608
h:f32[] = erf g
i:f32[] = div f 1.4142135381698608
j:f32[] = erf i
k:f32[376,256] = pjit[
jaxpr={ lambda ; l:key<fry>[] m:f32[] n:f32[]. let
o:f32[1,1] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 1)] m
p:f32[1,1] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 1)] n
q:u32[376,256] = random_bits[bit_width=32 shape=(376, 256)] l
r:u32[376,256] = shift_right_logical q 9
s:u32[376,256] = or r 1065353216
t:f32[376,256] = bitcast_convert_type[new_dtype=float32] s
u:f32[376,256] = sub t 1.0
v:f32[1,1] = sub p o
w:f32[376,256] = mul u v
x:f32[376,256] = add w o
y:f32[376,256] = max o x
in (y,) }
name=_uniform
] b h j
z:f32[376,256] = erf_inv k
ba:f32[376,256] = mul 1.4142135381698608 z
bb:f32[] = stop_gradient e
bc:f32[] = nextafter bb inf
bd:f32[] = stop_gradient f
be:f32[] = nextafter bd -inf
bf:f32[376,256] = pjit[
jaxpr={ lambda ; bg:f32[376,256] bh:f32[] bi:f32[]. let
bj:f32[376,256] = max bh bg
bk:f32[376,256] = min bi bj
in (bk,) }
name=clip
] ba bc be
in (bf,) }
name=_truncated_normal
] bl bm bn
from line new.py:11 (my_func)
(Additional originating lines are not shown.)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
0% ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/700,000 [ 0:00:00 < -:--:-- , ? it/s ]Exception ignored in: <function tqdm.__del__ at 0x7fa875eb8af0>
Traceback (most recent call last):
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/tqdm/std.py", line 1149, in __del__
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/tqdm/rich.py", line 120, in close
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/progress.py", line 1177, in __exit__
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/progress.py", line 1163, in stop
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/live.py", line 155, in stop
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/console.py", line 1137, in line
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/console.py", line 1674, in print
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/console.py", line 1535, in _collect_renderables
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/protocol.py", line 28, in rich_cast
ImportError: sys.meta_path is None, Python is likely shutting down
Basically I am trying to replace the parallelization performed by joblib
with JAX
. Here's my original program that I am changing -
from joblib import Parallel, delayed
import gym
import os
from sbx import SAC
import multiprocessing
def my_func():
env = gym.make("Humanoid-v4")
model = SAC("MlpPolicy", env,verbose=0)
model.learn(total_timesteps=7e5, progress_bar=True)
Parallel(n_jobs=2)(delayed(my_func)() for i in range(2))
The problem is here:
File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/sbx/common/policies.py", line 58, in predict
actions = np.array(actions).reshape((-1, *self.action_space.shape))
The sbx
package is calling np.array
on the inputs – this tells me that sbx
is built on NumPy, not on JAX. JAX transformations like pmap
are not compatible with NumPy functions, they require functions written with JAX operations. Unless sbx
is substantially re-designed, you won't be able to use it with pmap
, vmap
, jit
, grad
, or other JAX transformations.