1

Hey I've been seatching around the web looking for help on importing a custom image set, but every good tutorial seems to just MNIST which is fine, but I dont know how to translate code to a custom set. I've got a folder structure like this: SET:
|-->Training
-----|-->A
---------|-->8000 items
-----|-->B
----------|-->8000 items
|-->Validation
-----|-->A
----------|-->600 items
-----|-->B
----------|-->600 items

I want to train a GAN on the set of 8000 input images in Training set A to hopefully learn to mimic Training set B

I've been having no luck understand all the self inheritance from MNIST and how to use that with a custom set

Viktor hj
  • 11
  • 1

1 Answers1

1

You need to read your image files with a class that derives from the torch.utils.data.Dataset class, in order to have your custom dataset You can follow this part of the documentation to have a basic example of how to populate a custom Dataset. So, After you define

from torch.utils.data import Dataset
class CustomImageDataset(Dataset):

with the three mandatory methods (look at the documentation above)

def __init__(self, img_dir, ...):
def __len__(self):
def __getitem__(self, idx):

you can create an instance of the class and you could test your code, by verifying that the number of files is the intended one and that the method can fetch images - for example with the following lines:

trainset = CustomImageDataset(train_image_dir) 
print('N of loaded images: {}'.format(len(trainset)) 
first_image, first_label = trainset[0]

Most likely, you want that the init read all the files in RAM, so it is within the init that you will define the logic for exploring the paths and load the pictures. If I understood correctly what you want, the getitem function could defined in such a way that it returns two elements, the first being an image in the A folder, and the second output will be the related image of the B folder.

Afterwards, you would need only to instantiate the validation dataset, without the need to define a new class

valset = CustomImageDataset(valid_image_dir)

from this point, you have the logic for reading the data. Afterwards, you can let pytorch handle the batching of images through its own implementation of the dataloader, which you do not have to derive like we did before, but just to instantiate train_dataloader and valid_dataloader

from torch.utils.data import DataLoader
train_dataloader = DataLoader(trainset, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valset, batch_size=64, shuffle=True)
Domenico
  • 126
  • 13