2

I'm using pytorch lightning, and at the end of each epoch, I create a confusion matrix from torchmetrics.ConfusionMatrix (see code below). I would like to log this into Wandb, but the Wandb confusion matrix logger only accepts y_targets and y_predictions. Does anyone know how to extract the updated confusion matrix y_targets and y_predictions from a confusion matrix, or alternatively give Wandb my updated confusion matrix in a way that it can be processed into eg a heatmap within wandb?

class ClassificationTask(pl.LightningModule):
    def __init__(self, model, lr=1e-4, augmentor=augmentor):
        super().__init__()
        self.model = model
        self.lr = lr
        self.save_hyperparameters() #not being used at the moment, good to have ther in the future
        self.augmentor=augmentor
        
        self.matrix = torchmetrics.ConfusionMatrix(num_classes=9)
        
        self.y_trues=[]
        self.y_preds=[]
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        x=self.augmentor(x)#.to('cuda')
        y_pred = self.model(x)
        loss = F.cross_entropy(y_pred, y,)  #weights=class_weights_tensor
        

        acc = accuracy(y_pred, y)
        metrics = {"train_acc": acc, "train_loss": loss}
        self.log_dict(metrics)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"val_acc": acc, "val_loss": loss, }
        self.log_dict(metrics)
        return metrics
    
    def _shared_eval_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)
        self.matrix.update(y_hat,y)
        return loss, acc
    
    def validation_epoch_end(self, outputs):
        confusion_matrix = self.matrix.compute()
        wandb.log({"my_conf_mat_id" : confusion_matrix})
        
    def configure_optimizers(self):
        return torch.optim.Adam((self.model.parameters()), lr=self.lr)
Chris92
  • 31
  • 2

2 Answers2

1

I spent a while grappling with this, and here is what has worked for me. I create the confusion matrix in seaborn in the modelling process and then log the figure to wandb at the end of validation epoch.

I am doing multi-class classification, and I included class_names and confusion matrix metric as variables in the init of my model (self.confmat = ConfusionMatrix(task="multiclass", num_classes=9), self.class_names = class_names ).

For the torch metric version of the confusion matrix, it was not obvious from the documentation which way round the axes of the confusion matrix were. For me, when summed along the horizontal (axis=1) direction, I recovered the number of samples for each class. I then created the normalised CM, ensuring that it summed up to the num of classes as a validation. I also print out the number of samples as another validation to compare with the number of val samples from my dataloader.

I logged the confusion metric in my validation step as follows:

from torchmetrics import  ConfusionMatrix

def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
      
        self.confmat.update(preds, y)

        return  loss
      

    def on_validation_epoch_end(self):
        confmat = self.confmat.compute()
        class_names = self.class_names
        num_classes = len(class_names)

        df_cm = pd.DataFrame(confmat.cpu().numpy() , index = [i for i in class_names], columns = [i for i in class_names])
        
        print('Num of val samples: {}. Check this aligns with the numbers from the dataloader'.format(df_cm.sum(axis=1).sum() ))
        # df_cm.to_csv('raw_nums.csv') # you can use this to validate the number of samples is correct
      
        #normalise the confusion matrix 
        norm =  np.sum(df_cm, axis=1)
        normalized_cm = (df_cm.T/norm).T # 
        #validate the confusion matrix sums to num of classes
        if normalized_cm.sum(axis=1).sum() != num_classes:
          print('Error with confusion matrix')
          sys.exit() 

        normalized_cm.to_csv('norm_cdf.csv') #saved locally so that I could validate outside of wandb
        
        #log to wandb
        f, ax = plt.subplots(figsize = (15,10)) 
        sn.heatmap(normalized_cm, annot=True, ax=ax)
        wandb.log({"plot": wandb.Image(f) })
        
        self.confmat.reset()  #This was NEEDED otherwise the confusion matrix kept stacking the results after each epoch 
graceebc95
  • 11
  • 1
0

I'm actually working on the same issue currently. I found this great PR Feature request for PyTorch lightning. Perhaps this could be of help. I think a possible solution is utilizing torch metrics confusion matrix and then incorporating that into your train/val/test steps and logging them.

https://github.com/Lightning-AI/metrics/issues/880