0

I am loading my training images into a PyTorch dataloader, and I need to calculate the input image's stats. The calculation is taken directly from https://kozodoi.me/python/deep%20learning/pytorch/tutorial/2021/03/08/image-mean-std.html.

T = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor()
        ])

dataset = datasets.ImageFolder(train_dir, T)
image_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, drop_last=True)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
psum = torch.tensor([0.0, 0.0, 0.0]).to(device)
psum_sq = torch.tensor([0.0, 0.0, 0.0]).to(device)

for image, label in image_loader:
    image = image.to(device)
    psum += image.sum(axis=[0, 2, 3])
    psum_sq += (image ** 2).sum(axis=[0, 2, 3])
    count = len(image_loader.dataset) * img_size[0] * img_size[1]

    total_mean = psum / count
    total_var = (psum_sq / count) - (total_mean ** 2)
    total_std = torch.sqrt(total_var)

Profiling revealed that the for loop is a bottleneck. How can I parallelize the operations? I have looked into Dask's delayed and got something like this.

for image, label in image_loader:
    image = image.to(device)
    psum += delayed(calculate_psum)(image)
    psum_sq += delayed(calculate_psum_sq)(image)
    count = delayed(calculate_count)(image_loader, img_size)
        
    total_mean = delayed(calculate_mean)(psum, count)
    total_var = delayed(calculate_var)(psum_sq, count, total_mean)
    total_std = delayed(calculate_std)(total_var)

How should I parallelize each operation, and where should I call compute? I noticed that the total_x values have dependencies. Is that where parallelization is not possible?

UPDATE: Here is a computation graph to see which part is easier to parallelize.

enter image description here

disguisedtoast
  • 149
  • 1
  • 4
  • 15
  • I don't know anything about dask, but are you running on multiple images? Can your GPU handle batch size > 1? Also, increase the number of workers for your data-loader (`num_workers` argument), this will allow pytorch to load the next batch from the data loader while the body of the loop is running. Note that pytorch already uses asynchronous evaluation with cuda operations, so as long as everything stays on the GPU there is already likely parallelization going on behind the scenes. In practice batch size and num workers usually have the most impact. – jodag Jan 21 '22 at 15:19
  • I increased the batch size to 16 and the num_workers to 8 and saw a modest increase in performance. I guess that is the easier way to attain some speedup, but it would be nice if the for loop can be further optimized in some ways (i.e. vectorization). – disguisedtoast Jan 22 '22 at 08:55

0 Answers0