I am periodically saving checkpoints like this:
loss = trn_metrics_t[METRICS_LOSS_NDX].mean().item()
torch.save({
'epoch': epoch_ndx - 1,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': loss,
}, self.checkpoint_path)
Thus if the process is killed because of execution time, I can reload the latest checkpoint with the following code:
checkpoint = torch.load(self.checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
However, I am getting this error:
Traceback (most recent call last):
File "/home/miran045/reine097/projects/motion-qc-deep-learning/src/dcan/training/training.py", line 222, in <module>
InfantMRIMotionQCTrainingApp().main()
File "/home/miran045/reine097/projects/motion-qc-deep-learning/src/dcan/training/training.py", line 178, in main
self.run_epochs(self.train_dl, self.val_dl, epoch=epoch)
File "/panfs/jay/groups/4/miran045/reine097/projects/motion-qc-deep-learning/src/dcan/training/TrainingApp.py", line 297, in run_epochs
trn_metrics_t = self.do_training(epoch_ndx, train_dl)
File "/panfs/jay/groups/4/miran045/reine097/projects/motion-qc-deep-learning/src/dcan/training/TrainingApp.py", line 272, in do_training
self.optimizer.step()
File "/home/miran045/reine097/projects/AlexNet_Abrol2021/venv/lib/python3.9/site-packages/torch/optim/optimizer.py", line 88, in wrapper
return func(*args, **kwargs)
File "/home/miran045/reine097/projects/AlexNet_Abrol2021/venv/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "/home/miran045/reine097/projects/AlexNet_Abrol2021/venv/lib/python3.9/site-packages/torch/optim/adam.py", line 133, in step
F.adam(params_with_grad,
File "/home/miran045/reine097/projects/AlexNet_Abrol2021/venv/lib/python3.9/site-packages/torch/optim/_functional.py", line 86, in adam
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices,
cuda:0 and cpu!
when trying to load the checkpoint.