2

I have a torch.utils.data.Dataset object, I would like to have a DataLoader or a similar object that accepts a list of idxs and returns a batch of samples with the corresponding idxs.

Example, I have

list_idxs = [10, 109, 7, 12]

I would like to do like:

batch = loader.getbatch(list_idxs)

where batch contains:

[sample10, sample109, sample7, sample12]

Is there a simple and elegant way to do that in an optimized way?

Iguananaut
  • 21,810
  • 5
  • 50
  • 63
thebesttony
  • 327
  • 5
  • 10
  • Should be doable in various ways, but how would this be used in practice? Normally the way `DataLoader` works is it simply iterates over batches, where the batches returned are determined by both the `BatchSampler` and any other underlying sampler of the Dataset. You wouldn't need a `DataLoader` for this in principle. If you already have a `Dataset` of some kind this would simply be equivalent to `batch = [dataset[idx] for idx in idxs]` – Iguananaut Sep 09 '21 at 16:52
  • 1
    In other words, do you absolutely need any functionality of the `DataLoader` (multiprocessing, collation, etc.) or just some simple custom batching? – Iguananaut Sep 09 '21 at 17:07
  • 1
    Thanks, yes I need the functionality of DataLoader for preprocessing stuff... the answer is what I was looking for. – thebesttony Sep 09 '21 at 19:23

1 Answers1

2

If I understand your question correctly, you could have a DataLoader return a sequence of hand-selected batches using a custom batch_sampler (you don't even need to pass it a sampler in this case).

Given an arbitrary Dataset:

>>> from torch.utils.data import DataLoader, Dataset
>>> from torch.utils.data.sampler import Sampler
>>> class MyDataset(Dataset):
...     def __getitem__(self, idx):
...         return idx

you can then define something like:

>>> class MyBatchSampler(Sampler):
...     def __init__(self, batches):
...         self.batches = batches
...
...     def __iter__(self):
...         for batch in self.batches:
...             yield batch
...
...     def __len__(self):
...         return len(self.batches)

which just takes a list of lists containing dataset indices to include in each batch.

Then:

>>> dataset = MyDataset()
>>> batch_sampler = MyBatchSampler([[1, 2, 3], [5, 6, 7], [4, 2, 1]])
>>> dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler)
>>> for batch in dataloader:
...     print(batch)
... 
tensor([1, 2, 3])
tensor([5, 6, 7])
tensor([4, 2, 1])

Should be easy to extend to your actual Dataset, etc.

Iguananaut
  • 21,810
  • 5
  • 50
  • 63