From a80b466fd2346ca24cf3a53664b703b81de89b4e Mon Sep 17 00:00:00 2001 From: 1pha <1phantasmas@korea.ac.kr> Date: Tue, 5 Mar 2024 01:19:26 +0000 Subject: [PATCH] Fix confusion matrix logging --- config/train_cls.yaml | 7 ++++++- sage/trainer/trainer.py | 7 ++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/config/train_cls.yaml b/config/train_cls.yaml index 543318b..764961b 100644 --- a/config/train_cls.yaml +++ b/config/train_cls.yaml @@ -91,4 +91,9 @@ callbacks: _target_: pytorch_lightning.callbacks.LearningRateMonitor logging_interval: step log_momentum: False - \ No newline at end of file + + richsummary: + _target_: pytorch_lightning.callbacks.RichModelSummary + + richpbar: + _target_: pytorch_lightning.callbacks.RichProgressBar \ No newline at end of file diff --git a/sage/trainer/trainer.py b/sage/trainer/trainer.py index c2e3067..f3a7126 100644 --- a/sage/trainer/trainer.py +++ b/sage/trainer/trainer.py @@ -251,15 +251,16 @@ def validation_step(self, batch, batch_idx): output: dict = self.valid_metric(result["pred"], result["target"]) self.log_result(output, unit="step", prog_bar=False) - self.validation_step_outputs.append(result) def on_validation_epoch_end(self): output: dict = self.valid_metric.compute() self.log_result(output, unit="epoch", prog_bar=True) - if output["pred"].ndim == 2: + + result = utils._sort_outputs(outputs=self.validation_step_outputs) + if result["pred"].ndim == 2: """ Assuming prediction with (B, C) shape is a classification task""" - self.log_confusion_matrix(result=output) + self.log_confusion_matrix(result=result) self.validation_step_outputs.clear() def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):