0

So, I'm trying to load this dataset in pytorch, I'm facing a problem while loading it.

As you can make out my checking the dataset that the directory looks somethings like this:

  • root

    • monet_jpg

    • monet_tfrec

    • photo_jpg

    • photo_tfrec

So, I want to load the photo and monet images in separate dataloader variables. But this method doesn't seem to work.

EDIT: By that I mean the monet_ds and photo_ds return only monet images (while photo_ds should return images from photo_jpg)

I'm trying to load the data through this code:

import torchvision.datasets as dset
import torchvision.utils as vutils
from torch.utils.data import Subset
​
def load_data(dataroot , image_size, batch_size, workers,ngpu,shuffle=True):
    #DataLoading
    # Create the dataset
    dataset = dset.ImageFolder(root=dataroot,
                            transform=transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
    print(dataset.class_to_idx)
    #print(dataset.imgs)
    monet_ds = Subset(dataset, range(0,299))
    photo_ds = Subset(dataset, range(300,))
    
    # Create the dataloader
    monet_ds = torch.utils.data.DataLoader(monet_ds, batch_size=batch_size,
                                             num_workers=workers)
    photo_ds = torch.utils.data.DataLoader(photo_ds, batch_size=batch_size,
                                             num_workers=workers)
    # Decide which device we want to run on
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
​
    print("Data loaded...")

root = "../input/gan-getting-started"
monet_ds, photo_ds, device = load_data(root, image_size, batch_size, workers, ngpu)

Any help for loading this data perfectly in pytorch would be of good help. Thank you.

1 Answers1

0

It seems that they are completely independent, so the following should work just fine:

import os

from torchvision.datasets.folder import default_loader
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


class MonetPhotoDataset(Dataset):
    def __init__(self, root, transform=None):
        self.transform = transform
        self.img_paths = sorted(os.path.join(root, x) for x in os.listdir(root) if x.endswith('.jpg'))

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        sample = default_loader(img_path)
        if self.transform is not None:
            sample = self.transform(sample)
        return sample


def load_data(dataroot, image_size, batch_size, workers, ngpu, shuffle=True):
    # set up transform
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    # create datasets
    monet_ds = MonetPhotoDataset(root=os.path.join(dataroot, 'monet_jpg'), transform=transform)
    photo_ds = MonetPhotoDataset(root=os.path.join(dataroot, 'photo_jpg'), transform=transform)

    # create dataloaders
    monet_dl = DataLoader(monet_ds, batch_size=batch_size, num_workers=workers)
    photo_dl = DataLoader(photo_ds, batch_size=batch_size, num_workers=workers)

    # decide which device we want to run on
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
​
    print("Data loaded...")
    return monet_dl, photo_dl, device

root = "../input/gan-getting-started"
monet_dl, photo_dl, device = load_data(root, image_size, batch_size, workers, ngpu)

P.S.: I kept the load_data because I assumed you rely on its signature in your code, but I wouldn't use it otherwise. Also, I didn't test the code above, so expect some typo but the logic is correct.

Note that this dataset returns only the images.

Berriel
  • 12,659
  • 4
  • 43
  • 67
  • This seems to work just fine, just one idendation error for those who try to replicate the above code, rest works well. Thanks. – Prithviraj Kanaujia Aug 17 '21 at 05:44
  • @PrithviRajKanaujia glad it helped... feel to free to fix the indentation error suggesting an edit to the answer :) – Berriel Aug 17 '21 at 11:11
  • Also @Berriel, I'm trying to traverse these monet_dl and photo_dl, and I'm using this method: `for i, (data1, data2) in enumerate(zip(cycle(monet_dl), photo_dl)):` . Is this the best way to traverse it given monet_dl has 300 images and photo_dl has ~7000 images. In one epoch, I would like to traverse through every image in photo_dl – Prithviraj Kanaujia Aug 18 '21 at 07:15