0

I have a LightningModule that looks something like this:

class MyModule(pl.LightningModule):
    def __init__(self):
        self.module1 = Module1()
        self.module2 = Module2()

    def save(path):
        torch.save((self.module1, self.module2), path)

    def load(path):
        obj = torch.load(path)
        self.module1 = obj[0]
        self.module2 = obj[1]

Now I want to train it on GPU using

trainer = pl.Trainer(accelerator="gpu")
trainer.fit(model=mymodule)

This works in a notebook with a newly started kernel. After training, I save the module to disk. I can restart the kernel, create a new MyModule, load the state from disk and continue training. What doesn't work is reloading from disk without restarting the kernel. So basically

trainer = pl.Trainer(accelerator="gpu")
trainer.fit(model=mymodule)
mymodule.save("test.pt")
mymodule.load("test.pt")
trainer = pl.Trainer(accelerator="gpu")
trainer.fit(model=mymodule)

results in

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument tensors in method wrapper_CUDA_cat)

where the exact place where the exception is thrown varies. Puzzlingly, if I enter the post-mortem debugger with %debug and go up the stack to the training_step of MyModule, self.device is cpu and not cuda.

I've tried running garbage collection followed by torch.cuda.empty_cache(). I've also tried

def load(path):
    del self.module1
    del self.module2
    obj = torch.load(path)
    self.module1 = obj[0]
    self.module2 = obj[1]

without any effect.

It might be worth mentioning that module1 is a Pyro model and module2 is a Pyro guide (this is also why I can't just save the state_dict, as the guide creates the trainable parameters only when it sees the first minibatch). Both are subclasses of pyro.nn.PyroModule, which itself subclasses torch.nn.Module.

Robert Crovella
  • 143,785
  • 11
  • 213
  • 257
MadScience
  • 81
  • 2
  • I suggest to move both modules to `cpu` before saving, and to move them to `self.device` when loading. If you provide an MRE it could be useful to further debug you issue. – Luca Di Liello Jul 07 '23 at 07:02

0 Answers0