2

I'm learning pytorch, and I'm trying to implement a paper about the progressive growing of GANs. The authors train the networks on the given number of images, instead of for a given number of epochs.

My question is: is there a way to do this in pytorch, using default DataLoaders? I'd like to do something like:

loader = Dataloader(..., total=800000)
for batch in iter(loader):
   ... #do training

And the loader loops itself automatically until 800000 samples are seen.

I think that I'd be a better way, than to calculate the number of times you have to loop through the dataset by yourself

marc_s
  • 732,580
  • 175
  • 1,330
  • 1,459
aurelia
  • 493
  • 8
  • 12

2 Answers2

1

You can use torch.utils.data.RandomSampler and sample from your dataset. Here is a minimal setup example:

class DS(Dataset):
    def __len__(self):
        return 5
    def __getitem__(self, index):
        return torch.empty(1).fill_(index)

>>> ds = DS()

Initialize a random sampler providing num_samples and setting replacement to True i.e. the sampler is forced to draw instances multiple times if len(ds) < num_samples:

>>> sampler = RandomSampler(ds, replacement=True, num_samples=10)

Then plug this sampler to a new torch.utils.data.DataLoader:

>>> dl = DataLoader(ds, sampler=sampler, batch_size=2)

>>> for batch in dl:
...     print(batch)
tensor([[6.],
        [4.]])
tensor([[9.],
        [2.]])
tensor([[9.],
        [2.]])
tensor([[6.],
        [2.]])
tensor([[0.],
        [9.]])
Ivan
  • 34,531
  • 8
  • 55
  • 100
  • it will work if the number of samples isn't divided by the batch_size, won't it? (the documentation is a little vague about it) – aurelia Oct 03 '21 at 17:26
  • 1
    Your `RandomSampler` will draw a number of `num_samples` instances whatever the number of elements in your dataset. If this number is not divisible by `batch_size`, then the last batch will not get filled. If you wish to ignore this last partially filled batch you can set the parameter `drop_last` to `True` on the data-loader. With the above setup, compare `DataLoader(ds, sampler=sampler, batch_size=3)`, to this `DataLoader(ds, sampler=sampler, batch_size=3, drop_last=True)`. – Ivan Oct 03 '21 at 17:31
0

torch.utils.data.RandomSampler can be used to randomly sample more entries than exist in a dataset (where num_samples > dataset_size);

sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=num_samples)
loader = torch.utils.data.DataLoader(dataset=dataloader_dataset, sampler=sampler, batch_size=batch_size)

If sampling from a Hugging Face dataset, the dataloader_dataset class must have StopIteration configured to reset the iterator (start from beginning of dataset), for example;

#parameter selection (user configured);
dataset = load_dataset(...) 
dataset_size = dataset.num_rows
number_of_dataset_repetitions = 5
num_samples = dataset_size * number_of_dataset_repetitions
batch_size = 8
drop_last = True

dataloader_dataset = DataloaderDatasetRepeatSampler(dataset, dataset_size)  
sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=num_samples)
loader = torch.utils.data.DataLoader(dataset=dataloader_dataset, sampler=sampler, batch_size=batch_size, drop_last=drop_last)
loop = tqdm(loader, leave=True)

for batch_index, batch in enumerate(loop):
    ...

class DataloaderDatasetRepeatSampler(torch.utils.data.Dataset):
    def __init__(self, dataset, dataset_size):
        self.dataset = dataset
        self.dataset_size = dataset_size
        self.dataset_iterator = iter(dataset)
            
    def __len__(self):
        return self.datasetSize

    def __getitem__(self, i):
        try:
            dataset_entry = next(self.dataset_iterator)
        except StopIteration:
            #reset iterator (start from beginning of dataset)
            self.dataset_iterator = iter(self.dataset)
            dataset_entry = next(self.dataset_iterator)
        batch_sample = ...  #eg torch.Tensor(dataset_entry)
        return batch_sample
user2585501
  • 596
  • 4
  • 17