I have a dataloader that is initialised with a iterable dataset. I found that when I use multiprocessing (i.e. num_workers>0 in DataLoader) in dataloader, once the dataloader is exhausted after one epoch, it doesn't get reset automatically when I iterate it again in the second epoch. Below is a small reproducible example.
I am aware that "Workers are shut down once the end of the iteration is reached," according to the documentation. However, I would like to know how to achieve my expected behaviour of "resetting automatically". Thanks for any help in advance!
import torch
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super().__init__()
self.start = start
self.end = end
def __iter__(self):
return iter(range(self.start, self.end))
dataset = MyIterableDataset(0, 4)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False, num_workers=1, drop_last=False)
for epoch in range(2):
for i, data in enumerate(dataloader):
print(i, data)
"""
stdout:
0 tensor([0, 1])
1 tensor([2, 3])
2 _IterableDatasetStopIteration(worker_id=0)
"""
While my expectation of stdout is
"""
0 tensor([0, 1])
1 tensor([2, 3])
0 tensor([0, 1])
1 tensor([2, 3])
"""
I am using the latest pytorch version (1.6.0)