0

I am training a (siamense) neural network with Pytorch on a very big dataset. Loading data is the biggest bottleneck, and my dataset doesn't fit in RAM to speed it up.

What I would like to do is basically cache part of the data, and repeat it inside the same epoch to speed up the training. Would it be possible to have some kind of double ended queue to sample from, where I append elements upon reading them, and remove them after I included them in the training a few times?

Unfortunately none of the normal functions in either torchdata or torch.utils.data.Dataset seem to allow this. It's either caching a complete epoch of data, or none at all.

rmeertens
  • 4,383
  • 3
  • 17
  • 42

1 Answers1

1

I think using the sample multiple time in the same epoch will be messy when training the model. it's better to create a data generator that will for one epoch only use data once.

If you want to use the sample many time in one epoch i hope this little example that i made can help you :

#set batch_size = 1

class super_dynamic_Dataset(Dataset):
    """A dataset that will use cached dataset and only load a porportion of the dataset"""

    def __init__(self, list_of_paths,  real_batch,step=2):
        """_summary_

        Args:
            list_of_paths (list): list of file to read
            real_batch (int): batch size
            step (int, optional): the number of file to load in each time. Defaults to 2.
        """
        self.list_of_paths = list_of_paths
        self.data = list(range(len(list_of_paths)))
        self.real_batch = real_batch
        self.actual_batch = np.zeros((real_batch,256,256,3))
        self.batche_ids = []
        for elemnts in  range(0,len(self.data),step):
            self.batche_ids.append( self.data[elemnts:elemnts+real_batch])
        self.old_batch_ids = []
        
    def __len__(self):
        return len(self.data ) // self.real_batch

    def __getitem__(self, idx):
        actual_batch_ids = self.batche_ids[idx] 
        actual_batch = np.zeros((self.real_batch,256,256,3)) 
        for id, element in enumerate(actual_batch_ids): # if the element exist in the batch before juste get it from memory otherwise read it 
            if element in self.old_batch_ids:
                actual_batch[id] = self.actual_batch[self.old_batch_ids.index(element)]
            else:
                actual_batch[id] = _read_the_file(self.list_of_paths[element])   
        
        self.actual_batch = actual_batch
        self.old_batch_ids = actual_batch_ids
        return actual_batch

I couldn't test the code but you have the general idea

Ghassen Sultana
  • 1,223
  • 7
  • 18