2

I am trying to create a binary CNN classifier for an unbalanced dataset (class 0 = 4000 images, class 1 = around 250 images), which I want to perform 5-fold cross validation on. Currently I am loading my training set into an ImageLoader that applies my transformations/augmentations(?) and loads it into a DataLoader. However, this results in both my training splits and validation splits containing the augmented data.

I originally applied transformations offline (offline augmentation?) to balance my dataset, but from this thread (https://stats.stackexchange.com/questions/175504/how-to-do-data-augmentation-and-train-validate-split), it seems it would be ideal to only augment the training set. I would also prefer to train my model on solely augmented training data and then validate it on non-augmented data in a 5-fold cross validation

My data is organized as root/label/images, where there are 2 label folders (0 and 1) and images sorted into the respective labels.

My Code So Far

total_set = datasets.ImageFolder(ROOT, transform = data_transforms['my_transforms'])

//Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)

for train_idx, valid_idx in splits.split(total_set):
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=valid_sampler)

model.train()
//Model train/eval works but may be overpredict 

I'm sure I'm doing something sub-optimally or wrong in this code, but I can't seem to find any documentation on specifically augmenting only the training splits in cross-validation!

Any help would be appreciated!

Community
  • 1
  • 1
jinsom
  • 125
  • 1
  • 11

1 Answers1

2

One approach is to implement a wrapper Dataset class that applies transforms to the output of your ImageFolder dataset. For example

class WrapperDataset:
    def __init__(self, dataset, transform=None, target_transform=None):
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        image, label = self.dataset[index]
        if self.transform is not None:
            image = self.transform(image)
        if self.target_transform is not None:
            label = self.target_transform(label)
        return image, label

    def __len__(self):
        return len(self.dataset)

Then you could use this in your code by wrapping the larger dataset with different transforms.

total_set = datasets.ImageFolder(ROOT)

# Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)

for train_idx, valid_idx in splits.split(total_set):
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        WrapperDataset(total_set, transform=data_transforms['train_transforms']),
        batch_size=32, sampler=train_sampler)
    valid_loader = torch.utils.data.DataLoader(
        WrapperDataset(total_set, transform=data_transforms['valid_transforms']),
        batch_size=32, sampler=valid_sampler)

    # train/validate now

I haven't tested this code since I don't have your full code/models but the concept should be clear.

jodag
  • 19,885
  • 5
  • 47
  • 66
  • Thanks for the reply. I tried implementing your idea and I think it's close to working for my code. I get a TypeError "img should be PIL Image. Got " related to "image = self.transform(image)" from your class WrapperDataset when I try to iterate through my train_loader while training (code would be: for inputs, labels in train_loader: #train etc. – jinsom Aug 18 '19 at 02:45
  • For the purposes of my question, I added "image = transforms.ToPILImage()(image)" before "image = self.transform(image)" and this resolved the error. Thanks again for your help! – jinsom Aug 18 '19 at 02:51
  • Interesting, I'm not sure why this would be since `ImageDataset` should be returning PIL images by default. You've removed the transforms when defining `total_set` right? – jodag Aug 18 '19 at 04:08
  • 1
    You're right, it worked fine as is. I'm using an ipython notebook and likely forgot to reinitialize total_set. Doh! – jinsom Aug 18 '19 at 06:57