I'm trying to train/validate a CNN using Pytorch on an unbalanced image dataset (class 1:250 images, class 0: 4000ish images), and right now, I've tried augmentation solely on my training set (thanks @jodag). However, my model is still learning to favor the class with significantly more images.
I want to find ways to compensate for my unbalanced data set.
I thought about using oversampling/undersampling using the imbalanced data sampler (https://github.com/ufoym/imbalanced-dataset-sampler), but I already use a sampler to select indices for my 5-fold validation. Is there a way I could implement cross-validation using the code below and also add this sampler? Similarly, is there a way to augment one label more frequently than the other? Along the lines of these questions, are there any alternative easier ways that I could address my unbalanced dataset that I haven't looked into yet?
Here's an example of what I have so far
total_set = datasets.ImageFolder(PATH)
KF_splits = KFold(n_splits= 5, shuffle = True, random_state = 42)
for train_idx, valid_idx in KF_splits.split(total_set):
#sampler to get indices for cross validation
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
#Use a wrapper to apply augmentation only to training set
#These are dataloaders that pull images from the same folder but sort into validation and training sets
#Though transforms augment only the training set, it doesn't address
#the underlying issue of a heavily unbalanced dataset
train_loader = torch.utils.data.DataLoader(
WrapperDataset(total_set, transform=data_transforms['train']),
batch_size=32, sampler=ImbalancedDatasetSampler(total_set))
valid_loader = torch.utils.data.DataLoader(
WrapperDataset(total_set, transform=data_transforms['val']),
batch_size=32)
print("Fold:" + str(i))
for epoch in range(epochs):
#Train/validate model below
`
Thank you for your time and help!