2

I am trying to create a custom IterableDataset in pytorch and split it into train, validation and test datasets using this answer https://stackoverflow.com/a/61818182/9478434 . My dataset class is:

class EchoDataset(torch.utils.data.IterableDataset):
    
  def __init__(self, delay=4, seq_length=15, size=1000):
    super(EchoDataset).__init__()
    self.delay = delay
    self.seq_length = seq_length
    self.size = size
  
  def __len__(self):
    return self.size

  def __iter__(self):
    """ Iterable dataset doesn't have to implement __getitem__.
        Instead, we only need to implement __iter__ to return
        an iterator (or generator).
    """
    for _ in range(self.size):
      seq = torch.tensor([random.choice(range(1, N + 1)) for i in range(self.seq_length)], dtype=torch.int64)
      result = torch.cat((torch.zeros(self.delay), seq[:self.seq_length - self.delay])).type(torch.int64)
      yield seq, result

And the dataset is created and splitted as:

DELAY = 4
DATASET_SIZE = 200000
ds = EchoDataset(delay=DELAY, size=DATASET_SIZE)

train_count = int(0.7 * DATASET_SIZE)
valid_count = int(0.2 * DATASET_SIZE)
test_count = DATASET_SIZE - train_count - valid_count
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
    ds, (train_count, valid_count, test_count)
)

The problem is that when I want to iterate into the dataloader, I get NotImplementedError:

iterator = iter(train_dataset_loader) print(next(iterator))

I get:

NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.

Original Traceback (most recent call last):
  File "venv/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)

  File "venv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]

  File "venv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]

  File "venv/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 363, in __getitem__
    return self.dataset[self.indices[idx]]

  File "venv/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 69, in __getitem__
    raise NotImplementedError
NotImplementedError

It seems the problem goes back to splitting the dataset and creating Subset objects since I can iterate through a Dataloader created from the original dataset (not splitted)

Pouria
  • 49
  • 8
  • Since you're inheriting from the `torch.utils.data.IterableDataset` class, you must call ***super()*** in your ***__init__*** method to initialize the parent class as well. – dc_Bita98 Dec 26 '21 at 18:08
  • I added super() to my class __init__ but the result did not change.It still returns the same error. – Pouria Dec 26 '21 at 18:39
  • 1
    I don't know if you can use random_split along with IterableDataset. I've had a look at the pytorch source code and *random_split* returns instances of ***Subset***, which internally must call *__getitem__*, i.e. it must be a map-style dataset and not an iterable one. – dc_Bita98 Dec 26 '21 at 19:00

0 Answers0