I was debugging my pytorch
code and found that an instance of the class DataLoader
seems to be a global variable by default. I don't understand why this is the case but I've set up a minimum working example as below that should reproduce my observation. The code is below:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, df, n_feats, mode):
data = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]).transpose()
x = data[:, list(range(n_feats))] # features
y = data[:, -1] # target
self.x = torch.FloatTensor(x)
self.y = torch.FloatTensor(y)
def __getitem__(self, index):
return self.x[index], self.y[index]
def __len__(self):
return len(self.x)
def prep_dataloader(df, n_feats, mode, batch_size):
dataset = MyDataset(df, n_feats, mode)
dataloader = DataLoader(dataset, batch_size, shuffle=False)
return dataloader
tr_set = prep_dataloader(df, 1, 'train', 200)
def test():
print(tr_set)
As shown above, tr_set
was created before the function test
and is not passed to test
. However, running the code above, I got the following result:
<torch.utils.data.dataloader.DataLoader object at 0x7fb6c2ea7610>
Originally, I was expecting to get an error like "NameError: name 'tr_set' is not defined". However, the function was aware of tr_set
and printed the object of tr_set
even if tr_set
was not passed as an argument. I'm confused with this because in this case tr_set
seems like a global variable.
I'm wondering about the reason for this and possible ways that I can prevent it from becoming a global variable. Thank you!
(Update: In case that this matters, I was running the code above in a jupyter notebook.)