Skip to content

Commit

Permalink
Add normalized regression age
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Jan 19, 2024
1 parent 469d504 commit 9b0878f
Showing 1 changed file with 35 additions and 8 deletions.
43 changes: 35 additions & 8 deletions sage/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from pathlib import Path
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
sns.set_theme()
import torch
Expand Down Expand Up @@ -91,7 +93,7 @@ def finalize_inference(prediction: list,
preds, target = prediction["pred"], prediction["target"]
if name.startswith("C"):
logger.info("Classification data given:")
_cls_infrence(preds=preds, target=target, root_dir=root_dir, run_name=run_name)
_cls_inference(preds=preds, target=target, root_dir=root_dir, run_name=run_name)
elif name[0] in set(["R", "M"]):
logger.info("Regression data given:")
_reg_inference(preds=preds, target=target, root_dir=root_dir, run_name=run_name)
Expand Down Expand Up @@ -135,8 +137,9 @@ def _get_norm_cf_reg(preds, target, root_dir, run_name) -> None:
2. Interval of each bin: 5 or 10.
"""
AGEMIN, AGEMAX = 0, 100
min_, max_ = target.min(), target.max()
min_, max_ = max(AGEMIN, min_), min(AGEMAX, max_)
pmin_, pmax_ = preds.min(), preds.max()
tmin_, tmax_ = target.min(), target.max()
min_, max_ = max(AGEMIN, min(tmin_, pmin_)), min(AGEMAX, max(tmax_, pmax_))
int_ = max_ - min_ # interval

# Check two intervals 5 and 10
Expand All @@ -149,14 +152,38 @@ def _get_norm_cf_reg(preds, target, root_dir, run_name) -> None:
while (ub % bin_) != 0:
ub += 1
bins = [lb + bin_ * idx for idx in range(0, (ub - lb) // bin_ + 1)]
labels = [f"{left+1}-{right}" for left, right in zip(bins, bins[1:])]
labels = [f"{left}-{right - 1}" for left, right in zip(bins, bins[1:])]

cut_kwargs = dict(bins=bins, labels=labels, include_lowest=True)
_preds = pd.cut(x=preds, **cut_kwargs)
_target = pd.cut(x=target, **cut_kwargs)
assert (_preds.isna().sum() + _target.isna().sum()) == 0, f"nan value in binning."

fig, ax = plt.subplots(figsize=(13, 6), ncols=2)
labelsize = "large"
titlesize = "xx-large"

# Normalized Confusion Matrix
cf = confusion_matrix(_preds, _target)
sns.heatmap(cf, annot=True, fmt="d", xticklabels=_target.categories, ax=ax[0])
ax[0].set_xlabel("Target", size=labelsize)
ax[0].set_yticklabels(_preds.categories, rotation=270)
ax[0].set_ylabel("Prediction", size=labelsize)

# 0-1 row-wise Noramlized Confusion.
norm_cf = (cf.T / cf.sum(axis=1)).T
norm_cf = np.nan_to_num(x=norm_cf, nan=0.0)
sns.heatmap(norm_cf, annot=True, fmt="0.2f", xticklabels=_target.categories, ax=ax[1])
ax[1].set_xlabel("Target", size=labelsize)
ax[1].set_yticklabels(_preds.categories, rotation=270)
ax[1].set_ylabel("Prediction", size=labelsize)

_preds = pd.cut(x=preds, bins=bins, labels=labels)
_target = pd.cut(x=target, bins=bins, labels=labels)
breakpoint()
fig.suptitle(run_name, size=titlesize)
fig.tight_layout()
fig.savefig(root_dir / f"{run_name}-cf.png")


def _cls_infrence(preds, target, root_dir, run_name) -> None:
def _cls_inference(preds, target, root_dir, run_name) -> None:
metrics_input = dict(preds=preds,
target=target.int(),
task="binary")
Expand Down

0 comments on commit 9b0878f

Please sign in to comment.