Skip to content

Commit

Permalink
Confusion matrix finalize
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Jan 19, 2024
1 parent 9b0878f commit 6d92ff8
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions sage/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,21 @@ def _get_norm_cf_reg(preds, target, root_dir, run_name) -> None:
fig, ax = plt.subplots(figsize=(13, 6), ncols=2)
labelsize = "large"
titlesize = "xx-large"
cmap = "Blues"

# Normalized Confusion Matrix
cf = confusion_matrix(_preds, _target)
sns.heatmap(cf, annot=True, fmt="d", xticklabels=_target.categories, ax=ax[0])
sns.heatmap(cf, annot=True, fmt="d",
xticklabels=_target.categories, cmap=cmap, ax=ax[0])
ax[0].set_xlabel("Target", size=labelsize)
ax[0].set_yticklabels(_preds.categories, rotation=270)
ax[0].set_ylabel("Prediction", size=labelsize)

# 0-1 row-wise Noramlized Confusion.
norm_cf = (cf.T / cf.sum(axis=1)).T
norm_cf = np.nan_to_num(x=norm_cf, nan=0.0)
sns.heatmap(norm_cf, annot=True, fmt="0.2f", xticklabels=_target.categories, ax=ax[1])
sns.heatmap(norm_cf, annot=True, fmt="0.2f",
xticklabels=_target.categories, cmap=cmap, ax=ax[1])
ax[1].set_xlabel("Target", size=labelsize)
ax[1].set_yticklabels(_preds.categories, rotation=270)
ax[1].set_ylabel("Prediction", size=labelsize)
Expand Down

0 comments on commit 6d92ff8

Please sign in to comment.