I am trying to write a device-agnostic library for PyTorch, and I have stumbled across the problem of PyTorch using dtypes that are not compatible with my compute device:
import scipy.signal
import torch
raw_window = scipy.signal.windows.cosine(128)
print(raw_window.dtype) # float64
device = torch.device("cpu")
window = torch.as_tensor(raw_window, device=device)
print(window.device) # cpu
print(window.dtype) # torch.float64
device = torch.device("cuda")
window = torch.as_tensor(raw_window, device=device)
print(window.device) # cuda:0
print(window.dtype) # torch.float64
As you can see in the last line, torch assigns the dtype torch.float64
, even though my CUDA device is not able to handle double precision float values.
Is there a way to make PyTorch use the most suitable device dtype instead of the input data dtype? Or am I misunderstanding that whole concept of setting devices and dtypes completely?