I want to use a dataloader in my script.
normaly the default function call would be like this.
dataset = ImageFolderWithPaths(
data_dir,
transforms.Compose([
transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
transforms.Resize((img_size_XY, img_size_XY)),
transforms.ToTensor(),
transforms.Normalize(_mean , _std)
])
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2
)
and to iterate through this dataloader i use
for inputs, labels , paths in _dataloader:
break
now i need to collect the path for each image.
i found in github this code: (https://gist.github.com/andrewjong/6b02ff237533b3b2c554701fb53d5c4d)
class ImageFolderWithPaths(datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
# override the __getitem__ method. this is the method that dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
# EXAMPLE USAGE:
# instantiate the dataset and dataloader
data_dir = "your/data_dir/here"
dataset = ImageFolderWithPaths(data_dir) # our custom dataset
dataloader = torch.utils.DataLoader(dataset)
# iterate over data
for inputs, labels, paths in dataloader:
# use the above variables freely
print(inputs, labels, paths)
But this code does not take transforms into account, like in my original code.
Can anybody help in how I should go about making it work with that?