0

When I am training a model, I should use only 10% of data for trainer.fit(model,datamodule) so I should call DataModule just for 10% of data Part of DataModule is:

class DataModule(pl.LightningDataModule): 
  def __init__(self, train_dataset, val_dataset,  batch_size = 1):

    super(DataModule, self).__init__()
    self.train_dataset = train_dataset
    self.val_dataset = val_dataset
    self.batch_size = batch_size
  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size = self.batch_size, 
                      collate_fn = collate_fn, shuffle = True, num_workers = 2, pin_memory = True)
  
  def val_dataloader(self):
    return DataLoader(self.val_dataset, batch_size = self.batch_size,
                    collate_fn = collate_fn, shuffle = False, num_workers = 2, pin_memory = True)

So I use a for loop

datamodule = DataModule(train_ds, val_ds)
for i,data in enumerate(datamodule.train_dataloader()):
    print( datamodule.train_dataloader(i,data))

But it doesn't work. How can I change it?

diamond
  • 11
  • 2

0 Answers0