1

I am using PyTorch to train a machine learning model, and I have encountered a significant issue where iterating over the DataLoader is noticeably slower than directly accessing the dataset. My main goal is to speed up the data loading process during training since it takes considerably more time to wait for the DataLoader to fetch the data.

For example, when I iterate over the DataLoader like this:

for inputs,labels in tqdm(dataloader):
  pass

It takes more than 15 seconds to complete.

However, when I iterate directly over the dataset :

for inputs,labels in tqdm(zip(dataloader.dataset.data, dataloader.dataset.targets)):
  pass

It completes in less than 1 second.

I have already disabled shuffling, and I've experimented with adjusting the num_workers parameter, but it didn't significantly reduce the time difference.

The issue I am facing is not related to resource constraints, as my CPU and memory utilization are well below their maximum capacities. The waiting time during training occurs specifically when using the PyTorch DataLoader, and I'm seeking solutions to speed up the data loading process for more efficient training.

Moreover, my data reading and writing operations are not the bottleneck since I/O performance is not the limiting factor. The problem is observed during the data loading process within the PyTorch DataLoader, which takes longer than expected despite sufficient I/O capabilities.

Reproducible example:

import torch
from tqdm import tqdm
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])

trainset = datasets.MNIST('MNINST', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False)
for data,targets in tqdm(trainloader):
    pass

for data,targets in tqdm(zip(trainloader.dataset.data,trainloader.dataset.targets)):
    pass

result

Edit : The impact becomes more pronounced as you increase the batch_size. For example, I replicated a dataloader with shuffle=True, batch_size=64 and taking into account the MarGenDo remarks :

import torch
from tqdm import tqdm
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])

batch_size=64
trainset = datasets.MNIST('MNINST', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

for data,targets in tqdm(trainloader):
    pass

indices = torch.randperm(len(trainset))
for i in tqdm(range(0,len(indices),batch_size)):
    data = []
    targets = []
    
    for j in range(i,i+batch_size):
        if j < len(indices):
            data.append(trainset.data[indices[j]])
            targets.append(trainset.targets[indices[j]])
            
    data = torch.utils.data.default_collate(data)
    targets = torch.utils.data.default_collate(targets)
    
    tensor = (data.to(torch.float) / 256).unsqueeze(0)
    normalized = transforms.functional.normalize(tensor, (0.5,), (0.5,))

result2

Edit2 : taking into account the MarGenDo 2nd remark :

import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx],self.labels[idx]
        return sample


n=100000
data = torch.randn(n, 3, 28, 28)  
labels = torch.randint(0, 10, (n,))  

custom_dataset = CustomDataset(data, labels)

batch_size = 1
dataloader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=False)

for inputs, labels in tqdm(dataloader):
    pass

for inputs, labels in tqdm(zip(dataloader.dataset.data,dataloader.dataset.labels)):
    pass

result3

1 Answers1

0

The significant time difference is caused by inefficient conversions between PIL images and torch tensors.

In the first case, the DataLoader internally iterates over the dataset (=trainset), which is a list of tuples of PIL images and targets. The PIL images are converted to tensors using your provided transforms pipeline (ToTensor, Normalize).

In the second case, you are directly iterating over trainloader.dataset.data which points to trainset.data. There is no conversion and normalization.

Comparing one sample from both cases, the first is a normalized tensor of the image with float datatype, the second is representated using uint8 (range 0-255).


This code replicates what DataLoader is doing in the background when iterating over the dataset. Both of these cases take roughly equally to run.

# standard dataloader loop
for data, target in tqdm(trainloader):
    pass

# Same loop without dataloader
trainset = datasets.MNIST('MNINST', download=True, train=True) # Remove transforms from the trainset
for data, target in tqdm(trainset):
    tensor = transforms.functional.to_tensor(data)
    normalized = transforms.functional.normalize(tensor, (0.5,), (0.5,))
    # `normalized` now contains the same tensor as `data` in the previous case

Now to actually improve the efficiency, you could start with your solution of iterating over the actual data of the dataset and normalize them manually:

for data, target in tqdm(zip(trainloader.dataset.data,trainloader.dataset.targets)):
    tensor = (data.to(torch.float) / 256).unsqueeze(0)
    normalized = transforms.functional.normalize(tensor, (0.5,), (0.5,))
    # `normalized` now contains the same tensor as `data` in the previous case

This reduces running time by more than half on my machine.

MarGenDo
  • 727
  • 1
  • 8
  • 17
  • I have considered your suggestion and edited my topic accordingly. The issue is that the significance increases when altering the batch_size. The dataloader doesn't show any improvement when changing the batch_size, whereas my method (which I've detailed in my edited topic) seems to benefit. I don't believe it's due to the conversion or normalization, as removing them from my method doesn't yield a substantial time improvement. – triple_double Aug 05 '23 at 09:00
  • Yes, that is what I explained in my answer. Accessing dataset.data is significantly faster because the data is already in torch tensor format. On the other hand, the dataloader converts each sample from PIL image to torch tensor which is significantly slower. In your fast method, there is no conversion from PIL to tensor, that is why it is fast. – MarGenDo Aug 05 '23 at 09:17
  • Good catch! That confirms my answer, dataset method `__getitem__` is called each time the dataloader iterates and fetches new data. As you can see in the source code https://github.com/pytorch/vision/blob/main/torchvision/datasets/mnist.py , there is that slow conversion. It is even worse than I expected as it converts from torch, to numpy to PIL image and then back to torch using your transforms function. On the other hand, iterating over `dataset.data` is fast because you are accessing the tensors directly. – MarGenDo Aug 05 '23 at 09:25
  • The issue is that the observation I mentioned holds true even for my dataset (which I didn't include in my question). In this dataset, the \_\_getitem\_\_ method doesn't involve any additional processing and simply returns self.data[idx] and self.labels[idx], both of which are in tensor format. – triple_double Aug 05 '23 at 09:29
  • I am sorry but I can't think of any other reason for the slow dataloader if I don't see your source code and your dataset. I think your original questions is answered correctly and the reason for the slowness of your own dataset might even be completely unrelated if what you are saying is correct. – MarGenDo Aug 05 '23 at 09:41
  • 1
    I've added a more precise second example about that. – triple_double Aug 05 '23 at 09:44
  • In this case, after increasing the batch size to larger number (e.g. 64), the dataloader runs almost as fast as the other loop. It seems that the dataloader simply has a huge overhead for more advanced features that are not used in this case (memory management, multi process loading...). – MarGenDo Aug 05 '23 at 09:55