2

The PyTorch DataLoader turns datasets into iterables. I already have a generator which yields data samples that I want to use for training and testing. The reason I use a generator is because the total number of samples is too large to store in memory. I would like to load the samples in batches for training.

What is the best way to do this? Can I do it without a custom DataLoader? The PyTorch dataloader doesn't like taking the generator as input. Below is a minimal example of what I want to do, which produces the error "object of type 'generator' has no len()".

import torch
from torch import nn
from torch.utils.data import DataLoader

def example_generator():
    for i in range(10):
        yield i
    

BATCH_SIZE = 3
train_dataloader = DataLoader(example_generator(),
                        batch_size = BATCH_SIZE,
                        shuffle=False)

print(f"Length of train_dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")

I am trying to take the data from an iterator and take advantage of the functionality of the PyTorch DataLoader. The example I gave is a minimal example of what I would like to achieve, but it produces an error.

Edit: I want to be able to use this function for complex generators in which the len is not known in advance.

bja
  • 27
  • 3

1 Answers1

0

PyTorch's DataLoader actually has official support for an iterable dataset, but it just has to be an instance of a subclass of torch.utils.data.IterableDataset:

An iterable-style dataset is an instance of a subclass of IterableDataset that implements the __iter__() protocol, and represents an iterable over data samples

So your code would be written as:

from torch.utils.data import IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, iterable):
        self.iterable = iterable

    def __iter__(self):
        return iter(self.iterable)

...

BATCH_SIZE = 3

train_dataloader = DataLoader(MyIterableDataset(example_generator()),
                              batch_size = BATCH_SIZE,
                              shuffle=False)
bja
  • 27
  • 3
blhsing
  • 91,368
  • 6
  • 71
  • 106