2

I started to use pytorch-lightning and faced a problem of my custom data loaders:

Im using an own dataset and a common torch.utils.data.DataLoader. Basically the dataset takes a path and loads the data corresponding to an given index the dataloader loads its.

def train_dataloader(self):
    train_set = TextKeypointsDataset(parameters...)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size, num_workers)
    return train_loader 

When I use the pytorch-lightning modules train_dataloader and training_step everything runs fine. When I add val_dataloader and validation_step Im facing this error:

Epoch 1:  45%|████▌     | 10/22 [00:02<00:03,  3.34it/s, loss=5.010, v_num=131199]
ValueError: Expected input batch_size (1500) to match target batch_size (5)

In this case my dataset is really small (to test functionality) of 84 samples, my batch size is 8. The dataset for training and validation has the same length (just for testing purposes again).

So in total its 84 * 2 = 168 and 168 / 8 (batchsize) = 21, which are roughly the total steps (22) shown above. This means that after running on the training dataset for 10 times (10 * 8 = 80) the loader expects a new full sample of 8, but since there are only 84 samples I get an error (at least this is my current understanding).

I faced a similar problem in my own implementation (not using pytorch-lighntning) and used this pattern to solve it. Basically I am resetting the iterator, when running out of data:

try:
    data = next(data_iterator)
    source_tensor = data[0]
    target_tensor = data[1]

except StopIteration:  # reinitialize data loader if num_iteration > amount of data
    data_iterator = iter(data_loader)

Right now it seems like Im facing sth similar? I dont know how to reset/reinitialize the data loader in pytorch-lightning when my training_dataloader is running out of data. I guess there must be another sophisticated way Im not familar with. Thank you

Asdf11
  • 450
  • 5
  • 15
  • Implementing your own `Dataset` is pretty standard but defining a custom [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) is probably a mistake since it does all sorts of complicated stuff on the backend (multi-threading etc..). In the most extreme cases you should be able to define your own [`Sampler`](https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler) and possibly a [`collate_fn`](https://pytorch.org/docs/stable/data.html#working-with-collate-fn) (if necessary), both of which would be provided to your `DataLoader` upon construction. – jodag May 25 '20 at 20:16
  • I edited my question to make it more clear. Im using an own dataset but not a custom dataloader – Asdf11 May 25 '20 at 22:16

1 Answers1

1

The solution was:

I used source_tensor = source_tensor.view(-1, self.batch_size, self.input_size) which lead to some errors later on, now Im using source_tensor = source_tensor.permute(1, 0, 2), which fixed the problem.

Asdf11
  • 450
  • 5
  • 15