2

I am trying to construct a PyTorch Dataset that returns 21x512x512 slices from 3D images of shape ?x512x512. I know how many images there are, but I do not know how many slices there are in each image. Therefore, I would intuitively make the __len__() function of the Dataset return the total number of images I have. I could technically check the shapes of all images beforehand, but the dataset might change over time, so I would greatly prefer a scalable software solution.

With this, I need some functionality that breaks the image into slices (of size mentioned above), and return these instead of the whole image. This is also not a problem, I have a function that can do this.

Here comes the problem. If I add this slicing functionality in the __getitem__() function of the Dataset, then I will only get one slice per image, since the PyTorch DataLoader will think there are len(dataset) datapoints, which is not the case anymore. But I also cannot specify the correct number of samples, as I do not know it beforehand.

I tried some solutions:

  1. Return a generator function in __getitem__() which yields slices per image. This does not work because __getitem__() needs to return something of type list, tuple, tensor etc.
  2. Just return the whole image and break it up in the train loop. This can work, but is both bad programming style (as I want to hide data selection in the dataset) and not very compatible with the batching of the DataLoader, as one image might have 100 slices in it, while another might only have 5 slices. In this case, making batches from these images would result in very only 5 batches with the actual batch_size, and 95 others with less samples per batch. Resolving this would require some ugly check to see if another image needs to be loaded, which I again would like to hide in a Dataset.
  3. Yield results in a for loop in the __getitem__() function of the Dataset. This does not work for the same reason as point 1: a generator cannot be returned in a Dataset.

So in short, what is a clean way to load an unknown number of slices from 3D images in a PyTorch Dataset?

  • why don't you count the number of slices in advance? – Shai Dec 12 '19 at 14:05
  • Because it's a fairly big dataset that may change over time. It is not impossible to do this (and redo it when new data is added) but I would greatly prefer a scalable software solution. Good point though, I will update the question. – int elligence Dec 12 '19 at 14:53
  • Sounds to me like a good example of when to use an [iterable style dataset](https://pytorch.org/docs/stable/data.html#iterable-style-datasets) – jodag Dec 13 '19 at 09:35

0 Answers0