I'm attempting to get 100% reproducibility when resuming from a checkpoint for a reinforcement learning agent I'm training in PyTorch. What I currently find is that if I train the agent from scratch twice in a row, at 10000 timesteps the training plots (loss, return, etc.) are identical. However, if I save a checkpoint at 5000 timesteps, then resume training from this timestep and continue training out to 10000 timesteps, I find that performance is slightly off, as can be seen from the following plot (where blue is the trained from scratch to 10k steps and green is resumed from a 5k timestep checkpoint of blue and trained out to 10k steps):
I've stepped through my code and found that the parameters of my models and the RNG states are identical at the 5k step mark with both training from scratch, and after loading from the 5k checkpoint.
I set my seeding as follows:
def set_seed(seed, device):
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device.type == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
When I generate my environment I also set the following seeding:
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
When loading from a checkpoint, besides loading the state dicts for all my models and optimizers, I also set the RNG states after setting the seeds at the beginning of my code as follows:
if args.resume:
random.setstate(checkpoint["rng_states"]["random_rng_state"])
np.random.set_state(checkpoint["rng_states"]["numpy_rng_state"])
torch.set_rng_state(checkpoint["rng_states"]["torch_rng_state"])
if device.type == "cuda":
torch.cuda.set_rng_state(checkpoint["rng_states"]["torch_cuda_rng_state"])
torch.cuda.set_rng_state_all(
checkpoint["rng_states"]["torch_cuda_rng_state_all"]
)
The more complete script is here (I added only what I thought were the relevant sections here for brevity/make things less confusing): https://pastebin.com/1yqn3CLt
Would anyone have any ideas as to what I might be doing wrong such that I can't get exact reproducibility when I'm resuming from my checkpoint? Thanks in advance!