I have custom datasets that have the __getitem__
method in them.
I created the following DatasetMUX class that it supposed to select random dataset and get the item from that dataset:
class MUXDataset(Dataset):
"""
Defines a dataset class that provides a way to read scenes and also visualization tools
"""
def __init__(self, mux_dataset_params: MultiplexDatasetParams) -> None:
self._params = mux_dataset_params
self.sampling_vec = (
mux_dataset_params.sampling if isinstance(mux_dataset_params.sampling, str) else self.init_sampling()
)
return
def init_sampling(self) -> List[float]:
if self._params.sampling == 'uniform':
num_datasets = len(self._params.datasets)
sampling = [1 / num_datasets for _ in self._params.datasets]
elif self._params.sampling == 'proportional':
size_vec = [len(_d) for _d in self._params.datasets]
sampling = [1 - _s / sum(size_vec) for _s in size_vec]
sampling = [_s / sum(sampling) for _s in sampling]
else:
raise ValueError(f'{self._params.sampling} is not supported')
return sampling
def __len__(self) -> int:
return len(self._params.datasets)
def __getitem__(self, idx):
curr_ds = self._params.datasets[idx]
return curr_ds.__getitem__(random.randint(0, len(curr_ds) - 1))
Now I want the __getitem__
of the MUXDataset will be based on the sampling vector but I couldn't find a way to implement it within the class
I tried the following inside the __getitem__
:
def __getitem__(self, idx):
ds_idx = random.choices(population, weights=self.sampling, k=1)
curr_ds = self._params.datasets[ds_idx]
return curr_ds.__getitem__(random.randint(0, len(curr_ds) - 1))