Search code examples
pythondeep-learningpytorchtensorboardpytorch-lightning

How to dump confusion matrix using TensorBoard logger in pytorch-lightning?


The official doc only states

>>> from pytorch_lightning.metrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)

This doesn't show how to use the metric with the framework.

My attempt (methods are not complete and only show relevant parts):

def __init__(...):
    self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)

def validation_step(self, batch, batch_index):
    ...
    log_probs = self.forward(orig_batch)
    loss = self._criterion(log_probs, label_batch)
   
    self.val_confusion.update(log_probs, label_batch)
    self.log('validation_confusion_step', self.val_confusion, on_step=True, on_epoch=False)

def validation_step_end(self, outputs):
    return outputs

def validation_epoch_end(self, outs):
    self.log('validation_confusion_epoch', self.val_confusion.compute())

After the 0th epoch, this gives

    Traceback (most recent call last):
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 521, in train
        self.train_loop.run_training_epoch()
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 588, in run_training_epoch
        self.trainer.run_evaluation(test_mode=False)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 613, in run_evaluation
        self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 346, in log_evaluation_step_metrics
        self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 350, in __log_result_step_metrics
        cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 378, in update_logger_connector
        batch_log_metrics = self.get_latest_batch_log_metrics()
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 418, in get_latest_batch_log_metrics
        batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics")
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in run_batch_from_func_name
        results = [func(include_forked_originals=False) for func in results]
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in <listcomp>
        results = [func(include_forked_originals=False) for func in results]
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 122, in get_batch_log_metrics
        return self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics",
*args, **kwargs)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in run_latest_batch_metrics_with_func_name
        for dl_idx in range(self.num_dataloaders)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in <listcomp>
        for dl_idx in range(self.num_dataloaders)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 100, in get_latest_from_func_name
        results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\core\step_result.py", line 298, in get_batch_log_metrics
        result[dl_key] = self[k]._forward_cache.detach()
    AttributeError: 'NoneType' object has no attribute 'detach'

                                                      

It does pass the sanity validation check before training.

The failure happens on the return in validation_step_end. Makes little sense to me.

The exact same method of using mertics works fine with accuracy.

How to get a correct confusion matrix?


Solution

  • Updated answer, August 2022

    
    class IntHandler:
        def legend_artist(self, legend, orig_handle, fontsize, handlebox):
            x0, y0 = handlebox.xdescent, handlebox.ydescent
            text = plt.matplotlib.text.Text(x0, y0, str(orig_handle))
            handlebox.add_artist(text)
            return text
    
    
    
    class LightningClassifier(LightningModule):
        ...
    
        def _common_step(self, batch, batch_nb, stage: str):
            assert stage in ("train", "val", "test")
            logger = self._logger
            augmented_image, labels = batch
    
            outputs, aux_outputs = self(augmented_image)
            loss = self._criterion(outputs, labels)
    
            return outputs, labels, loss
    
        def validation_step(self, batch, batch_nb):
            stage = "val"
            outputs, labels, loss = self._common_step(batch, batch_nb, stage=stage)
            self._common_log(loss, stage=stage)
    
            return {"loss": loss, "outputs": outputs, "labels": labels}
    
    
        def validation_epoch_end(self, outs):
            # see https://github.com/Lightning-AI/metrics/blob/ff61c482e5157b43e647565fa0020a4ead6e9d61/docs/source/pages/lightning.rst
            # each forward pass, thus leading to wrong accumulation. In practice do the following:
            tb = self.logger.experiment  # noqa
    
            outputs = torch.cat([tmp['outputs'] for tmp in outs])
            labels = torch.cat([tmp['labels'] for tmp in outs])
    
            confusion = torchmetrics.ConfusionMatrix(num_classes=self.n_labels).to(outputs.get_device())
            confusion(outputs, labels)
            computed_confusion = confusion.compute().detach().cpu().numpy().astype(int)
    
            # confusion matrix
            df_cm = pd.DataFrame(
                computed_confusion,
                index=self._label_ind_by_names.values(),
                columns=self._label_ind_by_names.values(),
            )
    
            fig, ax = plt.subplots(figsize=(10, 5))
            fig.subplots_adjust(left=0.05, right=.65)
            sn.set(font_scale=1.2)
            sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='d', ax=ax)
            ax.legend(
                self._label_ind_by_names.values(),
                self._label_ind_by_names.keys(),
                handler_map={int: IntHandler()},
                loc='upper left',
                bbox_to_anchor=(1.2, 1)
            )
            buf = io.BytesIO()
    
            plt.savefig(buf, format='jpeg', bbox_inches='tight')
            buf.seek(0)
            im = Image.open(buf)
            im = torchvision.transforms.ToTensor()(im)
            tb.add_image("val_confusion_matrix", im, global_step=self.current_epoch)
    
    

    output:

    enter image description here

    Also based on this