6

I have a generator that creates synthetic data. How can I convert this into a PyTorch dataloader?

Rylan Schaeffer
  • 1,945
  • 2
  • 28
  • 50

2 Answers2

6

You can wrap your generator with a data.IterableDataset:

class IterDataset(data.IterableDataset):
    def __init__(self, generator):
        self.generator = generator

    def __iter__(self):
        return self.generator()

Naturally, you can then wrap this dataset with a data.DataLoader.

Here is a minimal example showing its use:

>>> gen = lambda: [(yield x) for x in range(10)]

>>> dataset = IterDataset(gen)
>>> for i in data.DataLoader(dataset, batch_size=2):
...    print(i)
tensor([0, 1])
tensor([2, 3])
tensor([4, 5])
tensor([6, 7])
tensor([8, 9])
Ivan
  • 34,531
  • 8
  • 55
  • 100
  • 1
    If I do this with an infinite generator, will the dataset/dataloader have any notion of an epoch? Or will the training loop run forever? – Rylan Schaeffer Aug 09 '22 at 16:27
  • 1
    No it won't have any notion length since `data.IterableDataset` doesn't implement a `__len__` function. – Ivan Aug 09 '22 at 17:37
  • 1
    The proposed solution did not work for me. I get constantly this error: TypeError: 'generator' object is not callable – Basilique Dec 06 '22 at 14:35
  • Did you replicate the code from above? Can you provide your code? – Ivan Dec 06 '22 at 17:01
  • It should be `return self.generator` (without parenthesis) in `__iter__`, otherwise I also get `TypeError: 'generator' object is not callable` – roygbiv Jun 23 '23 at 19:33
3

With the limited information that you provide, this is the simplest solution (I assume that your generator creates images from noise such as the original gans):

import torch

def get_data(batch_size, generator, latent_dim=512):
    z = torch.randn(batch_size, latent_dim)
    return genenerator(z)

def dataloader(batch_size, generator, iteration, latent_dim=512):
    for i in range(iteration):
        yield(get_data(batch_size, generator, latent_dim))

batch_size = 64
generator = GANs(...)
iteration = 100
latent_dim = 512

loader = dataloader(batch_size, generator, iteration, latent_dim)
for images in loader:
    # do something
CuCaRot
  • 1,208
  • 7
  • 23