Search code examples
machine-learningreinforcement-learningopenai-gym

OpenAI Gym: How do I access environment registration data (for e.g. max_episode_steps) from within a custom OPenvironment?


I created a custom environment using OpenAI Gym. I want to have access to the max_episode_steps and reward_threshold that are specified in init.py For eg:

from gym.envs.registration import registry, register, make, spec
register(
    id='myenv-v0',
    entry_point='gym.envs.algorithmic:myenv',
    tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
    reward_threshold=25.0,
)

But how do I access this from gym_myenv.py? If I first create the environment and use env._max_episode_steps, I have access. However, I don't have access to _max_episode_steps from within gym_myenv.py.


Solution

  • As pointed out by the Gymnasium team, the max_episode_steps parameter is not passed to the base environment on purpose. We can, however, use a simple Gymnasium wrapper to inject it into the base environment:

    """This file contains a small gymnasium wrapper that injects the `max_episode_steps`
    argument of a potentially nested `TimeLimit` wrapper into the base environment under the
    `_time_limit_max_episode_steps` attribute.
    """
    import gymnasium as gym
    
    
    def get_time_limit_wrapper_max_episode_steps(env):
        """Returns the ``max_episode_steps`` attribute of a potentially nested
        ``TimeLimit`` wrapper.
    
        Args:
            env (gym.Env): The gymnasium environment.
    
        Returns:
            int: The value of the ``max_episode_steps`` attribute of a potentially nested
                ``TimeLimit`` wrapper. If the environment is not wrapped in a ``TimeLimit``
                wrapper, then this function returns ``None``.
        """
        if hasattr(env, "env"):
            if isinstance(env, gym.wrappers.TimeLimit):
                return env._max_episode_steps
            get_time_limit_wrapper_max_episode_steps(env.env)
        return None
    
    
    def inject_attribute_into_base_env(env, attribute_name, attribute_value):
        """Injects the ``max_episode_steps`` argument into the base environment under the
        `_time_limit_max_episode_steps` attribute.
    
        Args:
            env (gym.Env): The gymnasium environment.
            attribute_name (str): The attribute's name to inject into the base
                environment.
            attribute_value (object): The attribute's value to inject into the base
                environment.
        """
        if hasattr(env, "env"):
            return inject_attribute_into_base_env(env.env, attribute_name, attribute_value)
        setattr(env, attribute_name, attribute_value)
    
    
    class MaxEpisodeStepsInjectionWrapper(gym.Wrapper):
        """A gymnasium wrapper that injects the ``max_episode_steps`` attribute of the
        ``TimeLimit`` wrapper into the base environment as the
        ``_time_limit_max_episode_steps`` attribute. If the environment is not wrapped in
        a ``TimeLimit`` wrapper, then the ``_time_limit_max_episode_steps`` attribute is
        set to ``None``.
        """
        def __init__(self, env):
            """Wrap a gymnasium environment.
            Args:
                env (gym.Env): The gymnasium environment.
            """
            super().__init__(env)
    
            # Retrieve max_episode_steps from potentially nested TimeLimit wrappers.
            max_episode_steps = get_time_limit_wrapper_max_episode_steps(self.env)
    
            # Inject the max_episode_steps attribute into the base environment.
            inject_attribute_into_base_env(
                self.env, "_time_limit_max_episode_steps", max_episode_steps
            )
    

    This wrapper can then be specified through the additional_wrappers parameter of the gym.register method:

    import gymnasium as gym
    from MY_PACKAGE.max_episode_steps_injection_wrapper import MaxEpisodeStepsInjectionWrapper
    
    gym.register(
            id="ENV_ID",
            entry_point="ENV_ENTRY_POINT",
            max_episode_steps="MAX_EPISODE_STEPS",
            additional_wrappers=(MaxEpisodeStepsInjectionWrapper.wrapper_spec(),)
    )