Search code examples
reinforcement-learningstable-baselines

Accessing training metrics in stable-baselines3


Is it possible to access the A2C total loss and whether the environment truncated or terminated within a custom callback?

I'd like to access truncated and terminated in _on_step. That would allow me to terminate training when the environment truncates, and also allow me to record training episode durations. I'd also like to be able to record total loss after an update.


Solution

  • Thanks to advice from gehirndienst I've taken a more 'SB3' approach than trying to write a custom callback. I'm not actually plotting mean episode length and reward, but I am using wrappers and callbacks to terminate training when the mean episode length meets the required value. I had to revert to using gym, not gymnasium too, as SB3 doesn't seem to have migrated yet.

    def train() -> None:
        """
        Training loop
        """
        # Create environment and agent
        environment: gym.Env = gym.make(GAME)
        policy_kwargs = dict(activation_fn=ACTIVATION_FN, net_arch=NET_ARCH)
        agent: algorithm.OnPolicyAlgorithm = A2C("MlpPolicy", environment, policy_kwargs=policy_kwargs,
                                                 n_steps=N_STEPS, learning_rate=LEARNING_RATE, gamma=GAMMA, verbose=1)
    
        # Train the agent
        callback_on_best: BaseCallback = StopTrainingOnRewardThreshold(reward_threshold=MAX_EPISODE_DURATION, verbose=1)
        eval_callback: BaseCallback = EvalCallback(Monitor(environment), callback_on_new_best=callback_on_best,
                                                   eval_freq=EVAL_FREQ, n_eval_episodes=AVERAGING_WINDOW)
        # Set huge number of steps because termination is based on the callback
        agent.learn(int(1e10), callback=eval_callback)
    
        # Save the agent
        agent.save(MODEL_FILE)