torch.utils.data.RandomSampler
can be used to randomly sample more entries than exist in a dataset (where num_samples
> dataset_size
);
sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=num_samples)
loader = torch.utils.data.DataLoader(dataset=dataloader_dataset, sampler=sampler, batch_size=batch_size)
If sampling from a Hugging Face dataset, the dataloader_dataset
class must have StopIteration
configured to reset the iterator (start from beginning of dataset), for example;
#parameter selection (user configured);
dataset = load_dataset(...)
dataset_size = dataset.num_rows
number_of_dataset_repetitions = 5
num_samples = dataset_size * number_of_dataset_repetitions
batch_size = 8
drop_last = True
dataloader_dataset = DataloaderDatasetRepeatSampler(dataset, dataset_size)
sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=num_samples)
loader = torch.utils.data.DataLoader(dataset=dataloader_dataset, sampler=sampler, batch_size=batch_size, drop_last=drop_last)
loop = tqdm(loader, leave=True)
for batch_index, batch in enumerate(loop):
...
class DataloaderDatasetRepeatSampler(torch.utils.data.Dataset):
def __init__(self, dataset, dataset_size):
self.dataset = dataset
self.dataset_size = dataset_size
self.dataset_iterator = iter(dataset)
def __len__(self):
return self.datasetSize
def __getitem__(self, i):
try:
dataset_entry = next(self.dataset_iterator)
except StopIteration:
#reset iterator (start from beginning of dataset)
self.dataset_iterator = iter(self.dataset)
dataset_entry = next(self.dataset_iterator)
batch_sample = ... #eg torch.Tensor(dataset_entry)
return batch_sample