From 6d92ff82875ed8e4d00ff94e375a3243946e5127 Mon Sep 17 00:00:00 2001 From: 1pha <1phantasmas@korea.ac.kr> Date: Fri, 19 Jan 2024 04:31:22 +0000 Subject: [PATCH] Confusion matrix finalize --- sage/trainer/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sage/trainer/utils.py b/sage/trainer/utils.py index 32b4471..6104e59 100644 --- a/sage/trainer/utils.py +++ b/sage/trainer/utils.py @@ -162,10 +162,12 @@ def _get_norm_cf_reg(preds, target, root_dir, run_name) -> None: fig, ax = plt.subplots(figsize=(13, 6), ncols=2) labelsize = "large" titlesize = "xx-large" + cmap = "Blues" # Normalized Confusion Matrix cf = confusion_matrix(_preds, _target) - sns.heatmap(cf, annot=True, fmt="d", xticklabels=_target.categories, ax=ax[0]) + sns.heatmap(cf, annot=True, fmt="d", + xticklabels=_target.categories, cmap=cmap, ax=ax[0]) ax[0].set_xlabel("Target", size=labelsize) ax[0].set_yticklabels(_preds.categories, rotation=270) ax[0].set_ylabel("Prediction", size=labelsize) @@ -173,7 +175,8 @@ def _get_norm_cf_reg(preds, target, root_dir, run_name) -> None: # 0-1 row-wise Noramlized Confusion. norm_cf = (cf.T / cf.sum(axis=1)).T norm_cf = np.nan_to_num(x=norm_cf, nan=0.0) - sns.heatmap(norm_cf, annot=True, fmt="0.2f", xticklabels=_target.categories, ax=ax[1]) + sns.heatmap(norm_cf, annot=True, fmt="0.2f", + xticklabels=_target.categories, cmap=cmap, ax=ax[1]) ax[1].set_xlabel("Target", size=labelsize) ax[1].set_yticklabels(_preds.categories, rotation=270) ax[1].set_ylabel("Prediction", size=labelsize)