Problem
In PyTorch, I am trying to write a class that could return the entire data
and label
separately using syntax like dataset.data
and dataset.label
. The code skeleton looks like:
class MyDataset(object):
data = _get_data()
label = _get_label()
def __init__(self, dir, transforms):
self.img_list = ... # all image paths loaded from dir
# do something
def __getitem__(self):
# do something
return data, label
def __len__(self):
return len(self.img_list)
def _get_data():
# do something
def _get_label():
# do something
However, when I use dataset.data
and dataset.label
to access the corresponding variables, nothing is returned.
I am wondering why this is the case and how I can fix this.
Edit
Thank you for all of your attention.
I have solved this problem by myself. The solution is pretty straightforward, which just utilizes the property of class variables.
class FaceDataset(object):
# class variable
data = None
label = None
def __init__(self, root, transforms=None):
# read img_list from root
img_list = ...
self.transforms = ...
FaceDataset.data = FaceDataset._get_data(self.img_list, self.transforms)
FaceDataset.label = FaceDataset._get_label(self.img_list)
@classmethod
def _get_data(cls, img_list, transforms):
data_list = []
for img_path in img_list:
data_list.append(transforms(Image.open(img_path)).unsqueeze(0))
return torch.stack(data_list, dim=0)
@classmethod
def _get_label(cls, img_list):
label = torch.zeros(len(img_list))
for i, img_path in enumerate(img_list):
label[i] = ...
return label
def __getitem__(self, index):
img_path = self.img_list[index]
label = ...
# read image from file
data = Image.open(img_path)
# apply transform defined in __init__
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.img_list)