I want to know how to use torch.utils.data.DataLoader
in PyTorch, especially in a multi-worker case.
I found that one batch output from DataLoader
always comes from a single worker.
I expected that there is a queue in the DataLoader which stores data from all of the workers and DataLoader shuffles them in the queue to output the random batch data. I think this is the way in tf.data.Dataset
in Tensorflow.
Can we implement a similar function in PyTorch? I want to load a dataset from big serialized files (like Tfrecord
) by using multi workers. In this case, mixing the source file in one batch, which means mixing the source of the worker, is important.
Please refer to following code:
import random
import time
import torch
class MyDataset(torch.utils.data.Dataset):
def __len__(self):
return 50
def __getitem__(self, idx):
info = torch.utils.data.get_worker_info()
time.sleep(random.uniform(0, 1))
print("[{}]:{}".format(info.id, idx))
return idx, info.id
if __name__ == '__main__':
dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False, num_workers=2)
for batch in dataloader:
print(batch)
Output:
[0]:0
[1]:5
[0]:1
[1]:6
[0]:2
[0]:3
[1]:7
[0]:4
[tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
[1]:8
[1]:9
[tensor([5, 6, 7, 8, 9]), tensor([1, 1, 1, 1, 1])]
[0]:10
[0]:11
[1]:15
[1]:16
[0]:12
[1]:17
...
Here, [0, 1, 2, 3, 4]
and [0, 0, 0, 0, 0]
in [tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
mean that this batch includes index 0-th to 4-th data came from worker id 0
.
Note that shuffle=True
does not solve this problem which only change the indices of data.
In this case, I want to get a batch like: [tensor([0, 5, 1, 6, 2]), tensor([0, 1, 0, 1, 0])]
.