0

Is it possible to access the A2C total loss and whether the environment truncated or terminated within a custom callback?

I'd like to access truncated and terminated in _on_step. That would allow me to terminate training when the environment truncates, and also allow me to record training episode durations. I'd also like to be able to record total loss after an update.

2 Answers2

0

You need to attach a callback that implements _on_step method that returns a bool by checking your env's variables. Something like this (I always check my env for being a VecEnv since it has a bit different way of accessing its variables in compare to non-vectorized one):

    class StopOnTruncCallback(BaseCallback):

        def __init__(self, verbose: int = 0):
            super().__init__(verbose)
    
        def _on_step(self):
            return self._is_trunc()

        def _is_trunc(self):
            if isinstance(self.training_env, VecEnv):
                return self.training_env.get_attr("truncated")[0]
            else:
                return self.training_env.truncated
gehirndienst
  • 424
  • 2
  • 13
  • Thanks. My environment is wrapped as a `VecEnv`, but I'm testing it on Cart Pole, and I'm getting this error using your above example: `'CartPoleEnv' object has no attribute 'truncated'` – Alpine Chamois Feb 11 '23 at 15:05
  • I'm using `gymnasium~=0.27.0` and can access the variables outside of Stable Baselines. – Alpine Chamois Feb 23 '23 at 15:48
  • @AlpineChamois you're right, I haven't been working with classical envs for ages and therefore forgot: `CartPole` can't have `truncated` at all, it is always `False` which makes absolute sense and also it doesn't hold `terminated` as a class variable, therefore you can't access it within your callback call. However you can access `self.state` and recalculate your `terminated` in your callback again as [here](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/classic_control/cartpole.py#L165) – gehirndienst Feb 24 '23 at 08:00
  • Thanks - you've made me look closer and I've learned something... Outside of Stable Baselines, when I use Cart Pole V1, gymnasium appears to wrap it in a TimeLimit wrapper, which is actually what issues the truncated signal after 500 steps (just like how the vector version of Cart Pole does further down in the file you linked). I had no idea. I assumed the environment in Stable Baselines would be the same. I'll look into it some more and update this when I find a solution. – Alpine Chamois Feb 25 '23 at 08:50
  • You can wrap your environment in any wrapper and pass then to sb model. But if you just want to implement truncating after N steps then you can write a callback and just count steps there until a counter reaches your threshold -- after that callback's `on_step` method returns `False` and your training stops. – gehirndienst Feb 25 '23 at 12:26
  • I'll look into using the wrapper and passing it in to the Stable Baselines model. I could calculate truncating myself, but I'm trying to compare the performance of Stable Baselines to a hand-rolled PyTorch RL agent, so I'd like to make the training as similar as possible - i.e. use the wrapper to decide when it should truncate. – Alpine Chamois Feb 25 '23 at 16:17
0

Thanks to advice from gehirndienst I've taken a more 'SB3' approach than trying to write a custom callback. I'm not actually plotting mean episode length and reward, but I am using wrappers and callbacks to terminate training when the mean episode length meets the required value. I had to revert to using gym, not gymnasium too, as SB3 doesn't seem to have migrated yet.

def train() -> None:
    """
    Training loop
    """
    # Create environment and agent
    environment: gym.Env = gym.make(GAME)
    policy_kwargs = dict(activation_fn=ACTIVATION_FN, net_arch=NET_ARCH)
    agent: algorithm.OnPolicyAlgorithm = A2C("MlpPolicy", environment, policy_kwargs=policy_kwargs,
                                             n_steps=N_STEPS, learning_rate=LEARNING_RATE, gamma=GAMMA, verbose=1)

    # Train the agent
    callback_on_best: BaseCallback = StopTrainingOnRewardThreshold(reward_threshold=MAX_EPISODE_DURATION, verbose=1)
    eval_callback: BaseCallback = EvalCallback(Monitor(environment), callback_on_new_best=callback_on_best,
                                               eval_freq=EVAL_FREQ, n_eval_episodes=AVERAGING_WINDOW)
    # Set huge number of steps because termination is based on the callback
    agent.learn(int(1e10), callback=eval_callback)

    # Save the agent
    agent.save(MODEL_FILE)