Skip to content

Commit

Permalink
Fix confusion matrix logging
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 5, 2024
1 parent b383b77 commit a80b466
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
7 changes: 6 additions & 1 deletion config/train_cls.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,9 @@ callbacks:
_target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: step
log_momentum: False


richsummary:
_target_: pytorch_lightning.callbacks.RichModelSummary

richpbar:
_target_: pytorch_lightning.callbacks.RichProgressBar
7 changes: 4 additions & 3 deletions sage/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,16 @@ def validation_step(self, batch, batch_idx):

output: dict = self.valid_metric(result["pred"], result["target"])
self.log_result(output, unit="step", prog_bar=False)

self.validation_step_outputs.append(result)

def on_validation_epoch_end(self):
output: dict = self.valid_metric.compute()
self.log_result(output, unit="epoch", prog_bar=True)
if output["pred"].ndim == 2:

result = utils._sort_outputs(outputs=self.validation_step_outputs)
if result["pred"].ndim == 2:
""" Assuming prediction with (B, C) shape is a classification task"""
self.log_confusion_matrix(result=output)
self.log_confusion_matrix(result=result)
self.validation_step_outputs.clear()

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
Expand Down

0 comments on commit a80b466

Please sign in to comment.