If I understand your question correctly, you could have a DataLoader
return a sequence of hand-selected batches using a custom batch_sampler
(you don't even need to pass it a sampler
in this case).
Given an arbitrary Dataset
:
>>> from torch.utils.data import DataLoader, Dataset
>>> from torch.utils.data.sampler import Sampler
>>> class MyDataset(Dataset):
... def __getitem__(self, idx):
... return idx
you can then define something like:
>>> class MyBatchSampler(Sampler):
... def __init__(self, batches):
... self.batches = batches
...
... def __iter__(self):
... for batch in self.batches:
... yield batch
...
... def __len__(self):
... return len(self.batches)
which just takes a list of lists containing dataset indices to include in each batch.
Then:
>>> dataset = MyDataset()
>>> batch_sampler = MyBatchSampler([[1, 2, 3], [5, 6, 7], [4, 2, 1]])
>>> dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler)
>>> for batch in dataloader:
... print(batch)
...
tensor([1, 2, 3])
tensor([5, 6, 7])
tensor([4, 2, 1])
Should be easy to extend to your actual Dataset, etc.