Search code examples
python-3.xopenai-gymstable-baselines

Dict Observation Space for Stable Baselines3 Not Working


I've created a minimal reproducible example below, this can be run in a new Google Colab notebook for ease. Once the first install finishes, just Runtime > Restart and Run All for it to take effect.

I've made a simple roulette game environment below for testing. For the observation space, I've created a gym.spaces.Dict() which you will see (the code is well commented).

It trains just fine, but when it gets to the testing iteration, I get the error:

ValueError                                Traceback (most recent call last)
<ipython-input-56-7c2cb900b44f> in <module>
      6 obs = env.reset()
      7 for i in range(1000):
----> 8     action, _state = model.predict(obs, deterministic=True)
      9     obs, reward, done, info = env.step(action)
     10     env.render()

ValueError: Error: Unexpected observation shape () for Box environment, please use (1,) or (n_env, 1) for the observation shape.

I read somewhere that the dict space needs to be flattened with gym.wrappers.FlattenObservation, so I change this line:

    action, _state = model.predict(obs, deterministic=True)

...to:

    action, _state = model.predict(FlattenObservation(obs), deterministic=True)

...which results in this error:

AttributeError                            Traceback (most recent call last)
<ipython-input-57-87824c61fc45> in <module>
      6 obs = env.reset()
      7 for i in range(1000):
----> 8     action, _state = model.predict(FlattenObservation(obs), deterministic=True)
      9     obs, reward, done, info = env.step(action)
     10     env.render()

AttributeError: 'collections.OrderedDict' object has no attribute 'observation_space'

I've also tried doing this, which results in the same error as the last one:

obs = env.reset()
obs = FlattenObservation(obs)

So clearly I'm not doing something right, but I just don't know what it is as this'll be the first time I'm working with a Dict space.

import os, sys
if not os.path.isdir('/usr/local/lib/python3.7/dist-packages/stable_baselines3'):
    !pip3 install stable_baselines3
    print("\n\n\n Stable Baselines3 has been installed, Restart and Run All now. DO NOT factory reset, or you'll have to start over\n")
    sys.exit(0)

from random import randint
from numpy import inf, float32, array, int32, int64
import gym
from gym.wrappers import FlattenObservation
from stable_baselines3 import A2C, DQN, PPO

"""Roulette environment class"""
class Roulette_Environment(gym.Env):

    metadata = {'render.modes': ['human', 'text']}

    """Initialize the environment"""
    def __init__(self):
        super(Roulette_Environment, self).__init__()

        # Some global variables
        self.max_table_limit = 1000
        self.initial_bankroll = 2000

        # Spaces
        # Each number on roulette board can have 0-1000 units placed on it
        self.action_space = gym.spaces.Box(low=0, high=1000, shape=(37,))

        # We're going to keep track of how many times each number shows up
        # while we're playing, plus our current bankroll and the max
        # table betting limit so the agent knows how much $ in total is allowed
        # to be placed on the table. Going to use a Dict space for this.
        self.observation_space = gym.spaces.Dict(
            {
                "0": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "1": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "2": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "3": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "4": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "5": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "6": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "7": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "8": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "9": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "10": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "11": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "12": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "13": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "14": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "15": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "16": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "17": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "18": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "19": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "20": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "21": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "22": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "23": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "24": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "25": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "26": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "27": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "28": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "29": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "30": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "31": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "32": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "33": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "34": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "35": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "36": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                
                "current_bankroll": gym.spaces.Box(low=-inf, high=inf, shape=(1,), dtype=int),
                
                "max_table_limit": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
            }
        )

    """Reset the Environment"""
    def reset(self):
        self.current_bankroll = self.initial_bankroll
        self.done = False

        # Take a sample from the observation_space to modify the values of
        self.current_state = self.observation_space.sample()
        
        # Reset each number being tracked throughout gameplay to 0
        for i in range(0, 37):
            self.current_state[str(i)] = 0

        # Reset our globals
        self.current_state['current_bankroll'] = self.current_bankroll
        self.current_state['max_table_limit'] = self.max_table_limit
        
        return self.current_state


    """Step Through the Environment"""
    def step(self, action):
        
        # Convert actions to ints cuz they show up as floats,
        # even when defined as ints in the environment.
        # https://github.com/openai/gym/issues/3107
        for i in range(len(action)):
            action[i] = int(action[i])
        self.current_action = action
        
        # Subtract your bets from bankroll
        sum_of_bets = sum([bet for bet in self.current_action])

        # Spin the wheel
        self.current_number = randint(a=0, b=36)

        # Calculate payout/reward
        self.reward = 36 * self.current_action[self.current_number] - sum_of_bets

        self.current_bankroll += self.reward

        # Update the current state
        self.current_state['current_bankroll'] = self.current_bankroll
        self.current_state[str(self.current_number)] += 1

        # If we've doubled our money, or lost our money
        if self.current_bankroll >= self.initial_bankroll * 2 or self.current_bankroll <= 0:
            self.done = True

        return self.current_state, self.reward, self.done, {}


    """Render the Environment"""
    def render(self, mode='text'):
        # Text rendering
        if mode == "text":
            print(f'Bets Placed: {self.current_action}')
            print(f'Number rolled: {self.current_number}')
            print(f'Reward: {self.reward}')
            print(f'New Bankroll: {self.current_bankroll}')

env = Roulette_Environment()

model = PPO('MultiInputPolicy', env, verbose=1)
model.learn(total_timesteps=10000)

obs = env.reset()
# obs = FlattenObservation(obs)

for i in range(1000):
    action, _state = model.predict(obs, deterministic=True)
    # action, _state = model.predict(FlattenObservation(obs), deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

Solution

  • Unfortunately, stable-baselines3 is pretty picky about the observation format.
    I ran into the same problem the last days.
    Some documentation as well as an example model helped me figure things out:

    It is possible to use Dict-observations

    However, the values of Box-Spaces must be mapped as numpy.ndarrays with correct dtypes.
    For Discrete observations, the observation can also be passed as int value. However, I'm not completely sure if this still holds for multidimensional MultiDiscrete-spaces

    A very simple solution

    A solution to your example would be to replace code everytime you reassign a value of your Dict through:
    self.current_state[key] = np.array([value], dtype=int)

    Below you find a working implementation of your problem (my system has Python=3.10 installed, though. But it should work on lower versions as well).

    Working code:

    import os, sys
    
    from random import randint
    from numpy import inf, float32, array, int32, int64
    import gym
    from gym.wrappers import FlattenObservation
    from stable_baselines3 import A2C, DQN, PPO
    import numpy as np
    
    """Roulette environment class"""
    class Roulette_Environment(gym.Env):
    
        metadata = {'render.modes': ['human', 'text']}
    
        """Initialize the environment"""
        def __init__(self):
            super(Roulette_Environment, self).__init__()
    
            # Some global variables
            self.max_table_limit = 1000
            self.initial_bankroll = 2000
    
            # Spaces
            # Each number on roulette board can have 0-1000 units placed on it
            self.action_space = gym.spaces.Box(low=0, high=1000, shape=(37,))
    
            # We're going to keep track of how many times each number shows up
            # while we're playing, plus our current bankroll and the max
            # table betting limit so the agent knows how much $ in total is allowed
            # to be placed on the table. Going to use a Dict space for this.
            self.observation_space = gym.spaces.Dict(
                {
                    "0": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "1": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "2": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "3": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "4": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "5": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "6": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "7": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "8": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "9": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "10": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "11": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "12": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "13": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "14": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "15": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "16": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "17": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "18": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "19": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "20": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "21": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "22": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "23": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "24": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "25": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "26": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "27": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "28": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "29": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "30": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "31": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "32": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "33": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "34": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "35": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    "36": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                    
                    "current_bankroll": gym.spaces.Box(low=-inf, high=inf, shape=(1,), dtype=int),
                    
                    "max_table_limit": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                }
            )
    
        """Reset the Environment"""
        def reset(self):
            self.current_bankroll = self.initial_bankroll
            self.done = False
    
            # Take a sample from the observation_space to modify the values of
            self.current_state = self.observation_space.sample()
            
            # Reset each number being tracked throughout gameplay to 0
            for i in range(0, 37):
                self.current_state[str(i)] = np.array([0], dtype=int)
    
            # Reset our globals
            self.current_state['current_bankroll'] = np.array([self.current_bankroll], dtype=int)
            self.current_state['max_table_limit'] = np.array([self.max_table_limit], dtype=int)
            
            return self.current_state
    
    
        """Step Through the Environment"""
        def step(self, action):
            
            # Convert actions to ints cuz they show up as floats,
            # even when defined as ints in the environment.
            # https://github.com/openai/gym/issues/3107
            for i in range(len(action)):
                action[i] = int(action[i])
            self.current_action = action
            
            # Subtract your bets from bankroll
            sum_of_bets = sum([bet for bet in self.current_action])
    
            # Spin the wheel
            self.current_number = randint(a=0, b=36)
    
            # Calculate payout/reward
            self.reward = 36 * self.current_action[self.current_number] - sum_of_bets
    
            self.current_bankroll += self.reward
    
            # Update the current state
            self.current_state['current_bankroll'] = np.array([self.current_bankroll], dtype=int)
            self.current_state[str(self.current_number)] += np.array([1], dtype=int)
    
            # If we've doubled our money, or lost our money
            if self.current_bankroll >= self.initial_bankroll * 2 or self.current_bankroll <= 0:
                self.done = True
    
            return self.current_state, self.reward, self.done, {}
    
    
        """Render the Environment"""
        def render(self, mode='text'):
            # Text rendering
            if mode == "text":
                print(f'Bets Placed: {self.current_action}')
                print(f'Number rolled: {self.current_number}')
                print(f'Reward: {self.reward}')
                print(f'New Bankroll: {self.current_bankroll}')
    
    env = Roulette_Environment()
    
    model = PPO('MultiInputPolicy', env, verbose=1)
    model.learn(total_timesteps=10)
    
    obs = env.reset()
    # obs = FlattenObservation(obs)
    
    for i in range(1000):
        action, _state = model.predict(obs, deterministic=True)
        # action, _state = model.predict(FlattenObservation(obs), deterministic=True)
        obs, reward, done, info = env.step(action)
        env.render()
        if done:
          obs = env.reset()