-2

I try my best to finish my code, but i find a error that i can't slove, in my code i already write down the image_index,but the error can't slove, so i want to ask for you, thank you so much!

this is the code:

import numpy as np
import torch
from torch.utils.data import Dataset
import glob
from PIL import Image
from torchvision import transforms
from skimage.segmentation import mark_boundaries
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.transforms import Grayscale, RandomHorizontalFlip, Resize, ToTensor
import numpy as np
import matplotlib.pyplot as plt
import os


class InfraredDataset(Dataset):
    def __init__(self, dataset_dir, image_index):
        super(InfraredDataset, self).__init__()
        self.dataset_dir = dataset_dir
        self.image_index = image_index
        self.transformer = transforms.Compose([
            Resize((256, 256)),
            Grayscale(),
            ToTensor(),
            RandomHorizontalFlip(0.5),
        ])

    def __getitem__(self, index):
        image_index = self.image_index[index].strip('\n')
        image_path = os.path.join(self.dataset_dir, 'images', '%s.png' % image_index)
        label_path = os.path.join(self.dataset_dir, 'masks', '%s_pixels0.png' % image_index)
        image = Image.open(image_path)
        label = Image.open(label_path)
        torch.manual_seed(1024)
        tensor_image = self.transformer(image)
        torch.manual_seed(1024)
        label = self.transformer(label)
        label[label > 0] = 1
        return tensor_image, label

    def __len__(self):
        return len(self.image_index)


if __name__ == "__main__":
    f = open('../sirst/idx_427/trainval.txt').readlines()
    ds = InfraredDataset(f)
    # 数据集测试
    for i, (image, label) in enumerate(ds):
        image, label = to_pil_image(image), to_pil_image(label)
        image, label = np.array(image), np.array(label)
        print(image.shape, label.shape)
        vis = mark_boundaries(image, label, color=(1, 1, 0))
        image, label = np.stack([image] * 3, -1), np.stack([label] * 3, -1)
        plt.imsave('image_%d.png' % i, vis)

this is error:

Traceback (most recent call last):
  File "H:/ProgramData/Infrared-detect-by-segmentation-master/Infrared-detect-by-segmentation-master/utils/dataloader.py", line 54, in <module>
    ds = InfraredDataset(f)
TypeError: __init__() missing 1 required positional argument: 'image_index'

i can't understand that why i made the mistake,so i add the image_imdex but can't slove the problem.

Osvart
  • 11
  • 3
  • 1
    Please give more information on what your code is trying to do. – Mikee Apr 07 '22 at 06:41
  • (Tangentially, code you put in `if __name__ == '__main__’:` should be trivial; the purpose of this boilerplate is to allow you to `import` the code, which you will not want to do anyway if the logic you need is not available via `import`. See also https://stackoverflow.com/a/69778466/874188) – tripleee Apr 07 '22 at 07:15

1 Answers1

0

It looks like you're missing to pass the second argument when instantiating the InfraredDataset object (named ds). As you can see the __init__() requires two arguments: dataset_dir, image_index.

In the second piece of code, you are instantiating the object as ds = InfraredDataset(f) passing only one argument (f, which will be taken from the __init__() as the dataset_dir parameter).

In conclusion, you need to match the number and position of the arguments of the __init__() with the ones you use for creating an object. You need either to pass another argument when creating the object, or to make the __init__() accepting only one.