I am currently writing a reinforcement learning model using stable_baselines3 library and gym_anytrading. I have written the code for an environment to train the model in and have a number of timesteps.
However, the model often finishes its training when the explained variance
is at an undesirable level. Hence, I want to write a callback function that helps stop the training when the explained variance
is within a range e.g. between 0.9 to 1.
This is the environment that I have created thus far.
env_maker = lambda: MyCustomEnv(df=df, frame_bound=(12,30660), window_size=12)
env = DummyVecEnv([env_maker])
model = A2C('MlpPolicy', env, verbose=1, policy_kwargs=dict(net_arch=[dict(pi=[128, 256, 128], vf=[128, 256, 128])]))
# Pass the custom callback to the learn() method
model.learn(total_timesteps=1000000, callback=custom_stop_callback)
And this is the callback function that is above the previous that should stop the model when the explained variance
is above a certain value.
class CustomLogger(logger.Logger):
def __init__(self, folder, output_formats, *args, **kwargs):
super().__init__(folder, output_formats, *args, **kwargs)
self.buffer = []
def get_writer(self) -> KVWriter:
return self
def _write(self, key_values, key_excluded):
self.buffer.append((key_values, key_excluded))
class CustomStopCallback(BaseCallback):
def __init__(self, logger, explained_variance_threshold: float, value_loss_threshold: float, starting_step: int = 0):
super(CustomStopCallback, self).__init__()
self.logger = logger
self.explained_variance_threshold = explained_variance_threshold
self.value_loss_threshold = value_loss_threshold
self.starting_step = starting_step
def _on_step(self) -> bool:
return True
def _on_rollout_end(self) -> None:
if self.num_timesteps >= self.starting_step:
log_buffer = self.logger.buffer
explained_variance = None
value_loss = None
for record in log_buffer:
key_values, _ = record
if "explained_variance" in key_values:
explained_variance = key_values["explained_variance"]
if "value_loss" in key_values:
value_loss = key_values["value_loss"]
if explained_variance is not None and value_loss is not None:
if explained_variance >= self.explained_variance_threshold and value_loss > self.value_loss_threshold:
print(f"Stopping training at step {self.num_timesteps} due to specified threshold conditions.")
self.model.set_attr('stop_training', True)
folder = "logs"
logger.configure(folder=folder)
# Instantiate the custom callback with specified thresholds
custom_stop_callback = CustomStopCallback(logger, explained_variance_threshold=0.9, value_loss_threshold=0, starting_step=10000)
Any help on how to resolve this is appreciated!