1

I am using pytorch for image classification using this code from github. I need to add data augmentation before training my model, I chose albumentation to do this. here is my code when I add albumentation:

data_transform = {
    "train": A.Compose([ 
                        A.RandomResizedCrop(224,224),
                        A.HorizontalFlip(p=0.5),
                        A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
                        A.RandomBrightnessContrast (p=0.5),
                        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
                        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
                        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
                        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                        ToTensorV2(),]),
    "val": A.Compose([
                      A.Resize(256,256),
                      A.CenterCrop(224,224),
                      A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                      ToTensorV2()])}

I got this error:

KeyError: Caught KeyError in DataLoader worker process 0.

KeyError: 'You have to pass data to augmentations as named arguments, for example: aug(image=image)'

3 Answers3

4

This Albumentations function takes a positional argument 'image' and returns a dictionnary. This is a sample to use it :

transforms = A.Compose([
                A.augmentations.geometric.rotate.Rotate(limit=15,p=0.5),
                A.Perspective(scale=[0,0.1],keep_size=False,fit_output=False,p=1),
                A.Resize(224, 224),
                A.HorizontalFlip(p=0.5),
                A.GaussNoise(var_limit=(10.0, 50.0), mean=0),
                A.RandomToneCurve(scale=0.5,p=1),
                A.Normalize(mean=[0.5, 0.5, 0.5],std=[0.225, 0.225, 0.225]),
                ToTensorV2()
            ])

img = cv2.imread("dog.png")
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
transformed_img = transforms(image=img)["image"]
Maxime D.
  • 306
  • 5
  • 17
0

You can do what you want with writing a class like below:

import albumentations as A
import cv2 

class ImageDataset(Dataset):
    def __init__(self, images_filepaths, transform=None):
        self.images_filepaths = images_filepaths
        self.transform = transform

    def __len__(self):
        return len(self.images_filepaths)

    def __getitem__(self, idx):
        image_filepath = self.images_filepaths[idx]
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        return image
    

train_transform = A.Compose([
    A.RandomResizedCrop(224,224),
    A.HorizontalFlip(p=0.5),
    A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
    A.RandomBrightnessContrast (p=0.5),
    A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
    A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
    A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ToTensorV2(),
])


val_transform = A.Compose([
    A.Resize(256,256),
    A.CenterCrop(224,224),
    A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ToTensorV2(),
])

train_dataset = ImageDataset(images_filepaths=train_images_filepaths, transform=train_transform)
val_dataset = ImageDataset(images_filepaths=val_images_filepaths, transform=val_transform)
I'mahdi
  • 23,382
  • 5
  • 22
  • 30
  • Thanks for your mention. I do as you said but again the same error! – Saeedeh Alebooyeh Mar 16 '22 at 03:48
  • I posted my code as an answer I am not sure whether I am using your suggestion correctly or not – Saeedeh Alebooyeh Mar 16 '22 at 04:02
  • @SaeedehAlebooyeh, welcome, **First** : please update your question and write your code in your question instead of writing an answer. **second** what is your error? :) you say you got an error, what is your error? **third** which line number of your code you get an error? – I'mahdi Mar 16 '22 at 16:14
0

Am I using your suggestion correctly? I have dataset of good and bad images (underwater images)

import os
import json
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
import random
from model import resnet34
import cv2 


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
class ImageDataset():
    def __init__(self, images_filepaths, transform=None):
        self.images_filepaths = images_filepaths
        self.transform = transform

    def __len__(self):
        return len(self.images_filepaths)

    def __getitem__(self, idx):
        image_filepath = self.images_filepaths[idx]
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        return image
train_transform = A.Compose([
    A.RandomResizedCrop(224,224),
    A.HorizontalFlip(p=0.5),
    A.RandomGamma(gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5),
    A.RandomBrightnessContrast (p=0.5),
    A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
    A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
    A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ToTensorV2(),
])


    val_transform = A.Compose([
      A.Resize(256,256),
      A.CenterCrop(224,224),
      A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
      ToTensorV2(),
])


data_root = os.path.abspath(os.path.join(os.getcwd(), "/content/gdrive/"))  # get             data root path
image_path = os.path.join(data_root, "MyDrive" , "totalimages")  # flower data set path
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)


train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                     transform=train_transform)
train_num = len(train_dataset)

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
{'bad':1, 'good':2} #
flower_list = train_dataset.class_to_idx
image_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in image_list.items()) #dictionary
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
       json_file.write(json_str)

batch_size = 64
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process'.format(nw))

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                        transform=val_transform)
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=nw)

   print("using {} images for training, {} images for  validation.".format(train_num,
                                                                       val_num))

net = resnet34()
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth

model_weight_path = "./resnet34-pre.pth"

model_weight_path = "/content/gdrive/MyDrive/resnet34-333f7ec4.pth"
assert os.path.exists(model_weight_path), "file {} does not    exist.".format(model_weight_path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# for param in net.parameters():
#     param.requires_grad = False

# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
net.to(device)

# define loss function
loss_function = nn.CrossEntropyLoss()

# construct an optimizer
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)

epochs = 10
best_acc = 0.0
save_path = './resNet34.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
    # train
    net.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, file=sys.stdout)
    for step, data in enumerate(train_bar):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()

        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                 epochs,
                                                                 loss)

    # validate
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        val_bar = tqdm(validate_loader, file=sys.stdout)
        for val_data in val_bar:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

            val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                       epochs)

    val_accurate = acc / val_num
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
          (epoch + 1, running_loss / train_steps, val_accurate))

    if val_accurate > best_acc:
        best_acc = val_accurate
        torch.save(net.state_dict(), save_path)

print('Finished Training')


if __name__ == '__main__':
main()
  • Please use the comment section to clarify things or update your question. Only use the answer section if you clearly answer your own question. – elyptikus Mar 16 '22 at 06:58
  • @SaeedehAlebooyeh, welcome, First : please update your question and write your code in your question instead of writing an answer. second what is your error? :) you say you got an error, what is your error? third which line number of your code you get an error? – I'mahdi Mar 17 '22 at 12:53