3

I am training a GANS on the Cifar-10 dataset in PyTorch (and hence don't need train/val/test splits), and I want to be able to combine the torchvision.datasets.CIFAR10 in the snippet below to form one single torch.utils.data.DataLoader iterator. My current solution is something like :

import torchvision
import torch
batch_size = 128
cifar_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False)
cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)
cifar_dl1 = torch.utils.data.DataLoader(cifar_trainset, batch_size=batch_size, num_workers=12, persistent_workers=True,
                                          shuffle=True, pin_memory=True)
cifar_dl2 = torch.utils.data.DataLoader(cifar_testset, batch_size=batch_size, num_workers=12, persistent_workers=True,
                                          shuffle=True, pin_memory=True)

And then in my training loop I have something like:

for dl in [cifar_dl1, cifar_l2]:
   for data in dl:
      # training

The problem with this approach in a multi-threaded context, where I have found for my setup and this task that the optimal number of workers is 12, is that now I am declaring 24 workers in total which is clearly too many, not to mention the start-up time costs associated with re-iterating over each dataloader in spite of the benefits of the persistent workers flag for each.

Any solutions to this problem much appreciated.

IntegrateThis
  • 853
  • 2
  • 16
  • 39
  • how is your question different to: https://stackoverflow.com/questions/60840500/pytorch-concatenating-datasets-before-using-dataloader? – Charlie Parker Sep 27 '22 at 01:32
  • @CharlieParker A bit unsure what you mean. Do you mean if you have two datasets with mutually exclusive labels (say N_1 for dataset 1, N_2 for dataset 2), then you want the largest label to be (N_1 + N_2 -1) or something? I think you could simply add the offset to each label in dataset 2 by N_1. – IntegrateThis Sep 27 '22 at 01:41
  • yes I think that is what I want. Ok so I think that would have to happen in the collate function or something. But how do I find the number of labels for a data set to begin with? I'm praying that each data set has the `dataset.labels`. e.g. if we try to merge/concatenate/union mnist or cifar I'd want labels in the range `0...10+10-1` I think. – Charlie Parker Sep 27 '22 at 01:48
  • Well assuming you know the dataset source you would have access to the literature and it will say the number of classes. For instance MNist there are 10 digits (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), hence there are 10 classes. It should be obvious from the data source how many labels there are. – IntegrateThis Sep 27 '22 at 01:50
  • thanks I think I figured out how to solve this, though it's not something ppl might easily know but I outlined the solution here: https://discuss.pytorch.org/t/does-concatenate-datasets-preserve-class-labels-and-indices/62611/12?u=brando_miranda the idea is to use the library I outline there but to implement it yourself it's not too hard. One just would need to write a custom data set that makes sure the addition of the offset is correct given an arbitrary list of data sets. – Charlie Parker Sep 27 '22 at 02:02
  • Actually, it's likely easier to preprocess the data points indices to map to the label required label (as you loop through each data set you'd know this value easily and keep a single counter) -- instead of bisecting. – Charlie Parker Sep 27 '22 at 02:06

1 Answers1

4

You can use ConcatDataset from torch.utils.data module.

Code Snippet:

import torch    
import torchvision

batch_size = 128

cifar_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False)
cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)

cifar_dataset = torch.utils.data.ConcatDataset([cifar_trainset, cifar_testset])

cifar_dataloader = torch.utils.data.DataLoader(cifar_dataset, batch_size=batch_size, num_workers=12, persistent_workers=True,
                                          shuffle=True, pin_memory=True)

for data in cifar_dataloader:
    # training
Kishore Sampath
  • 973
  • 6
  • 13