I am trying to create a custom IterableDataset
in pytorch
and split it into train, validation and test datasets using this answer https://stackoverflow.com/a/61818182/9478434 .
My dataset class is:
class EchoDataset(torch.utils.data.IterableDataset):
def __init__(self, delay=4, seq_length=15, size=1000):
super(EchoDataset).__init__()
self.delay = delay
self.seq_length = seq_length
self.size = size
def __len__(self):
return self.size
def __iter__(self):
""" Iterable dataset doesn't have to implement __getitem__.
Instead, we only need to implement __iter__ to return
an iterator (or generator).
"""
for _ in range(self.size):
seq = torch.tensor([random.choice(range(1, N + 1)) for i in range(self.seq_length)], dtype=torch.int64)
result = torch.cat((torch.zeros(self.delay), seq[:self.seq_length - self.delay])).type(torch.int64)
yield seq, result
And the dataset is created and splitted as:
DELAY = 4
DATASET_SIZE = 200000
ds = EchoDataset(delay=DELAY, size=DATASET_SIZE)
train_count = int(0.7 * DATASET_SIZE)
valid_count = int(0.2 * DATASET_SIZE)
test_count = DATASET_SIZE - train_count - valid_count
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
ds, (train_count, valid_count, test_count)
)
The problem is that when I want to iterate into the dataloader, I get NotImplementedError
:
iterator = iter(train_dataset_loader) print(next(iterator))
I get:
NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "venv/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "venv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "venv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "venv/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 363, in __getitem__
return self.dataset[self.indices[idx]]
File "venv/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 69, in __getitem__
raise NotImplementedError
NotImplementedError
It seems the problem goes back to splitting the dataset and creating Subset
objects since I can iterate through a Dataloader
created from the original dataset (not splitted)