The traditional way of applying some arbitrary collate_fn foo()
in torch code is
dataloader = torch.data.DataLoader(
dataset,
batch_size=64, # just for example
collate_fn=foo,
**other_kwargs
)
for batch in dataloader:
# incoming batch is already collated
do_stuff(batch)
But what if (for whatever reason), I wanted to do it like this:
dataloader = torch.data.DataLoader(
dataset,
batch_size=64, # just for example
**other_kwargs
)
for batch in dataloader:
# incoming batch is not yet collated
# this let's me do additional pre-collation stuff like
# batch = do_stuff_precollate(batch)
collated_batch = foo(batch) # finally we collate, outside of the dataloader
do_stuff(collated_batch)
Is there any reason why the latter is a big nono? Or why the former is particularly advantageous? I found a blogpost that even suggests that for HF tokenisation, the latter is faster