0

i used image augmentation in pytorch before training in unet like this

class ProcessTrainDataset(Dataset):
def __init__(self, x, y):
    self.x = x
    self.y = y

    self.pre_process = transforms.Compose([
                        transforms.ToTensor()])

    self.transform_data = transforms.Compose([
                        transforms.ColorJitter(brightness=0.2, contrast=0.2)])

    self.transform_all = transforms.Compose([
                        transforms.RandomVerticalFlip(),
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomRotation(10),
                        transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
                        transforms.RandomAffine(degrees=0, translate=(0.2,0.2), scale=(0.9,1.1),),])

def __len__(self):
    return len(self.x)

def __getitem__(self, idx):
    img_x = Image.open(self.x[idx])
    img_y = Image.open(self.y[idx]).convert("L")

    #First get into the right range of 0 - 1, permute channels first, and put to tensor
    img_x = self.pre_process(img_x)
    img_y = self.pre_process(img_y)

    #Apply resize and shifting transforms to all; this ensures each pair has the identical transform applied
    img_all = torch.cat([img_x, img_y])
    img_all = self.transform_all(img_all)
    
    #Split again and apply any color/saturation/hue transforms to data only
    img_x, img_y = img_all[:-1, ...], img_all[-1:,...]
    img_x = self.transform_data(img_x)  

    #Add augmented data to dataset
    self.x_augmented.append(img_x)
    self.y_augmented.append(img_y)

    return img_x, img_y

but how do we know if all augmentations have been applied to the dataset and how can we see the number of datasets after augmentation?

anastasia
  • 13
  • 4

1 Answers1

0

How can we see the length of the dataset after transformation? - Pytorch data transforms for augmentation such as the random transforms defined in your initialization are dynamic, meaning that every time you call __getitem__(idx), a new random transform is computed and applied to datum idx. In this way, there is functionally an infinite number of images supplied by your dataset, even if you have only one example in your dataset (though of course these images will be quite highly correlated with one another as they are generated from the same base image). Thus, thinking about pytorch transforms as augmenting the number of elements in the dataset is not really the correct. The number of elements is always equal to len(self.x), but each time you return element idx it will be slightly different.

Given this, your implementation need not store the resulting transformed data in x_augmented and y_augmented unless you have a specific downstream use for these data in addition to using the returned values from __getitem__().

How do I know the transforms have all been applied? - It's safe to say that if you define the transforms, and you apply the transforms, then they are all applied. If you want to verify this, you can display the images before and after transformation, and comment out all but one transform at a time to ensure that each transform has an effect.

DerekG
  • 3,555
  • 1
  • 11
  • 21