0

I am periodically saving checkpoints like this:

    loss = trn_metrics_t[METRICS_LOSS_NDX].mean().item()
    torch.save({
        'epoch': epoch_ndx - 1,
        'model_state_dict': self.model.state_dict(),
        'optimizer_state_dict': self.optimizer.state_dict(),
        'loss': loss,
    }, self.checkpoint_path)

Thus if the process is killed because of execution time, I can reload the latest checkpoint with the following code:

        checkpoint = torch.load(self.checkpoint_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

However, I am getting this error:

Traceback (most recent call last):
  File "/home/miran045/reine097/projects/motion-qc-deep-learning/src/dcan/training/training.py", line 222, in <module>
    InfantMRIMotionQCTrainingApp().main()
  File "/home/miran045/reine097/projects/motion-qc-deep-learning/src/dcan/training/training.py", line 178, in main
    self.run_epochs(self.train_dl, self.val_dl, epoch=epoch)
  File "/panfs/jay/groups/4/miran045/reine097/projects/motion-qc-deep-learning/src/dcan/training/TrainingApp.py", line 297, in run_epochs
    trn_metrics_t = self.do_training(epoch_ndx, train_dl)
  File "/panfs/jay/groups/4/miran045/reine097/projects/motion-qc-deep-learning/src/dcan/training/TrainingApp.py", line 272, in do_training
    self.optimizer.step()
  File "/home/miran045/reine097/projects/AlexNet_Abrol2021/venv/lib/python3.9/site-packages/torch/optim/optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
  File "/home/miran045/reine097/projects/AlexNet_Abrol2021/venv/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/miran045/reine097/projects/AlexNet_Abrol2021/venv/lib/python3.9/site-packages/torch/optim/adam.py", line 133, in step
    F.adam(params_with_grad,
  File "/home/miran045/reine097/projects/AlexNet_Abrol2021/venv/lib/python3.9/site-packages/torch/optim/_functional.py", line 86, in adam
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, 
cuda:0 and cpu!

when trying to load the checkpoint.

Innat
  • 16,113
  • 6
  • 53
  • 101
Paul Reiners
  • 8,576
  • 33
  • 117
  • 202

0 Answers0