Ive been trying to integrate the to(device) inside my dataloader using to(device) as seen in https://github.com/pytorch/pytorch/issues/11372
I defined it on FashionMNIST in the following way:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 32
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/',
download=True,
train=True,
transform=transforms.ToTensor())
rain_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: default_collate(x).to(device))
But i get the following error: AttributeError: 'list' object has no attribute 'to' It seems that the output of default collate is a list of length 2 with the first element being the image tensor and the second the labels tensor (since its the output of next(iter(train_loader)) with collate_fn=None), so I tried with the following defined function:
def to_device_list(l, device):
return [l[0].to(device), l[1].to(device)]
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: to_device_list(x, device))
And I got the following error: AttributeError: 'tuple' object has no attribute 'to'
Any help please on how to do it?