1

Please, could you help me to find the solution to my problem. I want to write collate_fn to make my pictures the equal size, but I don't know how to implement it correctly.

Colab: link

Code:

import pandas as pd
import numpy as np
from PIL import Image

from torchvision import transforms
from torch.utils.data.dataset import Dataset  # For custom datasets


class CustomDataset(Dataset):
    def __init__(self, csv_path):
        self.to_tensor = transforms.ToTensor()
        self.data_info = csv_path
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image        
        IMAGE_SIZE = [224,224]

        response = requests.get(single_image_name)
        img_as_img = Image.open(BytesIO(response.content)).resize(IMAGE_SIZE)

        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)

        # Get label(class) of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len
  • I would recommend resizing your images in the `__getitem__` method (or beforehand, as preprocessing) with the `torchvision` library (https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize), then you won't need to write your own `collate_fn` (the default one which stacks tensors will work). Though it looks like you're already trying to resize the images to be 224x224 in `__getitem__` - what kind of error are you running into, then? – jayelm Jun 21 '21 at 03:39

1 Answers1

0

In order to resize your images into the same size, you can use opencv library. In order to install the library, run the following command.

pip install opencv-python

The function that you need to use is the following.

cv2.resize(src, dsize[, dst[, fx[, fy[, interpolation]]]])

You can find the detailed documentation of the library at the following link.

Harris Minhas
  • 702
  • 3
  • 17