8

The official doc only states

>>> from pytorch_lightning.metrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)

This doesn't show how to use the metric with the framework.

My attempt (methods are not complete and only show relevant parts):

def __init__(...):
    self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)

def validation_step(self, batch, batch_index):
    ...
    log_probs = self.forward(orig_batch)
    loss = self._criterion(log_probs, label_batch)
   
    self.val_confusion.update(log_probs, label_batch)
    self.log('validation_confusion_step', self.val_confusion, on_step=True, on_epoch=False)

def validation_step_end(self, outputs):
    return outputs

def validation_epoch_end(self, outs):
    self.log('validation_confusion_epoch', self.val_confusion.compute())

After the 0th epoch, this gives

    Traceback (most recent call last):
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 521, in train
        self.train_loop.run_training_epoch()
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 588, in run_training_epoch
        self.trainer.run_evaluation(test_mode=False)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 613, in run_evaluation
        self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 346, in log_evaluation_step_metrics
        self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 350, in __log_result_step_metrics
        cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 378, in update_logger_connector
        batch_log_metrics = self.get_latest_batch_log_metrics()
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 418, in get_latest_batch_log_metrics
        batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics")
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in run_batch_from_func_name
        results = [func(include_forked_originals=False) for func in results]
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in <listcomp>
        results = [func(include_forked_originals=False) for func in results]
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 122, in get_batch_log_metrics
        return self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics",
*args, **kwargs)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in run_latest_batch_metrics_with_func_name
        for dl_idx in range(self.num_dataloaders)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in <listcomp>
        for dl_idx in range(self.num_dataloaders)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 100, in get_latest_from_func_name
        results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\core\step_result.py", line 298, in get_batch_log_metrics
        result[dl_key] = self[k]._forward_cache.detach()
    AttributeError: 'NoneType' object has no attribute 'detach'

                                                      

It does pass the sanity validation check before training.

The failure happens on the return in validation_step_end. Makes little sense to me.

The exact same method of using mertics works fine with accuracy.

How to get a correct confusion matrix?

Gulzar
  • 23,452
  • 27
  • 113
  • 201
  • Please provide the expected [MRE](https://stackoverflow.com/help/minimal-reproducible-example). Show where the intermediate results deviate from the ones you expect. We should be able to paste a single block of your code into file, run it, and reproduce your problem. This also lets us test any suggestions in your context. – Prune Dec 29 '20 at 21:24
  • The docs link you provide gives more information than you provide in the question, as well as a more complete example. As best I can see, your `update` in `validation_step` assumes an implementation that isn't consistent with the structure of a `ConfusionMatrix` object. Since you've omitted so much code, we can't tell; you've left us to eye-check your untraced code fragments, rather than testing. – Prune Dec 29 '20 at 21:25
  • @Prune MRE not doable, code running machine learning takes at least a dataset, and config. This is simply a lacking doc question, and my reproducible is actually useless anyway, I just want to see the correct usage. Please tell me what part of the doc I am missing? Obviously my implementation is not as expected, but I also don't understand what is expected, as I am using the exact same as the fuller accuracy example. – Gulzar Dec 29 '20 at 21:32
  • The accuracy example *in the doc itself* is not a MRE, because it is just less readable that way... https://pytorch-lightning.readthedocs.io/en/stable/metrics.html – Gulzar Dec 29 '20 at 21:33

3 Answers3

13

Obsolete for lightning>=2.0.0

You can report the figure using self.logger.experiment.add_figure(*tag*, *figure*).

The variable self.logger.experiment is actually a SummaryWriter (from PyTorch, not Lightning). This class has the method add_figure (documentation).

You can use it as follows: (MNIST example)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = F.nll_loss(preds, y)
        return { 'loss': loss, 'preds': preds, 'target': y}

    def validation_epoch_end(self, outputs):
        preds = torch.cat([tmp['preds'] for tmp in outputs])
        targets = torch.cat([tmp['target'] for tmp in outputs])
        confusion_matrix = pl.metrics.functional.confusion_matrix(preds, targets, num_classes=10)

        df_cm = pd.DataFrame(confusion_matrix.numpy(), index = range(10), columns=range(10))
        plt.figure(figsize = (10,7))
        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
        plt.close(fig_)
        
        self.logger.experiment.add_figure("Confusion matrix", fig_, self.current_epoch)
Gulzar
  • 23,452
  • 27
  • 113
  • 201
Yorick
  • 131
  • 4
  • Coming back to this, it makes no sense. The confusion matrix is not updated every step, how can it be correct? Please explain. – Gulzar Aug 16 '22 at 14:29
  • Also some calls seem outdated, namely `.confusion_matrix(preds, targets, num_classes=10)` and the device seems to be wrong for the `confusion_matrix` itself – Gulzar Aug 16 '22 at 17:48
  • also `pl` has no `metrics` now – Gulzar Aug 16 '22 at 18:11
  • See a [working version](https://stackoverflow.com/a/73388839/913098) based also on this. – Gulzar Aug 17 '22 at 12:59
  • 1
    I am not surprised if this doesn't work with the latest versions. PL has been changing quite rapidly. Concerning the update frequency, the matrix is generated after each validation step and added to TB. Isn't that what you want? – Yorick Aug 17 '22 at 15:35
  • It actually is what I want, but I expected updating every step to be handled and somehow zeroed by PL at the start of each epoch. Otherwise, how does any metric work when updated only in the step? [see doc of i.e Accuracy](https://torchmetrics.readthedocs.io/en/stable/pages/lightning.html) – Gulzar Aug 17 '22 at 16:45
  • If update on step accumulates too much, then what's the difference? Looks like this is an important detail – Gulzar Aug 17 '22 at 16:45
  • torch doesn't allow you to put `outputs` in the function arguments anymore. How to access outputs? – dorien Jun 19 '23 at 03:55
6

This is obsolete.

See a much better version


This took a lot of time to find.

This is the most minimal code I could paste that is still readable and reproducible.

I didn't want to put the entire model dataset and parameters here, as they are of no interest to readers of this question and are just noise.


That said, here is the required code for creating a confusion matrix per epoch and displaying in Tensorboard

This is a single frame for example:

enter image description here


import pytorch_lightning as pl
import seaborn as sn
import pandas as pd
import numpy as np
import io
import matplotlib.pyplot as plt
from PIL import Image

def __init__(self, config, trained_vae, latent_dim):
    self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)
    self.logger: Optional[TensorBoardLogger] = None

def forward(self, x):
    ...
    return log_probs

def validation_step(self, batch, batch_index):
    if self._config.dataset == "mnist":
        orig_batch, label_batch = batch
        orig_batch = orig_batch.reshape(-1, 28 * 28)

    log_probs = self.forward(orig_batch)
    loss = self._criterion(log_probs, label_batch)

    self.val_confusion.update(log_probs, label_batch)
    return {"loss": loss, "labels": label_batch}

def validation_step_end(self, outputs):
    return outputs

def validation_epoch_end(self, outs):
    tb = self.logger.experiment

    # confusion matrix
    conf_mat = self.val_confusion.compute().detach().cpu().numpy().astype(np.int)
    df_cm = pd.DataFrame(
        conf_mat,
        index=np.arange(self._config.n_clusters),
        columns=np.arange(self._config.n_clusters))
    plt.figure()
    sn.set(font_scale=1.2)
    sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='d')
    buf = io.BytesIO()
    
    plt.savefig(buf, format='jpeg')
    buf.seek(0)
    im = Image.open(buf)
    im = torchvision.transforms.ToTensor()(im)
    tb.add_image("val_confusion_matrix", im, global_step=self.current_epoch)

and the call to trainer

logger = TensorBoardLogger(save_dir=tb_logs_folder, name='Classifier')
trainer = Trainer(
    deterministic=True,
    max_epochs=10,
    default_root_dir=classifier_checkpoints_path,
    logger=logger,
    gpus=1
)
Gulzar
  • 23,452
  • 27
  • 113
  • 201
4

Updated answer, August 2022


class IntHandler:
    def legend_artist(self, legend, orig_handle, fontsize, handlebox):
        x0, y0 = handlebox.xdescent, handlebox.ydescent
        text = plt.matplotlib.text.Text(x0, y0, str(orig_handle))
        handlebox.add_artist(text)
        return text



class LightningClassifier(LightningModule):
    ...

    def _common_step(self, batch, batch_nb, stage: str):
        assert stage in ("train", "val", "test")
        logger = self._logger
        augmented_image, labels = batch

        outputs, aux_outputs = self(augmented_image)
        loss = self._criterion(outputs, labels)

        return outputs, labels, loss

    def validation_step(self, batch, batch_nb):
        stage = "val"
        outputs, labels, loss = self._common_step(batch, batch_nb, stage=stage)
        self._common_log(loss, stage=stage)

        return {"loss": loss, "outputs": outputs, "labels": labels}


    def validation_epoch_end(self, outs):
        # see https://github.com/Lightning-AI/metrics/blob/ff61c482e5157b43e647565fa0020a4ead6e9d61/docs/source/pages/lightning.rst
        # each forward pass, thus leading to wrong accumulation. In practice do the following:
        tb = self.logger.experiment  # noqa

        outputs = torch.cat([tmp['outputs'] for tmp in outs])
        labels = torch.cat([tmp['labels'] for tmp in outs])

        confusion = torchmetrics.ConfusionMatrix(num_classes=self.n_labels).to(outputs.get_device())
        confusion(outputs, labels)
        computed_confusion = confusion.compute().detach().cpu().numpy().astype(int)

        # confusion matrix
        df_cm = pd.DataFrame(
            computed_confusion,
            index=self._label_ind_by_names.values(),
            columns=self._label_ind_by_names.values(),
        )

        fig, ax = plt.subplots(figsize=(10, 5))
        fig.subplots_adjust(left=0.05, right=.65)
        sn.set(font_scale=1.2)
        sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='d', ax=ax)
        ax.legend(
            self._label_ind_by_names.values(),
            self._label_ind_by_names.keys(),
            handler_map={int: IntHandler()},
            loc='upper left',
            bbox_to_anchor=(1.2, 1)
        )
        buf = io.BytesIO()

        plt.savefig(buf, format='jpeg', bbox_inches='tight')
        buf.seek(0)
        im = Image.open(buf)
        im = torchvision.transforms.ToTensor()(im)
        tb.add_image("val_confusion_matrix", im, global_step=self.current_epoch)

output:

enter image description here

Also based on this

Gulzar
  • 23,452
  • 27
  • 113
  • 201
  • Currently, I cannot put 'outs' in the function argument: TypeError: LSTMClassifier.on_test_epoch_end() missing 1 required positional argument: 'outs'. How to access outs? – dorien Jun 19 '23 at 03:51
  • @dorien your error is unclear. It seems you didn't put outs as an argument whereas it was required, as in the above code. – Gulzar Jun 19 '23 at 17:57
  • Lightening 2.x does not support passing outs as arguments. A workaround that works with the new version can be found here: https://lightning.ai/forums/t/confusion-matrix-in-on-test-epoch-end-argument-error/2423 – dorien Jun 20 '23 at 01:35