I have a LightningModule that looks something like this:
class MyModule(pl.LightningModule):
def __init__(self):
self.module1 = Module1()
self.module2 = Module2()
def save(path):
torch.save((self.module1, self.module2), path)
def load(path):
obj = torch.load(path)
self.module1 = obj[0]
self.module2 = obj[1]
Now I want to train it on GPU using
trainer = pl.Trainer(accelerator="gpu")
trainer.fit(model=mymodule)
This works in a notebook with a newly started kernel. After training, I save the module to disk. I can restart the kernel, create a new MyModule, load the state from disk and continue training. What doesn't work is reloading from disk without restarting the kernel. So basically
trainer = pl.Trainer(accelerator="gpu")
trainer.fit(model=mymodule)
mymodule.save("test.pt")
mymodule.load("test.pt")
trainer = pl.Trainer(accelerator="gpu")
trainer.fit(model=mymodule)
results in
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument tensors in method wrapper_CUDA_cat)
where the exact place where the exception is thrown varies. Puzzlingly, if I enter the post-mortem debugger with %debug
and go up the stack to the training_step of MyModule, self.device
is cpu
and not cuda
.
I've tried running garbage collection followed by torch.cuda.empty_cache()
. I've also tried
def load(path):
del self.module1
del self.module2
obj = torch.load(path)
self.module1 = obj[0]
self.module2 = obj[1]
without any effect.
It might be worth mentioning that module1 is a Pyro model and module2 is a Pyro guide (this is also why I can't just save the state_dict, as the guide creates the trainable parameters only when it sees the first minibatch). Both are subclasses of pyro.nn.PyroModule
, which itself subclasses torch.nn.Module
.