1

With stable baselines 3 it is possible to access metrics and info of the environment by using self.training_env.get_attr("your_attribute_name"), however, how does one access the training metrics that are generated by the model?

By setting verbose=1 these training metrics can be printed in the console, but how could one access these in a custom logger?

logs printed in console

Here is some starter code:

import gym
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback

class MeticLogger(BaseCallback):
    def __init__(self, verbose=0):
        super(MeticLogger, self).__init__(verbose)
    
    def _on_step(self) -> bool:

        #here i would like to access entropy_loss for further processing
        #entropy_loss = self.locals['rollout_buffer']??? perhaps

        return True

env = gym.make('CartPole-v1')

model = A2C('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=1000, callback=MeticLogger())

What should I do to access entropy_loss on_step in the custom logger?

rich
  • 520
  • 6
  • 21

1 Answers1

0
class MeticLogger(BaseCallback):
def __init__(self,log_frequency=100, verbose=0):
    super(MeticLogger, self).__init__(verbose)
    self.verbose=verbose
    self.log_frequency=log_frequency
    self.value_lossess=[]

def _on_step(self) -> bool:
    if self.n_calls % self.log_frequency == 0:
        if (self.verbose == 1):
            print(f"iterations: {self.model.logger.name_to_value['train/n_updates']}")
            print(f"ep_rew_mean: {self.model.logger.name_to_value['train/ep_rew_mean']}")
            print(f"policy_loss: {self.model.logger.name_to_value['train/policy_loss']}")
            print(f"value_loss: {self.model.logger.name_to_value['train/value_loss']}")
            print(f"entropy_loss: {self.model.logger.name_to_value['train/entropy_loss']}")
            print("--------------------------------")
            self.value_lossess.append(self.model.logger.name_to_value['train/value_loss'])

    return True
Nabat Farsi
  • 840
  • 1
  • 9
  • 17