24

I am new to Pytorch, but it seems pretty nice. My only question was when to use tensor.to(device) or Module.nn.to(device).

I was reading the documentation on this topic, and it indicates that this method will move the tensor or model to the specified device. But I was not clear for what operations this is necessary, and what kind of errors I will get if I don't use .to() at the right time?

For example, if I just create a tensor, I imagine that the tensor is stored in CPU accessible memory until I move the tensor to the GPU. Once the tensor is on the GPU, then the GPU will execute any mathematical operations on that tensor.

However, do I have to worry about accidentally transferring the data tensor to the GPU while not transferring the model to the GPU? Will this just give me straight errors, or will it engage in a lot of expensive data transfer behind the scenes. This example is easy enough for me to test, but I was just wondering about other cases where it might not be so obvious.

Any guidance would be helpful.

krishnab
  • 9,270
  • 12
  • 66
  • 123
  • 2
    If an operation is made with one tensor on the GPU and the other on the CPU, you'll receive a `Runtime Error: Expected object of device type cuda but got device type cpu in Pytorch`, which is quite clear. – Berriel Jul 23 '20 at 22:17
  • @Berriel yes, this makes sense. Now I know how to interpret this message if it pops up. Thanks for the tip. I was just worried in case the system did not fail, and I ended up with some very slow training loop or other difficult to diagnose issue. Thanks again. – krishnab Jul 24 '20 at 14:54

1 Answers1

24

It is necessary to have both the model, and the data on the same device, either CPU or GPU, for the model to process data. Data on CPU and model on GPU, or vice-versa, will result in a Runtime error.

You can set a variable device to cuda if it's available, else it will be set to cpu, and then transfer data and model to device :

import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
data = data.to(device)
Garima Jain
  • 675
  • 6
  • 7