I am trying to construct a PyTorch Dataset that returns 21x512x512
slices from 3D images of shape ?x512x512
. I know how many images there are, but I do not know how many slices there are in each image. Therefore, I would intuitively make the __len__()
function of the Dataset return the total number of images I have. I could technically check the shapes of all images beforehand, but the dataset might change over time, so I would greatly prefer a scalable software solution.
With this, I need some functionality that breaks the image into slices (of size mentioned above), and return these instead of the whole image. This is also not a problem, I have a function that can do this.
Here comes the problem. If I add this slicing functionality in the __getitem__()
function of the Dataset, then I will only get one slice per image, since the PyTorch DataLoader will think there are len(dataset)
datapoints, which is not the case anymore. But I also cannot specify the correct number of samples, as I do not know it beforehand.
I tried some solutions:
- Return a generator function in
__getitem__()
which yields slices per image. This does not work because__getitem__()
needs to return something of typelist
,tuple
,tensor
etc. - Just return the whole image and break it up in the train loop. This can work, but is both bad programming style (as I want to hide data selection in the dataset) and not very compatible with the batching of the DataLoader, as one image might have 100 slices in it, while another might only have 5 slices. In this case, making batches from these images would result in very only 5 batches with the actual
batch_size
, and 95 others with less samples per batch. Resolving this would require some ugly check to see if another image needs to be loaded, which I again would like to hide in a Dataset. - Yield results in a
for
loop in the__getitem__()
function of the Dataset. This does not work for the same reason as point 1: a generator cannot be returned in a Dataset.
So in short, what is a clean way to load an unknown number of slices from 3D images in a PyTorch Dataset?