1

The traditional way of applying some arbitrary collate_fn foo() in torch code is

dataloader = torch.data.DataLoader(
  dataset,
  batch_size=64, # just for example
  collate_fn=foo,
  **other_kwargs
)

for batch in dataloader:
    # incoming batch is already collated
    do_stuff(batch)

But what if (for whatever reason), I wanted to do it like this:

dataloader = torch.data.DataLoader(
  dataset,
  batch_size=64, # just for example
  **other_kwargs
)

for batch in dataloader:
    # incoming batch is not yet collated
    # this let's me do additional pre-collation stuff like
    # batch = do_stuff_precollate(batch) 
    collated_batch = foo(batch)  # finally we collate, outside of the dataloader
    do_stuff(collated_batch)

Is there any reason why the latter is a big nono? Or why the former is particularly advantageous? I found a blogpost that even suggests that for HF tokenisation, the latter is faster

thesofakillers
  • 290
  • 3
  • 13

0 Answers0