14

I would like to create a new tensor in a validation_epoch_end method of a LightningModule. From the official docs (page 48) it is stated that we should avoid direct .cuda() or .to(device) calls:

There are no .cuda() or .to() calls. . . Lightning does these for you.

and we are encouraged to use type_as method to transfer to the correct device.

new_x = new_x.type_as(x.type())

However, in a step validation_epoch_end I do not have any tensor to copy device from(by type_as method) in a clean way.

My question is what should I do if I want to create a new tensor in this method and transfer it to the device where is the model?

The only thing I can think of is to find a tensor in the outputs dictionary but it feels kinda messy:

avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
output = self(self.__test_input.type_as(avg_loss))

Is there any clean way to achieve that?

Szymon Knop
  • 509
  • 1
  • 6
  • 16

1 Answers1

22

did you check part 3.4 (page 34) in the doc you linked ?

LightningModules know what device they are on! construct tensors on the device directly to avoid CPU->Device transfer

t = tensor.rand(2, 2).cuda()# bad
(self is lightningModule)t = tensor.rand(2,2, device=self.device)# good 

I had a similar issue to create tensors this helped me. I hope it will help you too.

Robin San Roman
  • 236
  • 2
  • 3
  • Thank you, this was exactly what I was looking for! – Szymon Knop Jul 10 '20 at 17:34
  • 1
    I was about to start crying when I read your (working) answer: thanks! – Antonio Sesto Jul 27 '21 at 16:06
  • The documentation changed a bit but the code still works: device The device the module is on. Use it to keep your code device agnostic. `def training_step(self): z = torch.rand(2, 3, device=self.device)` – Marine Galantin Jan 05 '22 at 01:00
  • This answer works in many cases, but for `torch.Tensor([1], device=self.device)` it fails. You can get this error . It works better [link](https://discuss.pytorch.org/t/cannot-construct-tensor-directly-on-gpu-in-torch-1-10-1/153751). It is better to use `torch.Tensor([1]).to(self.device)` for these cases – Alejo L.A Oct 22 '22 at 12:48