I started to use pytorch-lightning and faced a problem of my custom data loaders:
Im using an own dataset and a common torch.utils.data.DataLoader. Basically the dataset takes a path and loads the data corresponding to an given index the dataloader loads its.
def train_dataloader(self):
train_set = TextKeypointsDataset(parameters...)
train_loader = torch.utils.data.DataLoader(train_set, batch_size, num_workers)
return train_loader
When I use the pytorch-lightning modules train_dataloader
and training_step
everything runs fine. When I add val_dataloader
and validation_step
Im facing this error:
Epoch 1: 45%|████▌ | 10/22 [00:02<00:03, 3.34it/s, loss=5.010, v_num=131199]
ValueError: Expected input batch_size (1500) to match target batch_size (5)
In this case my dataset is really small (to test functionality) of 84 samples, my batch size is 8. The dataset for training and validation has the same length (just for testing purposes again).
So in total its 84 * 2 = 168 and 168 / 8 (batchsize) = 21, which are roughly the total steps (22) shown above. This means that after running on the training dataset for 10 times (10 * 8 = 80) the loader expects a new full sample of 8, but since there are only 84 samples I get an error (at least this is my current understanding).
I faced a similar problem in my own implementation (not using pytorch-lighntning) and used this pattern to solve it. Basically I am resetting the iterator, when running out of data:
try:
data = next(data_iterator)
source_tensor = data[0]
target_tensor = data[1]
except StopIteration: # reinitialize data loader if num_iteration > amount of data
data_iterator = iter(data_loader)
Right now it seems like Im facing sth similar? I dont know how to reset/reinitialize the data loader in pytorch-lightning when my training_dataloader is running out of data. I guess there must be another sophisticated way Im not familar with. Thank you