You need to read your image files with a class that derives from the torch.utils.data.Dataset
class, in order to have your custom dataset
You can follow this part of the documentation to have a basic example of how to populate a custom Dataset.
So, After you define
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
with the three mandatory methods (look at the documentation above)
def __init__(self, img_dir, ...):
def __len__(self):
def __getitem__(self, idx):
you can create an instance of the class and you could test your code, by verifying that the number of files is the intended one and that the method can fetch images - for example with the following lines:
trainset = CustomImageDataset(train_image_dir)
print('N of loaded images: {}'.format(len(trainset))
first_image, first_label = trainset[0]
Most likely, you want that the init read all the files in RAM, so it is within the init that you will define the logic for exploring the paths and load the pictures. If I understood correctly what you want, the getitem function could defined in such a way that it returns two elements, the first being an image in the A folder, and the second output will be the related image of the B folder.
Afterwards, you would need only to instantiate the validation dataset, without the need to define a new class
valset = CustomImageDataset(valid_image_dir)
from this point, you have the logic for reading the data. Afterwards, you can let pytorch handle the batching of images through its own implementation of the dataloader, which you do not have to derive like we did before, but just to instantiate train_dataloader
and valid_dataloader
from torch.utils.data import DataLoader
train_dataloader = DataLoader(trainset, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valset, batch_size=64, shuffle=True)