3

I am trying to use information from the outside functions to decide which data to return. Here, I have added a simplified code to demonstrate the problem. When I use num_workers = 0, I get the desired behavior (The output after 3 epochs is 18). But, when I increase the value of num_workers, the output after each epoch is the same. And the global variable remains unchanged.

from torch.utils.data import Dataset, DataLoader

x = 6
def getx():
    global x
    x+=1
    print("x: ", x)
    return x

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        global x
        x = getx()
        return x
    
    def __len__(self):
        return 3

dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=0,
    shuffle=False
)

for epoch in range(4):
    for idx, data in enumerate(loader):
        print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))

The final output when num_workers=0 is 18 as expected. But when num_workers>0, x remains unchanged (The final output is 6).

How can I get a similar behavior as num_workers=0 using num_workers>0(i.e.How to ensure the __getitem__ function of dataloader changes the global variable x's value )?

Mercury
  • 3,417
  • 1
  • 10
  • 35
deep s. pandey
  • 150
  • 1
  • 11

1 Answers1

1

The reason for this is the underlying nature of multiprocessing in python. Setting num_workers means that your DataLoader creates that number of sub-processes. Each sub-process is effectively a separate python instance with its own global state, and has no idea of what's going on in the other processes.

A typical solution for this in python's multiprocessing is using a Manager. However, since your multiprocessing is being provided through the DataLoader, you have no way to work this in.

Fortunately, something else can be done. DataLoader actually relies on torch.multiprocessing, which in turn allows sharing of tensors between processes as long as they are in shared memory.

So what you can do is, simply use x as a shared tensor.

from torch.utils.data import Dataset, DataLoader
import torch 

x = torch.tensor([6])
x.share_memory_()

def getx():
    global x
    x+=1
    print("x: ", x.item())
    return x

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        global x
        x = getx()
        return x
    
    def __len__(self):
        return 3

dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=2,
    shuffle=False
)

for epoch in range(4):
    for idx, data in enumerate(loader):
        print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))

Out:

x:  7
x:  8
x:  9
Epoch 0, idx 0, val: tensor([[7]])
Epoch 0, idx 1, val: tensor([[8]])
Epoch 0, idx 2, val: tensor([[9]])
x:  10
x:  11
x:  12
Epoch 1, idx 0, val: tensor([[10]])
Epoch 1, idx 1, val: tensor([[12]])
Epoch 1, idx 2, val: tensor([[12]])
x:  13
x:  14
x:  15
Epoch 2, idx 0, val: tensor([[13]])
Epoch 2, idx 1, val: tensor([[15]])
Epoch 2, idx 2, val: tensor([[14]])
x:  16
x:  17
x:  18
Epoch 3, idx 0, val: tensor([[16]])
Epoch 3, idx 1, val: tensor([[18]])
Epoch 3, idx 2, val: tensor([[17]])

While this works, it isn't perfect. Look at epoch 1, and notice that there are 2 12s rather than 11 and 12. This means that two separate processes have executed the line x+=1 before executing print. This is unavoidable as parallel processes are working on shared memory.

If you're familiar with operating system concepts, you may be able to further implement some sort of semaphore with an extra variable to control the access to x as needed - but as this goes beyond the scope of the question, I won't elaborate further.

Mercury
  • 3,417
  • 1
  • 10
  • 35