Search code examples
jax

`jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[1,17].`


I am trying to perform multiprocessing to parallelize my program (that uses JAX) using pmap. I am a newbie with JAX and realize that maybe pmap isn't the right approach. The documentation here, said that pmap can express SPMD programs (which is the case here) and therefore I decided to use it.

Here's my program. I am basically trying to run a reinforcement learning program (that uses JAX too) twice, using parallel execution -

'''
For installation please do -
pip install gym
pip install sbx
pip install mujoco
pip install shimmy
'''
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax
import gym
from sbx import SAC

def my_func():

    env = gym.make("Humanoid-v4")
    model = SAC("MlpPolicy", env,verbose=0)
    model.learn(total_timesteps=7e5, progress_bar=True)

from jax import pmap
import jax.numpy as jnp

out = pmap(lambda _: my_func())(jnp.arange(2))

I get the following error -

(tbd) thoma@thoma-Lenovo-Legion-5-15IMH05H:~/PycharmProjects/tbd$ python new.py
/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/vec_env/patch_gym.py:49: UserWarning: You provided an OpenAI Gym environment. We strongly recommend transitioning to Gymnasium environments. Stable-Baselines3 is automatically wrapping your environments in a compatibility layer, which could potentially cause issues.
  warnings.warn(
/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)
  if not isinstance(terminated, (bool, np.bool8)):
Traceback (most recent call last):
  File "new.py", line 17, in <module>
    out = pmap(lambda _: my_func())(jnp.arange(2))
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/api.py", line 1779, in cache_miss
    execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 411, in xla_pmap_impl_lazy
    compiled_fun, fingerprint = parallel_callable(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 678, in parallel_callable
    pmap_computation = lower_parallel_callable(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 825, in lower_parallel_callable
    jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 748, in stage_parallel_callable
    jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2233, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "new.py", line 17, in <lambda>
    out = pmap(lambda _: my_func())(jnp.arange(2))
  File "new.py", line 12, in my_func
    model.learn(total_timesteps=7e5, progress_bar=True)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/sbx/sac/sac.py", line 173, in learn
    return super().learn(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 328, in learn
    rollout = self.collect_rollouts(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 557, in collect_rollouts
    actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 390, in _sample_action
    unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/base_class.py", line 553, in predict
    return self.policy.predict(observation, state, episode_start, deterministic)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/sbx/common/policies.py", line 58, in predict
    actions = np.array(actions).reshape((-1, *self.action_space.shape))
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/jax/_src/core.py", line 605, in __array__
    raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[1,17].
The error occurred while tracing the function <lambda> at new.py:17 for pmap. This value became a tracer due to JAX operations on these lines:

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line new.py:11 (my_func)

  operation a:u32[] = convert_element_type[new_dtype=uint32 weak_type=False] b
    from line new.py:11 (my_func)

  operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line new.py:11 (my_func)

  operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line new.py:11 (my_func)

  operation a:f32[376,256] = pjit[
  jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
      e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
      f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
      g:f32[] = div e 1.4142135381698608
      h:f32[] = erf g
      i:f32[] = div f 1.4142135381698608
      j:f32[] = erf i
      k:f32[376,256] = pjit[
        jaxpr={ lambda ; l:key<fry>[] m:f32[] n:f32[]. let
            o:f32[1,1] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 1)] m
            p:f32[1,1] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 1)] n
            q:u32[376,256] = random_bits[bit_width=32 shape=(376, 256)] l
            r:u32[376,256] = shift_right_logical q 9
            s:u32[376,256] = or r 1065353216
            t:f32[376,256] = bitcast_convert_type[new_dtype=float32] s
            u:f32[376,256] = sub t 1.0
            v:f32[1,1] = sub p o
            w:f32[376,256] = mul u v
            x:f32[376,256] = add w o
            y:f32[376,256] = max o x
          in (y,) }
        name=_uniform
      ] b h j
      z:f32[376,256] = erf_inv k
      ba:f32[376,256] = mul 1.4142135381698608 z
      bb:f32[] = stop_gradient e
      bc:f32[] = nextafter bb inf
      bd:f32[] = stop_gradient f
      be:f32[] = nextafter bd -inf
      bf:f32[376,256] = pjit[
        jaxpr={ lambda ; bg:f32[376,256] bh:f32[] bi:f32[]. let
            bj:f32[376,256] = max bh bg
            bk:f32[376,256] = min bi bj
          in (bk,) }
        name=clip
      ] ba bc be
    in (bf,) }
  name=_truncated_normal
] bl bm bn
    from line new.py:11 (my_func)

(Additional originating lines are not shown.)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "new.py", line 17, in <module>
    out = pmap(lambda _: my_func())(jnp.arange(2))
  File "new.py", line 17, in <lambda>
    out = pmap(lambda _: my_func())(jnp.arange(2))
  File "new.py", line 12, in my_func
    model.learn(total_timesteps=7e5, progress_bar=True)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/sbx/sac/sac.py", line 173, in learn
    return super().learn(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 328, in learn
    rollout = self.collect_rollouts(
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 557, in collect_rollouts
    actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 390, in _sample_action
    unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/stable_baselines3/common/base_class.py", line 553, in predict
    return self.policy.predict(observation, state, episode_start, deterministic)
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/sbx/common/policies.py", line 58, in predict
    actions = np.array(actions).reshape((-1, *self.action_space.shape))
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[1,17].
The error occurred while tracing the function <lambda> at new.py:17 for pmap. This value became a tracer due to JAX operations on these lines:

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line new.py:11 (my_func)

  operation a:u32[] = convert_element_type[new_dtype=uint32 weak_type=False] b
    from line new.py:11 (my_func)

  operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line new.py:11 (my_func)

  operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line new.py:11 (my_func)

  operation a:f32[376,256] = pjit[
  jaxpr={ lambda ; b:key<fry>[] c:i32[] d:i32[]. let
      e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
      f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
      g:f32[] = div e 1.4142135381698608
      h:f32[] = erf g
      i:f32[] = div f 1.4142135381698608
      j:f32[] = erf i
      k:f32[376,256] = pjit[
        jaxpr={ lambda ; l:key<fry>[] m:f32[] n:f32[]. let
            o:f32[1,1] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 1)] m
            p:f32[1,1] = broadcast_in_dim[broadcast_dimensions=() shape=(1, 1)] n
            q:u32[376,256] = random_bits[bit_width=32 shape=(376, 256)] l
            r:u32[376,256] = shift_right_logical q 9
            s:u32[376,256] = or r 1065353216
            t:f32[376,256] = bitcast_convert_type[new_dtype=float32] s
            u:f32[376,256] = sub t 1.0
            v:f32[1,1] = sub p o
            w:f32[376,256] = mul u v
            x:f32[376,256] = add w o
            y:f32[376,256] = max o x
          in (y,) }
        name=_uniform
      ] b h j
      z:f32[376,256] = erf_inv k
      ba:f32[376,256] = mul 1.4142135381698608 z
      bb:f32[] = stop_gradient e
      bc:f32[] = nextafter bb inf
      bd:f32[] = stop_gradient f
      be:f32[] = nextafter bd -inf
      bf:f32[376,256] = pjit[
        jaxpr={ lambda ; bg:f32[376,256] bh:f32[] bi:f32[]. let
            bj:f32[376,256] = max bh bg
            bk:f32[376,256] = min bi bj
          in (bk,) }
        name=clip
      ] ba bc be
    in (bf,) }
  name=_truncated_normal
] bl bm bn
    from line new.py:11 (my_func)

(Additional originating lines are not shown.)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
   0% ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/700,000  [ 0:00:00 < -:--:-- , ? it/s ]Exception ignored in: <function tqdm.__del__ at 0x7fa875eb8af0>
Traceback (most recent call last):
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/tqdm/std.py", line 1149, in __del__
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/tqdm/rich.py", line 120, in close
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/progress.py", line 1177, in __exit__
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/progress.py", line 1163, in stop
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/live.py", line 155, in stop
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/console.py", line 1137, in line
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/console.py", line 1674, in print
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/console.py", line 1535, in _collect_renderables
  File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/rich/protocol.py", line 28, in rich_cast
ImportError: sys.meta_path is None, Python is likely shutting down

Basically I am trying to replace the parallelization performed by joblib with JAX. Here's my original program that I am changing -

from joblib import Parallel, delayed
import gym
import os
from sbx import SAC
import multiprocessing

def my_func():

    env = gym.make("Humanoid-v4")

    model = SAC("MlpPolicy", env,verbose=0)
    model.learn(total_timesteps=7e5, progress_bar=True)


Parallel(n_jobs=2)(delayed(my_func)() for i in range(2))

Solution

  • The problem is here:

     File "/home/thoma/anaconda3/envs/tbd/lib/python3.8/site-packages/sbx/common/policies.py", line 58, in predict
        actions = np.array(actions).reshape((-1, *self.action_space.shape))
    

    The sbx package is calling np.array on the inputs – this tells me that sbx is built on NumPy, not on JAX. JAX transformations like pmap are not compatible with NumPy functions, they require functions written with JAX operations. Unless sbx is substantially re-designed, you won't be able to use it with pmap, vmap, jit, grad, or other JAX transformations.