1

I'm dealing with multiple datasets training using pytorch_lightning. Datasets have different lengths ---> different number of batches in corresponding DataLoaders. For now I tried to keep things separately by using dictionaries, as my ultimate goal is weighting the loss function according to a specific dataset:

def train_dataloader(self): #returns a dict of dataloaders
    train_loaders = {}
    for key, value in self.train_dict.items():
        train_loaders[key] = DataLoader(value,
                                        batch_size = self.batch_size,
                                        collate_fn = collate)
    return train_loaders

Then, in training_step() I do the following:

def training_step(self, batch, batch_idx):        
    total_batch_loss = 0

    for key, value in batch.items():
        anc, pos, neg  = value
        emb_anc = F.normalize(self.forward(anc.x,
                                           anc.edge_index,
                                           anc.weights,
                                           anc.batch,
                                           training=True
                                           ), 2, dim=1)
    
        emb_pos = F.normalize(self.forward(pos.x,
                                           pos.edge_index,
                                           pos.weights,
                                           pos.batch,
                                           training=True
                                           ), 2, dim=1)
    
        emb_neg = F.normalize(self.forward(neg.x,
                                           neg.edge_index,
                                           neg.weights,
                                           neg.batch,
                                           training=True
                                           ), 2, dim=1)
                                
        loss_dataset = LossFunc(emb_anc, emb_pos, emb_neg, anc.y, pos.y, neg.y)
        total_batch_loss += loss_dataset
        
    self.log("Loss", total_batch_loss, prog_bar=True, on_epoch=True)        
    return total_batch_loss

The problem is that when the smallest dataset gets exhausted, Lightning will throw a StopIteration and so I won't finish training on remaining batches from other datasets. I have considered concatenating everything into a single train DataLoader as suggested in the docs but I don't see how I can weight loss function differently according to this method.

James Arten
  • 523
  • 5
  • 16

1 Answers1

2

you can use the CombinedLoader class and specify max_size mode to iterate based on the longest dataloader available.

Aniket Maurya
  • 330
  • 3
  • 5