I just started playing with RLlib and I wanted to test offline DQN training on a CartPole. Thus, I generated the data as in the tutorial:
rllib train --run=PG --env=CartPole-v1 --config='{"output": "/tmp/cartpole-out", "output_max_file_size": 5000000}' --stop='{"timesteps_total": 100000}'
And then run offline the training via
rllib train --run=DQN --env=CartPole-v1 --config='{"input": "/tmp/cartpole-out","explore": false}' --stop='{"timesteps_total": 1000000}'
However, now I want to reproduce offline training using Python API and I get a little confused about timesteps_total
. I have written the following code:
if __name__ == '__main__':
config = (
DQNConfig()
.environment(env="CartPole-v1")
.framework("torch")
.offline_data(input_="/tmp/cartpole-out")
.exploration(explore=False)
)
algo = config.build()
for _ in tqdm(range(100)):
algo.train()
but I am not sure how timesteps_total
is related to the training loop with the above snippet. I looked inside AlgorithmConfig class, but I found that self.timesteps_per_iteration = DEPRECATED_VALUE
. Thus, the question:
How to set up timesteps_total
in the config in Python API?
Disclaimer: I asked ChatGPT about that, but it confidently gave me the wrong answer with a detailed explanation.