I'm dealing with multiple datasets training using pytorch_lightning
. Datasets have different lengths ---> different number of batches in corresponding DataLoader
s. For now I tried to keep things separately by using dictionaries, as my ultimate goal is weighting the loss function according to a specific dataset:
def train_dataloader(self): #returns a dict of dataloaders
train_loaders = {}
for key, value in self.train_dict.items():
train_loaders[key] = DataLoader(value,
batch_size = self.batch_size,
collate_fn = collate)
return train_loaders
Then, in training_step()
I do the following:
def training_step(self, batch, batch_idx):
total_batch_loss = 0
for key, value in batch.items():
anc, pos, neg = value
emb_anc = F.normalize(self.forward(anc.x,
anc.edge_index,
anc.weights,
anc.batch,
training=True
), 2, dim=1)
emb_pos = F.normalize(self.forward(pos.x,
pos.edge_index,
pos.weights,
pos.batch,
training=True
), 2, dim=1)
emb_neg = F.normalize(self.forward(neg.x,
neg.edge_index,
neg.weights,
neg.batch,
training=True
), 2, dim=1)
loss_dataset = LossFunc(emb_anc, emb_pos, emb_neg, anc.y, pos.y, neg.y)
total_batch_loss += loss_dataset
self.log("Loss", total_batch_loss, prog_bar=True, on_epoch=True)
return total_batch_loss
The problem is that when the smallest dataset gets exhausted, Lightning will throw a StopIteration
and so I won't finish training on remaining batches from other datasets. I have considered concatenating everything into a single train DataLoader
as suggested in the docs but I don't see how I can weight loss function differently according to this method.