1

I am new to pytorch. I am trying to create a DataLoader for a dataset of images where each image got a corresponding ground truth (same name):

root:
--->RGB:
------>img1.png
------>img2.png
------>...
------>imgN.png
--->GT:
------>img1.png
------>img2.png
------>...
------>imgN.png

When I use the path for root folder (that contains RGB and GT folders) as input for the torchvision.datasets.ImageFolder it reads all of the images as if they were all intended for input (classified as RGB and GT), and it seems like there is no way to pair the RGB-GT images. I would like to pair the RGB-GT images, shuffle, and divide it to batches of defined size. How can it be done? Any advice will be appreciated. Thanks.

yliats
  • 93
  • 2
  • 10

1 Answers1

2

I think, the good starting point is to use VisionDataset class as a base. What we are going to use here is: DatasetFolder source code. So, we going to create smth similar. You can notice this class depends on two other functions from datasets.folder module: default_loader and make_dataset.

We are not going to modify default_loader, because it's already fine, it just helps us to load images, so we will import it.

But we need a new make_dataset function, that prepared the right pairs of images from root folder. Since original make_dataset pairs images (image paths if to be more precisely) and their root folder as target class (class index) and we have a list of (path, class_to_idx[target]) pairs, but we need (rgb_path, gt_path). Here is the code for new make_dataset:

def make_dataset(root: str) -> list:
    """Reads a directory with data.
    Returns a dataset as a list of tuples of paired image paths: (rgb_path, gt_path)
    """
    dataset = []

    # Our dir names
    rgb_dir = 'RGB'
    gt_dir = 'GT'   

    # Get all the filenames from RGB folder
    rgb_fnames = sorted(os.listdir(os.path.join(root, rgb_dir)))

    # Compare file names from GT folder to file names from RGB:
    for gt_fname in sorted(os.listdir(os.path.join(root, gt_dir))):

            if gt_fname in rgb_fnames:
                # if we have a match - create pair of full path to the corresponding images
                rgb_path = os.path.join(root, rgb_dir, gt_fname)
                gt_path = os.path.join(root, gt_dir, gt_fname)

                item = (rgb_path, gt_path)
                # append to the list dataset
                dataset.append(item)
            else:
                continue

    return dataset

What do we have now? Let's compare our function with original one:

from torchvision.datasets.folder import make_dataset as make_dataset_original


dataset_original = make_dataset_original(root, {'RGB': 0, 'GT': 1}, extensions='png')
dataset = make_dataset(root)

print('Original make_dataset:')
print(*dataset_original, sep='\n')

print('Our make_dataset:')
print(*dataset, sep='\n')
Original make_dataset:
('./data/GT/img1.png', 1)
('./data/GT/img2.png', 1)
...
('./data/RGB/img1.png', 0)
('./data/RGB/img2.png', 0)
...
Our make_dataset:
('./data/RGB/img1.png', './data/GT/img1.png')
('./data/RGB/img2.png', './data/GT/img2.png')
...

I think it works great) It's time to create our class Dataset. The most important part here is __getitem__ methods, because it imports images, applies transformation and returns a tensors, that can be used by dataloaders. We need to read a pair of images (rgb and gt) and return a tuple of 2 tensor images:

from torchvision.datasets.folder import default_loader
from torchvision.datasets.vision import VisionDataset


class CustomVisionDataset(VisionDataset):

    def __init__(self,
                 root,
                 loader=default_loader,
                 rgb_transform=None,
                 gt_transform=None):
        super().__init__(root,
                         transform=rgb_transform,
                         target_transform=gt_transform)

        # Prepare dataset
        samples = make_dataset(self.root)

        self.loader = loader
        self.samples = samples
        # list of RGB images
        self.rgb_samples = [s[1] for s in samples]
        # list of GT images
        self.gt_samples = [s[1] for s in samples]

    def __getitem__(self, index):
        """Returns a data sample from our dataset.
        """
        # getting our paths to images
        rgb_path, gt_path = self.samples[index]

        # import each image using loader (by default it's PIL)
        rgb_sample = self.loader(rgb_path)
        gt_sample = self.loader(gt_path)

        # here goes tranforms if needed
        # maybe we need different tranforms for each type of image
        if self.transform is not None:
            rgb_sample = self.transform(rgb_sample)
        if self.target_transform is not None:
            gt_sample = self.target_transform(gt_sample)      

        # now we return the right imported pair of images (tensors)
        return rgb_sample, gt_sample

    def __len__(self):
        return len(self.samples)

Let's test it:

from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


bs=4  # batch size
transforms = ToTensor()  # we need this to convert PIL images to Tensor
shuffle = True

dataset = CustomVisionDataset('./data', rgb_transform=transforms, gt_transform=transforms)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=shuffle)

for i, (rgb, gt) in enumerate(dataloader):
    print(f'batch {i+1}:')
    # some plots
    for i in range(bs):
        plt.figure(figsize=(10, 5))
        plt.subplot(221)
        plt.imshow(rgb[i].squeeze().permute(1, 2, 0))
        plt.title(f'RGB img{i+1}')
        plt.subplot(222)
        plt.imshow(gt[i].squeeze().permute(1, 2, 0))
        plt.title(f'GT img{i+1}')
        plt.show()

Out:

batch 1:

a b c

...

Here you can find a notebook with code and simple dummy dataset.

trsvchn
  • 8,033
  • 3
  • 23
  • 30
  • Thank you for your help. When I'm trying to import Vision Dataset: from torchvision.datasets.vision import VisionDataset I get an error: from torchvision.datasets.vision import VisionDataset ImportError: No module named vision – yliats Dec 26 '19 at 10:04
  • 1
    After the upgrade of torchvision everything worked great. Thanks again. – yliats Dec 28 '19 at 08:24