Search code examples
pythonreinforcement-learningstable-baselines

State-action transformation of collected experience in stable baselines replay buffer


I am working with stable baselines 3 applied to a very expensive problem. I have set everything up for maximum sample-efficiency already and would like to implement the method described in this article: https://arxiv.org/pdf/2111.03454.pdf Namely, for every step taken, I would like to apply geometric transformations in order to produce valid experiences and add them to the replay buffer together with the single experience that has actually been simulated. Does any one know if there's an existing way to do this? I have reviewed the documentation and her replay buffer class but I admit that the answer is still not obvious to me.


Solution

  • It seems that implementing a derived ReplayBuffer class was the answer. Whether or not it will help with the training remains to be seen, but it works the way I intended. Here the transformations were chosen to work with the Gym's pendulum environment.

    class CustomReplayBuffer(ReplayBuffer):
        """ Specialised buffer that applies geometric transformations to the observations
        and actions in order to fill up the space more quickly and (hopefully) provide
        higher sample efficiency.
        """
    
        def __init__(
            self,
            buffer_size: int,
            observation_space: spaces.Space,
            action_space: spaces.Space,
            device: Union[th.device, str] = "auto",
            n_envs: int = 1,
            optimize_memory_usage: bool = False,
            handle_timeout_termination: bool = True,
        ):
            super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs,
                             optimize_memory_usage=False, handle_timeout_termination=True)
    
        def add(
            self,
            obs: np.ndarray,
            next_obs: np.ndarray,
            action: np.ndarray,
            reward: np.ndarray,
            done: np.ndarray,
            infos: List[Dict[str, Any]],
        ) -> None:
            # Reshape needed when using multiple envs with discrete observations
            # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
            if isinstance(self.observation_space, spaces.Discrete):
                obs = obs.reshape((self.n_envs, *self.obs_shape))
                next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))
    
            # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
            action = action.reshape((self.n_envs, self.action_dim))
    
            transformations_obs = [
                # "Standard"
                [1., 1., 1.],
                # Mirror everything around the vertical axis.
                # NOTE: the env definition has the x-axis pointing UP!
                [1., -1., -1.],
            ]
            transformations_act = [
                1.,
                -1,
            ]
    
            for i in range(len(transformations_obs)):
                # Copy to avoid modification by reference, apply transformation.
                self.observations[self.pos] = np.array(obs).copy() * transformations_obs[i]
                self.next_observations[self.pos] = np.array(next_obs).copy() * transformations_obs[i]
                self.actions[self.pos] = np.array(action).copy() * transformations_act[i]
                # Reward and done are unchanged.
                self.rewards[self.pos] = np.array(reward).copy()
                self.dones[self.pos] = np.array(done).copy()
                if self.handle_timeout_termination:
                    self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
                # Advance in the buffer. Check if full.
                self.pos += 1
                if self.pos == self.buffer_size:
                    self.full = True
                    self.pos = 0