8

I'm creating a custom dataset for NLP-related tasks.

In the PyTorch custom datast tutorial, we see that the __getitem__() method leaves room for a transform before it returns a sample:

def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
       
        ### SOME DATA MANIPULATION HERE ###

        sample = {'image': image, 'landmarks': landmarks}
        if self.transform:
            sample = self.transform(sample)

        return sample

However, the code here:

        if torch.is_tensor(idx):
            idx = idx.tolist()

implies that multiple items should be able to be retrieved at a time which leaves me wondering:

  1. How does that transform work on multiple items? Take the custom transforms in the tutorial for example. They do not look like they could be applied to a batch of samples in a single call.

  2. Related, how does a DataLoader retrieve a batch of multiple samples in parallel and apply said transform if the transform can only be applied to a single sample?

Maura Pintor
  • 198
  • 1
  • 14
rocksNwaves
  • 5,331
  • 4
  • 38
  • 77

2 Answers2

4
  1. How does that transform work on multiple items? They work on multiple items through use of the data loader. By using transforms, you are specifying what should happen to a single emission of data (e.g., batch_size=1). The data loader takes your specified batch_size and makes n calls to the __getitem__ method in the torch data set, applying the transform to each sample sent into training/validation. It then collates n samples into your batch size emitted from the data loader.

  2. Related, how does a DataLoader retrieve a batch of multiple samples in parallel and apply said transform if the transform can only be applied to a single sample? Hopefully above makes sense to you. Parallelization is done by the torch data set class and the data loader, where you specify num_workers. Torch will pickle the data set and spread it across workers.

John Stud
  • 1,506
  • 23
  • 46
  • 1
    "_... makes n calls to `__getitem__ `_" is exactly what I wanted to know. Then that means passing a list index will do nothing and should be avoided. I will remove those lines of code from my own implementation. Thank you. – rocksNwaves Feb 25 '21 at 16:35
  • 1
    Yes, you do not need to pass it an index list except for likely very rare jobs. In the `__init__` section where you instantiate your data set, it will determine the length of the data set, in part with the `__len__` method, and then build the index list itself. – John Stud Feb 25 '21 at 16:38
  • 1
    Right, I thought that was how a `DataLoader` might be accessing `n` items in a `batch_size=n`, since the official tutorial had those lines of code about a list index. I believe a list would throw an error when the transformation of the "batch" failed. – rocksNwaves Feb 25 '21 at 16:42
  • "I will remove those lines of code..." I often see that code in many custom dataset implementations. So is this completely unnecessary? If so, then why is this used everywhere? – krenerd Feb 21 '22 at 04:38
1

from the documentation of transforms from torchvision:

All transformations accept PIL Image, Tensor Image or batch of Tensor Images as input. Tensor Image is a tensor with (C, H, W) shape, where C is a number of channels, H and W are image height and width. Batch of Tensor Images is a tensor of (B, C, H, W) shape, where B is a number of images in the batch. Deterministic or random transformations applied on the batch of Tensor Images identically transform all the images of the batch.

This means that you can pass a batch of images, and the transform will be applied to the whole batch, as long as it respects the shape. The list indexes act on the iloc from the dataframe, which selects either a single index or a list of them, returning the requested subset.

Maura Pintor
  • 198
  • 1
  • 14
  • Hi Maura, custom datasets are used outside the world of CV, and so I'm interested in the interaction of datasets with data loaders in a more general sense. Specifically, I'm working with text data rather than images. – rocksNwaves Feb 25 '21 at 14:27
  • 1
    I have edited my question for clarity if you would like to try to answer it again. – rocksNwaves Feb 25 '21 at 14:33
  • I think you can still get inspiration from the [`Lambda` transform](https://github.com/pytorch/vision/blob/master/torchvision/transforms/transforms.py#L391) to modify your data at loading time. If you want to transform data in batches, maybe you should implement the processing inside the `nn.Module.forward`, however this happens only after loading and might be less efficient. – Maura Pintor Feb 25 '21 at 14:34
  • That's definitely one approach, but it doesn't make your network extending `nn.Module` class very extensible itself. But anyway, my question is somewhat less about "How do I achieve X?" and is rather a "How is X achieved in Y example?" For example, if the transforms are only meant for a single sample, why allow a list index to be passed? There is something smelly here, or something fancy going on under the hood. I want to know which it is. – rocksNwaves Feb 25 '21 at 14:38
  • 1
    Ok, sorry I didn't get that :) Anyways, I think it is still related to how the transforms are defined in torchvision (yes, image domain is everywhere, I get your point, however for a long time this has been the main topic). Some of them are actually `nn.Module` themselves. – Maura Pintor Feb 25 '21 at 14:47