Assuming the question is:
- Combine 2+ data sets with potentially overlapping categories of objects (distinguishable by label)
- Each object has 4 "subcategories" for each color (distinguishable by label)
- Each batch should only contain a single object category
The first step will be to ensure consistency of the object labels from both data sets, if not already consistent. For example, if the dog class is label 0
in the first data set but label 2
in the second data set, then we need to make sure the two dog categories are correctly merged. We can do this "translation" with a simple data set wrapper:
class TranslatedDataset(Dataset):
"""
Args:
dataset: The original dataset.
translate_label: A lambda (function) that maps the original
dataset label to the label it should have in the combined data set
"""
def __init__(self, dataset, translate_label):
super().__init__()
self._dataset = dataset
self._translate_label = translate_label
def __len__(self):
return len(self._dataset)
def __getitem__(self, idx):
inputs, target = self._dataset[idx]
return inputs, self._translate_label(target)
The next step is combining the translated data sets together, which can be done easily with a ConcatDataset
:
first_original_dataset = ...
second_original_dataset = ...
first_translated = TranslateDataset(
first_original_dataset,
lambda y: 0 if y is 2 else 2 if y is 0 else y, # or similar
)
second_translated = TranslateDataset(
second_original_dataset,
lambda y: y, # or similar
)
combined = ConcatDataset([first_translated, second_translated])
Finally, we need to restrict batch sampling to the same class, which is possible with a custom Sampler
when creating the data loader.
class SingleClassSampler(torch.utils.data.Sampler):
def __init__(self, dataset, batch_size):
super().__init__()
# We need to create sequential groups
# with batch_size elements from the same class
indices_for_target = {} # dict to store a list of indices for each target
for i, (_, target) in enumerate(dataset):
# converting to string since Tensors hash by reference, not value
str_targ = str(target)
if str_targ not in indices_for_target:
indices_for_target[str_targ] = []
indices_for_target[str_targ] += [i]
# make sure we have a whole number of batches for each class
trimmed = {
k: v[:-(len(v) % batch_size)]
for k, v in indices_for_target.items()
}
# concatenate the lists of indices for each class
self._indices = sum(list(trimmed.values()))
def __len__(self):
return len(self._indices)
def __iter__(self):
yield from self._indices
Then to use the sampler:
loader = DataLoader(
combined,
sampler=SingleClassSampler(combined, 64),
batch_size=64,
shuffle=True
)
I haven't run this code, so it might not be exactly right, but hopefully it will put you on the right track.
torch.utils.data Docs