Search code examples
pythonreinforcement-learningopenai-gymstable-baselines

Python stable_baselines3 - AssertionError: The observation returned by `reset()` method must be an int


I am trying to learn reinforcement learning to train ai on custom games in python, and decided to use gym for the environment and stable-baselines3 for the training. I decided to start off with a basic tic tac toe environment. Here's my code

import gym
from gym import spaces
import numpy as np
from stable_baselines3.common.env_checker import check_env

class tictactoe(gym.Env):
    def __init__(self):
        #creating grid, action and obervation space
        self.box = [0,0,0,0,0,0,0,0,0]
        self.done=False
        self.turn = 1
        self.action_space = spaces.Discrete(9)
        self.observation_space = spaces.Discrete(9)

    def _get_obs(self):
        #returns the observation (the grid)
        return np.array(self.box)

    def iswinner(self, b, l):
        #function to check if a side has won
        return (b[1] == l and b[2] == l and b[3] == l) or (b[4] == l and b[5] == l and b[6] == l) or (b[7] == l and b[8] == l and b[9] == l) or (b[1] == l and b[4] == l and b[7] == l) or (b[7] == l and b[5] == l and b[3] == l) or (b[1] == l and b[5] == l and b[9] == l) or (b[8] == l and b[5] == l and b[2] == l) or (b[9] == l and b[6] == l and b[3] == l)
    
    def reset(self):
        #resets the env (grid, turn and done variable) and returns the observation
        self.box = [0,0,0,0,0,0,0,0,0]
        self.turn = 1
        self.done=False
        return self._get_obs()

    def step(self, action):
        #gives negative reward for illegal move (square occupied)
        if self.box[action] != 0:
            return self._get_obs(), -10, True, {} 
        #enters a value (1 or 2) in the grid and flips the turn
        self.box[action] = self.turn
        self.turn = (1 if self.turn == 2 else 2)
        reward = 0
        #checks if the game is over and sets a reward (+5 win, 0 draw)
        if self.iswinner([0]+self.box,1) and self.turn == 1: reward,self.done = 5,True
        elif 0 not in self.box: reward,self.done = 0,True
        #returns the observation (grid), reward, if the game is finished and extra information (empty dict for me)
        return self._get_obs(), reward, self.done, {}

    def render(self):
        #renders the board so it looks like a grid
        print(self.box[:3],self.box[3:6],self.box[6:],sep='\n')

#checking the env
env = tictactoe()
print(check_env(env))

Trying this code, I got the error AssertionError: The observation returned by 'reset()' method must be an int. I completely do not understand how this is supposed to work. Since my reset function returns the obervation from _get_obs. Is it trying to say that my observation must be an integer? That makes even less sense as now I have no idea how I'm supposed to do that.


Solution

  • When you do

    self.observation_space = spaces.Discrete(9)
    

    you're actually defining your observation space as a single value that can take in all values of {0, 1, 2, 3, 4, 5, 6, 7, 8} since you defined it as a discrete single-dimension space (aka an integer).

    As you said you were trying to make a tic-tac-toe environment, I presume what you were actually trying to do was something like

    self.observation_space = spaces.MultiDiscrete([3, 3, 3, 3, 3, 3, 3, 3, 3])
    # or self.observation_space = spaces.MultiDiscrete(9 * [3]), which would be cleaner
    

    which means you have 9 tiles in total and each tile can be in three different states (empty, X or O).