Search code examples
tensorflowreinforcement-learningtensorflow-agents

Can a tf-agents environment be defined with an unobservable exogenous state?


I apologize in advance for the question in the title not being very clear. I'm trying to train a reinforcement learning policy using tf-agents in which there exists some unobservable stochastic variable that affects the state.

For example, consider the standard CartPole problem, but we add wind where the velocity changes over time. I don't want to train an agent that relies on having observed the wind velocity at each step; I instead want the wind to affect the position and angular velocity of the pole, and the agent to learn to adapt just as it would in the wind-free environment. In this example however, we would need the wind velocity at the current time to be correlated with the wind velocity at the previous time e.g. we wouldn't want the wind velocity to change from 10m/s at time t to -10m/s at time t+1.

The problem I'm trying to solve is how to track the state of the exogenous variable without making it part of the observation spec that gets fed into the neural network when training the agent. Any guidance would be appreciated.


Solution

  • Yes, that is no problem at all. Your environment object (a subclass of PyEnvironment or TFEnvironment) can do whatever you want within it. The observation_spec requirement is only related to the TimeStep that you output in the step and reset methods (more precisely in your implementation of the _step and _reset abstract methods).

    Your environment however is completely free to have any additional attributes that you might want (like parameters to control wind generation) and any number of additional methods you like (like methods to generate the wind at this timestep according to self._wind_hyper_params). A quick schematic of your code would look like is below:

    class MyCartPole(PyEnvironment):
        def __init__(self, ..., wind_HP):
            ...    # self._observation_spec and _action_spec can be left unchanged
            self._wind_hyper_params = wind_HP
            self._wind_velocity = 0
            self._state = ...
        
        def _update_wind_velocity(self):
            self._wind_velocity = ...
    
        def factor_in_wind(self):
            self.state = ...    #update according to wind
    
        def _step(self, action):
            ...    # self._state update computations
            self._update_wind_velocity
            self.factor_in_wind()
    
            observations = self._state_to_observations()
            ...