I am trying the performances got with different hyperparameters on solving the cartpole Gym environment through TF-Agents using the starting code proposed around the Internet. I copy here the central part of the code
collect_steps_per_iteration = 1
batch_size = 64
dataset = replay_buffer.as_dataset(num_parallel_calls=3,
sample_batch_size=batch_size,
num_steps=2).prefetch(3)
iterator = iter(dataset)
num_iterations = 500000
env.reset()
for _ in range(batch_size*10):
collect_step(env, agent.policy, replay_buffer)
for _ in range(num_iterations):
# Collect a few steps using collect_policy and save to the replay buffer.
for _ in range(collect_steps_per_iteration):
last_inserted_batch = collect_step(env, agent.collect_policy, replay_buffer)
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
print(experience)
train_loss = agent.train(experience).loss
step = agent.train_step_counter.numpy()
# Print loss every 200 steps.
if step % 200 == 0:
print('step = {0}: loss = {1}'.format(step, train_loss))
# Evaluate agent's performance every 1000 steps.
if step % 1000 == 0:
avg_return = compute_avg_return(env, agent.policy, 5)
print('step = {0}: Average Return = {1}'.format(step, avg_return))
returns.append(avg_return)
From what I understand the iterator/next loops on the replay_buffer for batches extraction sequentially instead of randomly as suggested by literature (at least as a starting point). I wonder how I could select random batches which means selecting a random position inside the replay_buffer (and/or the dataset retrieved from the replay_buffer) for the starting point of the batch (first element of the batch) keeping elements inside the batch that are consecutive as in the original dataset/replay_buffer. I searched for solutions, but always get stuck with Tensorflow formalisms. This is for the coding; I also wonder if it could give an advantage to wait before training that other 64 (batch_size) experiences are stored and then choose randomly the position of the batch, but just on a multiple of 64 elements.