2

I have a custom Dataset that loads data from large files. Sometimes, the loaded data are empty and I don't want to use them for training.

In Dataset I have:

   def __getitem__(self, i):
       (x, y) = self.getData(i) #getData loads data and handles problems     
       return (x, y)

which in case of bad data return (None, None) (x and y are both None). However, it later fails in DataLoader and I am not able to skip this batch entirely. I have the batch size set to 1.

trainLoader = DataLoader(trainDataset, batch_size=1, shuffle=False)
for x_batch, y_batch in trainLoader:
    #process and train
Ivan
  • 34,531
  • 8
  • 55
  • 100
Martin Perry
  • 9,232
  • 8
  • 46
  • 114
  • Can you not preprocess those missing files out of your dataset instead of doing it at runtime? – Ivan Sep 24 '21 at 14:59
  • @Ivan It can be the solution, but it is quite time-consuming to pre-process the entire dataset. Plus for the future, I can have data loaded from the network which I dont have before. So am looking for runtime-based solution. – Martin Perry Sep 24 '21 at 15:03

1 Answers1

3

You could implement a custom IterableDataset and define a __next__ and __iter__ that would skip any instances for which your getData function has raised an error on:

Here is a possible implementation with dummy data:

class DS(IterableDataset):
    def __init__(self):
        self.data = torch.randint(0,3,(20,))
        self._i = -1

    def getData(self, index):
        x = self.data[index]
        if x == 0:
            raise ValueError
        return x

    def __iter__(self):
        return self

    def __next__(self):
        self._i += 1
        if self._i == len(self.data):  # out of instances
            self._i = -1               # reset the iterable
            raise StopIteration        # stop the iteration
        try:
            return self.getData(self._i)
        except ValueError:
            return next(self)

You would use it like:

>>> trainLoader = DataLoader(DS(), batch_size=1, shuffle=False)
>>> for x in trainLoader:
...    print(x)
tensor([1])
tensor([2])
tensor([2])
...
tensor([1])
tensor([1])

Here all 0 instances have been skipped in the iterable dataset.

You can adapt this simple example to fit your needs.

Ivan
  • 34,531
  • 8
  • 55
  • 100