0

I want to implement this situation for the torchvision MNIST dataset, loading data with DataLoader:

batch A (unaugmented images): 5, 0, 4, ...
batch B (augmented images): 5*, 5+, 5-, 0*, 0+, 0-, 4*, 4+, 4-, ...

... where for every image of A there are 3 augmentations in batch B. len(B) = 3*len(A) correspondingly. These batches should be used within a single iteration to compare the original images of batch A with those augmented in batch B to build a loss.

class MyMNIST(Dataset):

def __init__(self, mnist_dir, train, augmented, transform=None, repeat=1):

    self.mnist_dir = mnist_dir
    self.train = train
    self.augmented = augmented
    self.repeat = repeat
    self.transform = transform
    self.dataset = None

    if augmented and train:
        self.dataset = datasets.MNIST(self.mnist_dir, train=train, download=True, transform=transform)
        self.dataset.data = torch.repeat_interleave(self.dataset.data, repeats=self.repeat, dim=0)
        self.dataset.targets = torch.repeat_interleave(self.dataset.targets, repeats=self.repeat, dim=0)
    elif augmented and not train:
        raise Exception("Test set should not be augmented.")
    else:
        self.dataset = datasets.MNIST(MNIST_DIR, train=train, download=True, transform=transform)

With this class, I want to initialize two different dataloaders:

orig_train = MyMNIST(MNIST_DIR, train=True, augmented=False, transform=orig_transforms)
orig_train_loader = torch.utils.data.DataLoader(orig_train.dataset, batch_size=100, shuffle=True)

aug_train = MyMNIST(MNIST_DIR, train=True, augmented=True, transform=aug_transforms, repeat=3)
aug_train_loader = torch.utils.data.DataLoader(aug_train.dataset, batch_size=300, shuffle=True)

My problem now is, I also need to shuffle with each iteration while the order between A and B stays in relation. Which is not possible with above code, as both DataLoader yield different orders. So I tried to work with a single DataLoader and to manually copy a repeated batch:

for batch_no, (images, labels) in enumerate(orig_train_loader):
    repeat_images = torch.repeat_interleave(images, 3, dim=0)

This way, I get the order of batch B (repeat_images) right, but now I´m missing the transformations which I would need to apply within a batch/iteration. This seems not to be the paradigm of Pytorch, at least I did not find a way to do that.

I would be happy if somebody can help me - I am quite new to Pytorch (and also to stackoverflow), so please also be welcome to criticize my whole approach, performance issues that could arise etc.

Thanks a lot!

0 Answers0