I tried to model the simplest coin flipping game where you have to predict if it is going to be a head. Sadly it won't run, given me:
Using cpu device
Traceback (most recent call last):
File "/home/user/python/simplegame.py", line 40, in <module>
model.learn(total_timesteps=10000)
File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/ppo/ppo.py", line 315, in learn
return super().learn(
File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 264, in learn
total_timesteps, callback = self._setup_learn(
File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/common/base_class.py", line 423, in _setup_learn
self._last_obs = self.env.reset() # type: ignore[assignment]
File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 77, in reset
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
TypeError: CoinFlipEnv.reset() got an unexpected keyword argument 'seed'
Here is the code:
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
class CoinFlipEnv(gym.Env):
def __init__(self, heads_probability=0.8):
super(CoinFlipEnv, self).__init__()
self.action_space = gym.spaces.Discrete(2) # 0 for heads, 1 for tails
self.observation_space = gym.spaces.Discrete(2) # 0 for heads, 1 for tails
self.heads_probability = heads_probability
self.flip_result = None
def reset(self):
# Reset the environment
self.flip_result = None
return self._get_observation()
def step(self, action):
# Perform the action (0 for heads, 1 for tails)
self.flip_result = int(np.random.rand() < self.heads_probability)
# Compute the reward (1 for correct prediction, -1 for incorrect)
reward = 1 if self.flip_result == action else -1
# Return the observation, reward, done, and info
return self._get_observation(), reward, True, {}
def _get_observation(self):
# Return the current coin flip result
return self.flip_result
# Create the environment with heads probability of 0.8
env = DummyVecEnv([lambda: CoinFlipEnv(heads_probability=0.8)])
# Create the PPO model
model = PPO("MlpPolicy", env, verbose=1)
# Train the model
model.learn(total_timesteps=10000)
# Save the model
model.save("coin_flip_model")
# Evaluate the model
obs = env.reset()
for _ in range(10):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
print(f"Action: {action}, Observation: {obs}, Reward: {rewards}")
What am I doing wrong?
This is in version 2.2.1.
The gymnasium.Env
class has the following signature which divers from the one by DummyVecEnv
which takes no arguments.
Env.reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) → tuple[ObsType, dict[str, Any]]
in other words seed
and options
are keyword-only which your own reset
function needs to implement. It returns the observation, info
tuple in the end.
The problems to note:
reset
does not match, needs seed
and options
reset
does not match. It needs to return a valid observation (ObsType) and a dictionarystep
does not match, needs to say if result is truncated / model went out of bounds. (see below) def reset(self, *, seed=None, options=None): # Fix input signature
# Reset the environment
self.flip_result = 0 # None is not a valid Observation
return self.flip_result, {} # Fix return signature
If you return None, as underlying numpy arrays are used array([0])[0]=obs <- None
would throw another error.
step
needs to have five returns parameters observation, reward, terminated, truncated, info
def step(self, action):
# Perform the action (0 for heads, 1 for tails)
self.flip_result = int(np.random.rand() < self.heads_probability)
# Compute the reward (1 for correct prediction, -1 for incorrect)
reward = 1 if self.flip_result == action else -1
# Return the observation, reward, done, truncated, and info
return self._get_observation(), reward, True, False, {}
Now the models trains:
-----------------------------
| time/ | |
| fps | 5608 |
| iterations | 1 |
| time_elapsed | 0 |
| total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/ | |
| fps | 3530 |
| iterations | 2 |
| time_elapsed | 1 |
| total_timesteps | 4096 |
| train/ | |
| approx_kl | 0.020679139 |
| clip_fraction | 0.617 |
| clip_range | 0.2 |
| entropy_loss | -0.675 |
| explained_variance | 0 |
| learning_rate | 0.0003 |
| loss | 0.38 |
| n_updates | 10 |
| policy_gradient_loss | -0.107 |
| value_loss | 1 |
-----------------------------------------
-----------------------------------------
| time/ | |
| fps | 3146 |
| iterations | 3 |
| time_elapsed | 1 |
| total_timesteps | 6144 |
| train/ | |
| approx_kl | 0.032571375 |
| clip_fraction | 0.628 |
| clip_range | 0.2 |
| entropy_loss | -0.599 |
| explained_variance | 0 |
| learning_rate | 0.0003 |
| loss | 0.392 |
| n_updates | 20 |
| policy_gradient_loss | -0.104 |
| value_loss | 0.987 |
-----------------------------------------
---------------------------------------
| time/ | |
| fps | 2984 |
| iterations | 4 |
| time_elapsed | 2 |
| total_timesteps | 8192 |
| train/ | |
| approx_kl | 0.0691616 |
| clip_fraction | 0.535 |
| clip_range | 0.2 |
| entropy_loss | -0.417 |
| explained_variance | 0 |
| learning_rate | 0.0003 |
| loss | 0.335 |
| n_updates | 30 |
| policy_gradient_loss | -0.09 |
| value_loss | 0.941 |
---------------------------------------
----------------------------------------
| time/ | |
| fps | 2898 |
| iterations | 5 |
| time_elapsed | 3 |
| total_timesteps | 10240 |
| train/ | |
| approx_kl | 0.12130852 |
| clip_fraction | 0.125 |
| clip_range | 0.2 |
| entropy_loss | -0.189 |
| explained_variance | 0 |
| learning_rate | 0.0003 |
| loss | 0.536 |
| n_updates | 40 |
| policy_gradient_loss | -0.0397 |
| value_loss | 0.806 |
----------------------------------------
Action: [1], Observation: [0], Reward: [1.]
Action: [1], Observation: [0], Reward: [-1.]
Action: [1], Observation: [0], Reward: [-1.]
Action: [1], Observation: [0], Reward: [1.]
Action: [1], Observation: [0], Reward: [1.]
Action: [1], Observation: [0], Reward: [-1.]
Action: [1], Observation: [0], Reward: [1.]
Action: [1], Observation: [0], Reward: [-1.]
Action: [1], Observation: [0], Reward: [1.]
Action: [1], Observation: [0], Reward: [1.]