2

I defined a custom Dataset and a custom Dataloader, and I want to access all the batches using for i,batch in enumerate(loader). But this for loop gives me different number of batches in every epoch, and all of them are far smaller then the actual number of batches (which equals to number_of_samples/batch_size).

Here is how I define my dataset and dataloader:



class UsptoDataset(Dataset):
    def __init__(self, csv_file):
        df = pd.read_csv(csv_file)
        self.rea_trees = df['reactants_trees'].to_numpy()
        self.syn_trees = df['synthons_trees'].to_numpy()
        self.syn_smiles = df['synthons'].to_numpy()
        self.product_smiles = df['product'].to_numpy()

    def __len__(self):
        return len(self.rea_trees)

    def __getitem__(self, item):
        rea_tree = self.rea_trees[item]
        syn_tree = self.syn_trees[item]
        syn_smile = self.syn_smiles[item]
        pro_smile = self.product_smiles[item]
        # omit the snippet used to process the data here, which gives us the variables used in the return statement.
        return {'input_words': input_words,
                'input_chars': input_chars,
                'syn_tree_indices': syn_tree_indices,
                'syn_rule_nl_left': syn_rule_nl_left,
                'syn_rule_nl_right': syn_rule_nl_right,
                'rea_tree_indices': rea_tree_indices,
                'rea_rule_nl_left': rea_rule_nl_left,
                'rea_rule_nl_right': rea_rule_nl_right,
                'class_mask': class_mask,
                'query_paths': query_paths,
                'labels': labels,
                'parent_matrix': parent_matrix,
                'syn_parent_matrix': syn_parent_matrix,
                'path_lens': path_lens,
                'syn_path_lens': syn_path_lens}

    @staticmethod
    def collate_fn(batch):
        input_words = torch.tensor(np.stack([_['input_words'] for _ in batch], axis=0), dtype=torch.long)
        input_chars = torch.tensor(np.stack([_['input_chars'] for _ in batch], axis=0), dtype=torch.long)
        syn_tree_indices = torch.tensor(np.stack([_['syn_tree_indices'] for _ in batch], axis=0), dtype=torch.long)
        syn_rule_nl_left = torch.tensor(np.stack([_['syn_rule_nl_left'] for _ in batch], axis=0), dtype=torch.long)
        syn_rule_nl_right = torch.tensor(np.stack([_['syn_rule_nl_right'] for _ in batch], axis=0), dtype=torch.long)
        rea_tree_indices = torch.tensor(np.stack([_['rea_tree_indices'] for _ in batch], axis=0), dtype=torch.long)
        rea_rule_nl_left = torch.tensor(np.stack([_['rea_rule_nl_left'] for _ in batch], axis=0), dtype=torch.long)
        rea_rule_nl_right = torch.tensor(np.stack([_['rea_rule_nl_right'] for _ in batch], axis=0), dtype=torch.long)
        class_mask = torch.tensor(np.stack([_['class_mask'] for _ in batch], axis=0), dtype=torch.float32)
        query_paths = torch.tensor(np.stack([_['query_paths'] for _ in batch], axis=0), dtype=torch.long)
        labels = torch.tensor(np.stack([_['labels'] for _ in batch], axis=0), dtype=torch.long)
        parent_matrix = torch.tensor(np.stack([_['parent_matrix'] for _ in batch], axis=0), dtype=torch.float)
        syn_parent_matrix = torch.tensor(np.stack([_['syn_parent_matrix'] for _ in batch], axis=0), dtype=torch.float)
        path_lens = torch.tensor(np.stack([_['path_lens'] for _ in batch], axis=0), dtype=torch.long)
        syn_path_lens = torch.tensor(np.stack([_['syn_path_lens'] for _ in batch], axis=0), dtype=torch.long)

        return_dict = {'input_words': input_words,
                       'input_chars': input_chars,
                       'syn_tree_indices': syn_tree_indices,
                       'syn_rule_nl_left': syn_rule_nl_left,
                       'syn_rule_nl_right': syn_rule_nl_right,
                       'rea_tree_indices': rea_tree_indices,
                       'rea_rule_nl_left': rea_rule_nl_left,
                       'rea_rule_nl_right': rea_rule_nl_right,
                       'class_mask': class_mask,
                       'query_paths': query_paths,
                       'labels': labels,
                       'parent_matrix': parent_matrix,
                       'syn_parent_matrix': syn_parent_matrix,
                       'path_lens': path_lens,
                       'syn_path_lens': syn_path_lens}
        return return_dict


train_dataset=UsptoDataset("train_trees.csv")

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1, collate_fn=UsptoDataset.collate_fn)         

And when I use the dataloader as follows, it gives me different number of batches every epoch:

epoch_steps = len(train_loader)
for e in range(epochs):
    for j, batch_data in enumerate(train_loader):
        step = e * epoch_steps + j

The log shows that the first epoch only has 5 batches, the second epoch has 3 batches, and the third epoch has 5 batches, and so on.

 1 Config:
  2 Namespace(batch_size_per_gpu=4, epochs=400, eval_every_epoch=1, hidden_size=128, keep=10, log_every_step=1, lr=0.001, new_model=False, save_dir='saved_model/', workers=1)
  3 2021-01-06 15:33:17,909 - __main__ - WARNING - Checkpoints not found in dir saved_model/, creating a new model.
  4 2021-01-06 15:33:18,340 - __main__ - INFO - Step: 0, Loss: 5.4213, Rule acc: 0.1388
  5 2021-01-06 15:33:18,686 - __main__ - INFO - Step: 1, Loss: 4.884, Rule acc: 0.542
  6 2021-01-06 15:33:18,941 - __main__ - INFO - Step: 2, Loss: 4.6205, Rule acc: 0.6122
  7 2021-01-06 15:33:19,174 - __main__ - INFO - Step: 3, Loss: 4.4442, Rule acc: 0.61
  8 2021-01-06 15:33:19,424 - __main__ - INFO - Step: 4, Loss: 4.3033, Rule acc: 0.6211
  9 2021-01-06 15:33:20,684 - __main__ - INFO - Dev Loss: 3.5034, Dev Sample Acc: 0.0, Dev Rule Acc: 0.5970844200679234, in epoch 0
 10 2021-01-06 15:33:22,203 - __main__ - INFO - Test Loss: 3.4878, Test Sample Acc: 0.0, Test Rule Acc: 0.6470248053471247
 11 2021-01-06 15:33:22,394 - __main__ - INFO - Found better dev sample accuracy 0.0 in epoch 0
 12 2021-01-06 15:33:22,803 - __main__ - INFO - Step: 10002, Loss: 3.6232, Rule acc: 0.6555
 13 2021-01-06 15:33:23,046 - __main__ - INFO - Step: 10003, Loss: 3.53, Rule acc: 0.6442
 14 2021-01-06 15:33:23,286 - __main__ - INFO - Step: 10004, Loss: 3.4907, Rule acc: 0.6498
 15 2021-01-06 15:33:24,617 - __main__ - INFO - Dev Loss: 3.3081, Dev Sample Acc: 0.0, Dev Rule Acc: 0.5980878387178693, in epoch 1
 16 2021-01-06 15:33:26,215 - __main__ - INFO - Test Loss: 3.2859, Test Sample Acc: 0.0, Test Rule Acc: 0.6466992994149526
 17 2021-01-06 15:33:26,857 - __main__ - INFO - Step: 20004, Loss: 3.3965, Rule acc: 0.6493
 18 2021-01-06 15:33:27,093 - __main__ - INFO - Step: 20005, Loss: 3.3797, Rule acc: 0.6314
 19 2021-01-06 15:33:27,353 - __main__ - INFO - Step: 20006, Loss: 3.3959, Rule acc: 0.5727
 20 2021-01-06 15:33:27,609 - __main__ - INFO - Step: 20007, Loss: 3.3632, Rule acc: 0.6279
 21 2021-01-06 15:33:27,837 - __main__ - INFO - Step: 20008, Loss: 3.3331, Rule acc: 0.6158
 22 2021-01-06 15:33:29,122 - __main__ - INFO - Dev Loss: 3.0911, Dev Sample Acc: 0.0, Dev Rule Acc: 0.6016287207603455, in epoch 2
 23 2021-01-06 15:33:30,689 - __main__ - INFO - Test Loss: 3.0651, Test Sample Acc: 0.0, Test Rule Acc: 0.6531393428643545
 24 2021-01-06 15:33:32,143 - __main__ - INFO - Dev Loss: 3.0911, Dev Sample Acc: 0.0, Dev Rule Acc: 0.6016287207603455, in epoch 3
 25 2021-01-06 15:33:33,765 - __main__ - INFO - Test Loss: 3.0651, Test Sample Acc: 0.0, Test Rule Acc: 0.6531393428643545
 26 2021-01-06 15:33:34,359 - __main__ - INFO - Step: 40008, Loss: 3.108, Rule acc: 0.6816
 27 2021-01-06 15:33:34,604 - __main__ - INFO - Step: 40009, Loss: 3.0756, Rule acc: 0.6732
 28 2021-01-06 15:33:35,823 - __main__ - INFO - Dev Loss: 3.0419, Dev Sample Acc: 0.0, Dev Rule Acc: 0.613776079245976, in epoch 4

FYI, the value of len(train_loader.dataset), batch_size and len(train_loader) are 40008, 4 and 10002 respectively, which are exactly what I expected. So it is so confusing that using enumerate gives me only several batches such as 3 or 5 (10002 is expected).

pyxies
  • 374
  • 1
  • 4
  • 10
  • Can you please turn this into a [mcve]? We do not have the input file, and the methods look excessively huge. – MisterMiyagi Jan 10 '21 at 14:24
  • @MisterMiyagi Thanks for the quick reply, but I am afraid that it is hard to provide an example with little code. I will try to do so, and post it here after it is done. – pyxies Jan 10 '21 at 14:30
  • Could you go through why you are implementing `collate_fn` on your dataset? – Ivan Jan 10 '21 at 14:43
  • @Ivan Sorry for the late reply. I am not very familiar with pytorch, so forgive me if I am doing something stupid. The reason why I implement the `collate_fn` method is that I can get a dictionary from the dataloader as follows: ```for i, batch in enumerate(loader): input_words=batch['input_words'].to(device) ``` – pyxies Jan 11 '21 at 03:48

1 Answers1

0

I am not exactly sure what the problem is with your code. From what I can read, what you are trying to do in collate_fn is, gather and stack data of the same feature type from the batch. Something like:

You are using input_words, input_chars, syn_tree_indices, syn_rule_nl_left, syn_rule_nl_left, syn_rule_nl_right, rea_tree_indices, rea_tree_indices, rea_rule_nl_left, rea_rule_nl_right, class_mask, query_paths, labels, parent_matrix, syn_parent_matrix, path_lens, and syn_path_lens as keys. In my example we will keep it simple with only a, b, c, and d.

  • __getitem__ will return a single data point from your dataset. In your our case, it will be a dictionnary: {'a': ..., 'b': ..., 'c': ..., 'd': ...}.

  • collate_fn: is an intermediate layer between the dataset and dataloader when returning data. It takes a list of batch elements (elements that have been gathered one by one with __getitem__). What you are trying to return here is a collated batch. Something that will convert [{'a': ..., 'b': ..., 'c': ..., 'd': ...}, ...] into {'a': [...], 'b': [...], 'c': [...], 'd': [...]}. Where key 'a' would contain all data from the a feature...

Now what you might not know is for a this simple type of collating, you don't actually need collate_fn. I believe tuples and dictionnaries are handled automatically by PyTorch dataloaders. Which means if you return a dictionnary from __getitem__, your dataloader will collate automatically by keys.

Here, still with our minimal example:

class D(Dataset):
    def __init__(self):
        super(D, self).__init__()
        self.a = [1,11,111,1111,11111]
        self.b = [2,22,222,2222,22222]
        self.c = [3,33,333,3333,33333]
        self.d = [4,44,444,4444,44444]

    def __getitem__(self, i):
        return {
            'a': self.a[i],
            'b': self.b[i],
            'c': self.c[i],
            'd': self.d[i]
        }

    def __len__(self):
        return len(self.a)

As you can see in the following print, the data is gathered by key.

>>> ds = D()
>>> dl = DataLoader(ds, batch_size=2, shuffle=True)

>>> for i, x in enumerate(dl):
>>>    print(i, x)
0 {'a': tensor([11, 1111]), 'b': tensor([22, 2222]), 'c': tensor([33, 3333]), 'd': tensor([44, 4444])}
1 {'a': tensor([1, 11111]), 'b': tensor([2, 22222]), 'c': tensor([3, 33333]), 'd': tensor([4, 44444])}
2 {'a': tensor([111]), 'b': tensor([222]), 'c': tensor([333]), 'd': tensor([444])}

Providing a collate_fn argument will remove this automatic collate.

Ivan
  • 34,531
  • 8
  • 55
  • 100
  • Thanks for the answer, the snippet you provied is indeed more concise, but I do not think providing a custom `collate_fn` is the reason causing the problem. I used the implementation before and it is fine. Anyway I will try to simplify my code and keep debuging, thanks for your help! – pyxies Jan 12 '21 at 04:47