0

I am currently writing a reinforcement learning model using stable_baselines3 library and gym_anytrading. I have written the code for an environment to train the model in and have a number of timesteps.

However, the model often finishes its training when the explained variance is at an undesirable level. Hence, I want to write a callback function that helps stop the training when the explained variance is within a range e.g. between 0.9 to 1.

This is the environment that I have created thus far.

env_maker = lambda: MyCustomEnv(df=df, frame_bound=(12,30660), window_size=12)
env = DummyVecEnv([env_maker])

model = A2C('MlpPolicy', env, verbose=1, policy_kwargs=dict(net_arch=[dict(pi=[128, 256, 128], vf=[128, 256, 128])]))

# Pass the custom callback to the learn() method
model.learn(total_timesteps=1000000, callback=custom_stop_callback)

And this is the callback function that is above the previous that should stop the model when the explained variance is above a certain value.

class CustomLogger(logger.Logger):
    def __init__(self, folder, output_formats, *args, **kwargs):
        super().__init__(folder, output_formats, *args, **kwargs)
        self.buffer = []

    def get_writer(self) -> KVWriter:
        return self

    def _write(self, key_values, key_excluded):
        self.buffer.append((key_values, key_excluded))

class CustomStopCallback(BaseCallback):
    def __init__(self, logger, explained_variance_threshold: float, value_loss_threshold: float, starting_step: int = 0):
        super(CustomStopCallback, self).__init__()
        self.logger = logger
        self.explained_variance_threshold = explained_variance_threshold
        self.value_loss_threshold = value_loss_threshold
        self.starting_step = starting_step

    def _on_step(self) -> bool:
        return True

    def _on_rollout_end(self) -> None:
        if self.num_timesteps >= self.starting_step:
            log_buffer = self.logger.buffer
            explained_variance = None
            value_loss = None

            for record in log_buffer:
                key_values, _ = record
                if "explained_variance" in key_values:
                    explained_variance = key_values["explained_variance"]
                if "value_loss" in key_values:
                    value_loss = key_values["value_loss"]

            if explained_variance is not None and value_loss is not None:
                if explained_variance >= self.explained_variance_threshold and value_loss > self.value_loss_threshold:
                    print(f"Stopping training at step {self.num_timesteps} due to specified threshold conditions.")
                    self.model.set_attr('stop_training', True)

folder = "logs"
logger.configure(folder=folder)

# Instantiate the custom callback with specified thresholds
custom_stop_callback = CustomStopCallback(logger, explained_variance_threshold=0.9, value_loss_threshold=0, starting_step=10000)

Any help on how to resolve this is appreciated!

ET4
  • 1

0 Answers0