2

So basically, I am using the class from the Pytorch Lightning Module. My issue is that I'm loading my data using Pytorch Dataloader:

def train_dataloader(self):
    train_dir = f"{self.img_dir_gender}/train"
    # train_transforms: from PIL to TENSOR + DATA AUG
    train_transforms = T.Compose([
        T.ToTensor(),
        # T.Pad(25, padding_mode='symmetric'),
        # T.RandomHorizontalFlip(),
        # T.RandomVerticalFlip()
    ])
    train_dataset = ImageFolder(train_dir, transform=train_transforms)

    print(train_dataset.class_to_idx)
    print(Counter(train_dataset.targets))

    # oversampling giving more weight to minority classes
    class_weights = Counter(train_dataset.targets)
    class_weights_adjusted = [0] * len(train_dataset)
    for idx, (data, label) in enumerate(train_dataset):
    # inverse gives more weight to minority classes
        class_weight = 1 / class_weights[label]
        class_weights_adjusted[idx] = class_weight
    sampler = WeightedRandomSampler(class_weights_adjusted, num_samples=self.num_samples , replacement=True)

    train_loader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=4, sampler=sampler, shuffle=False)
    return train_loader

And there I manage to retrieve my class weights and execute some oversampling:

However, I cannot manage to retrieve those weights and, say, take their inverse to then pass them to my cross_entropy loss function within my training_step and val_step methods with the aim of tackling class imbalance in my val dataset:

def training_step(self, batch, batch_idx):
    # torch.Size([bs, 3, 224, 224])
    # x = batch["pixel_values"]
    # torch.Size([bs])
    # y = batch["labels"]
    x, y = batch
    # unfreeze after a certain number of epochs
    # self.trainer.current_epoch >=

    # meaning it will not keep a graph with grads for the backbone (memory efficient)
    if self.trainer.current_epoch < self.hparams.unfreeze_epoch:
        with torch.no_grad():
            features = self.backbone(x)
    else:
        features = self.backbone(x)
    preds = self.finetune_layer(features)
    # pred_probs = softmax(preds, dim=-1)
    # pred_labels = torch.argmax(pred_probs, dim=-1)
    train_loss = cross_entropy(preds, y, weight=?)
    self.log("train_loss", train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    self.log("train_accuracy", self.train_accuracy(preds, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)
    self.log("train_f1_score", self.train_f1(preds, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)
    #self.log("train_accuracy", self.train_accuracy(preds, y), prog_bar=True)
    #self.log("train_precision", self.train_precision(preds, y), prog_bar=True)
    #self.log("train_recall", self.train_recall(preds, y), prog_bar=True)
    #self.log("train_f1", self.train_f1(preds, y), prog_bar=True)
    return train_loss

So I know that I should use the weight= parameter in the cross_entropy function, but how can I retrieve my class weights from my training dataset?

Let me know if I should add some clarifications.

martineau
  • 119,623
  • 25
  • 170
  • 301

1 Answers1

1

You could:

dm = DataModule()
# write your weights getter function in your pl.LightningDataModule
weights = dm.get_weights()
# where your loss function is set under your pl.LightningModule's init 
#
#        self.loss = nn.CrossEntropyLoss(weights=weights)) 
#
# and then called under training_step as self.loss(preds, y)
model = model(weights) 
trainer.fit(model, dm)

No need of passing weights all the time to your loss function

Mike B
  • 2,136
  • 2
  • 12
  • 31
  • The problem with this approach is that it will apply the weights to your validation set too, which is not what you want. It is better to pass the weights to the `LightningModule` on `init` but then use the functional `F.cross_entropy_loss` call with the weights passed in only on training. – System123 Sep 19 '22 at 07:34
  • It is hardly misleading, it all depends what you want. Generally you want your validation losses to provide insight into how well your model will perform on the natural (test) data distribution. Thus if you weight your losses on validation you lose that interpretation and instead your validation reflects an artificially balanced representation of your data distribution. – System123 Oct 06 '22 at 09:16
  • But this way you can't apply pytorch lightning CLI :/ Is there any other way lighting CLI could be used? – Dániel Terbe Nov 03 '22 at 08:57