5

I have a need to use a BatchSampler within a pytorch DataLoader instead of calling __getitem__ of the dataset multiple times (remote dataset, each query is pricy).
I cannot understand how to use the batchsampler with any given dataset.

e.g

class MyDataset(Dataset):

    def __init__(self, remote_ddf, ):
        self.ddf = remote_ddf

    def __len__(self):
        return len(self.ddf)

    def __getitem__(self, idx):
        return self.ddf[idx] --------> This is as expensive as a batch call

    def get_batch(self, batch_idx):
        return self.ddf[batch_idx]

my_loader = DataLoader(MyDataset(remote_ddf), 
           batch_sampler=BatchSampler(Sampler(), batch_size=3))

The thing I do not understand, neither found any example online or in torch docs, is how do I use my get_batch function instead of the __getitem__ function.
Edit: Following the answer of Szymon Maszke, this is what I tried and yet, \_\_get_item__ gets one index each call, instead of a list of size batch_size

class Dataset(Dataset):

    def __init__(self):
       ...

    def __len__(self):
        ...

    def __getitem__(self, batch_idx):  ------> here I get only one index
        return self.wiki_df.loc[batch_idx]


loader = DataLoader(
                dataset=dataset,
                batch_sampler=BatchSampler(
                    SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
                num_workers=self.hparams.num_data_workers,
            )
DsCpp
  • 2,259
  • 3
  • 18
  • 46

1 Answers1

5

You can't use get_batch instead of __getitem__ and I don't see a point to do it like that.

torch.utils.data.BatchSampler takes indices from your Sampler() instance (in this case 3 of them) and returns it as list so those can be used in your MyDataset __getitem__ method (check source code, most of samplers and data-related utilities are easy to follow in case you need it).

I assume your self.ddf supports list slicing (e.g. self.ddf[[25, 44, 115]] returns values correctly and uses only one expensive call). In this case simply switch get_batch into __getitem__ and you are good to go.

class MyDataset(Dataset):

    def __init__(self, remote_ddf, ):
        self.ddf = remote_ddf

    def __len__(self):
        return len(self.ddf)

    def __getitem__(self, batch_idx):
        return self.ddf[batch_idx] -> batch_idx is a list

EDIT: You have to specify batch_sampler as sampler, otherwise the batch will be divided into single indices. This should be fine:

loader = DataLoader(
    dataset=dataset,
    # This line below!
    sampler=BatchSampler(
        SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False
    ),
    num_workers=self.hparams.num_data_workers,
)
Szymon Maszke
  • 22,747
  • 4
  • 43
  • 83
  • 1
    Funny as it sounds, I couldn't understand it from the documentation. _getitem_ of a dataset sounds like something that returns one sample, in my case a row. – DsCpp Apr 27 '20 at 12:15
  • `torch.utils.data.Dataset` is a rather flexible structure (at least from pytorch version `1.4` IIRC) so `index` can be anything really AFAIK. If you use `batch_sampler` it is responsible for creating whole batch of data. – Szymon Maszke Apr 27 '20 at 12:15
  • Of course, but from the documentation perspective, the collate function(aggregation) is done implicitly for you, meaning _get_ gets k times 1, and then aggregated. This means that *no* aggregation is being done after __getitem__ – DsCpp Apr 27 '20 at 12:18
  • 1
    `collate_fn` allows you to "post-process" data after it's been returned from batch. You may return `list[Tensor]` from your Dataset or get `list[Tensor]` gets returned when using standard sampler and you can create tensor from it. Good use case is padding for variable length tensors to be used with RNN or a-like. Though I agree `DataLoader` might be a little confusing. – Szymon Maszke Apr 27 '20 at 12:25
  • Are you sure this is working for you? Updated my example with code as you advised – DsCpp Apr 28 '20 at 07:10
  • 1
    Yes hahaha! I just understood it now, and came to answer myself. Thank you! – DsCpp Apr 28 '20 at 08:16