0

I have custom datasets that have the __getitem__ method in them. I created the following DatasetMUX class that it supposed to select random dataset and get the item from that dataset:

class MUXDataset(Dataset):
    """
        Defines a dataset class that provides a way to read scenes and also visualization tools
    """

    def __init__(self, mux_dataset_params: MultiplexDatasetParams) -> None:
        self._params = mux_dataset_params
        self.sampling_vec = (
            mux_dataset_params.sampling if isinstance(mux_dataset_params.sampling, str) else self.init_sampling()
        )
        return

    def init_sampling(self) -> List[float]:
        if self._params.sampling == 'uniform':
            num_datasets = len(self._params.datasets)
            sampling = [1 / num_datasets for _ in self._params.datasets]
        elif self._params.sampling == 'proportional':
            size_vec = [len(_d) for _d in self._params.datasets]
            sampling = [1 - _s / sum(size_vec) for _s in size_vec]
            sampling = [_s / sum(sampling) for _s in sampling]
        else:
            raise ValueError(f'{self._params.sampling} is not supported')
        return sampling

    def __len__(self) -> int:
        return len(self._params.datasets)

    def __getitem__(self, idx):
        curr_ds = self._params.datasets[idx]
        return curr_ds.__getitem__(random.randint(0, len(curr_ds) - 1))

Now I want the __getitem__ of the MUXDataset will be based on the sampling vector but I couldn't find a way to implement it within the class

I tried the following inside the __getitem__:

def __getitem__(self, idx):
        ds_idx = random.choices(population, weights=self.sampling, k=1)
        curr_ds = self._params.datasets[ds_idx]
        return curr_ds.__getitem__(random.randint(0, len(curr_ds) - 1))
David
  • 83
  • 7

0 Answers0