I am training a GANS on the Cifar-10 dataset in PyTorch (and hence don't need train/val/test splits), and I want to be able to combine the torchvision.datasets.CIFAR10
in the snippet below to form one single torch.utils.data.DataLoader
iterator. My current solution is something like :
import torchvision
import torch
batch_size = 128
cifar_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False)
cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)
cifar_dl1 = torch.utils.data.DataLoader(cifar_trainset, batch_size=batch_size, num_workers=12, persistent_workers=True,
shuffle=True, pin_memory=True)
cifar_dl2 = torch.utils.data.DataLoader(cifar_testset, batch_size=batch_size, num_workers=12, persistent_workers=True,
shuffle=True, pin_memory=True)
And then in my training loop I have something like:
for dl in [cifar_dl1, cifar_l2]:
for data in dl:
# training
The problem with this approach in a multi-threaded context, where I have found for my setup and this task that the optimal number of workers is 12, is that now I am declaring 24 workers in total which is clearly too many, not to mention the start-up time costs associated with re-iterating over each dataloader in spite of the benefits of the persistent workers flag for each.
Any solutions to this problem much appreciated.