I have a generator that creates synthetic data. How can I convert this into a PyTorch dataloader?
Asked
Active
Viewed 3,021 times
2 Answers
6
You can wrap your generator with a data.IterableDataset
:
class IterDataset(data.IterableDataset):
def __init__(self, generator):
self.generator = generator
def __iter__(self):
return self.generator()
Naturally, you can then wrap this dataset with a data.DataLoader
.
Here is a minimal example showing its use:
>>> gen = lambda: [(yield x) for x in range(10)]
>>> dataset = IterDataset(gen)
>>> for i in data.DataLoader(dataset, batch_size=2):
... print(i)
tensor([0, 1])
tensor([2, 3])
tensor([4, 5])
tensor([6, 7])
tensor([8, 9])

Ivan
- 34,531
- 8
- 55
- 100
-
1If I do this with an infinite generator, will the dataset/dataloader have any notion of an epoch? Or will the training loop run forever? – Rylan Schaeffer Aug 09 '22 at 16:27
-
1No it won't have any notion length since `data.IterableDataset` doesn't implement a `__len__` function. – Ivan Aug 09 '22 at 17:37
-
1The proposed solution did not work for me. I get constantly this error: TypeError: 'generator' object is not callable – Basilique Dec 06 '22 at 14:35
-
Did you replicate the code from above? Can you provide your code? – Ivan Dec 06 '22 at 17:01
-
It should be `return self.generator` (without parenthesis) in `__iter__`, otherwise I also get `TypeError: 'generator' object is not callable` – roygbiv Jun 23 '23 at 19:33
3
With the limited information that you provide, this is the simplest solution (I assume that your generator creates images from noise such as the original gans):
import torch
def get_data(batch_size, generator, latent_dim=512):
z = torch.randn(batch_size, latent_dim)
return genenerator(z)
def dataloader(batch_size, generator, iteration, latent_dim=512):
for i in range(iteration):
yield(get_data(batch_size, generator, latent_dim))
batch_size = 64
generator = GANs(...)
iteration = 100
latent_dim = 512
loader = dataloader(batch_size, generator, iteration, latent_dim)
for images in loader:
# do something

CuCaRot
- 1,208
- 7
- 23
-
I more meant a Pytorch dataloader, not a generic function named dataloader. My apologies for not being more specific! – Rylan Schaeffer Aug 04 '22 at 07:51