From e9908b273a5a2edcd11610c5d20b9f843e9198c9 Mon Sep 17 00:00:00 2001 From: 1pha <1phantasmas@korea.ac.kr> Date: Mon, 11 Mar 2024 08:26:27 +0000 Subject: [PATCH] Fix logger.name based cls/reg branching --- sage/trainer/utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sage/trainer/utils.py b/sage/trainer/utils.py index ac70259..27eee2a 100644 --- a/sage/trainer/utils.py +++ b/sage/trainer/utils.py @@ -97,16 +97,14 @@ def finalize_inference(prediction: list, run_name = save_name[:-4] + "_" + timestamp() preds, target = prediction["pred"], prediction["target"] infer_kwargs = dict(preds=preds, target=target, root_dir=root_dir, run_name=run_name) - if name.startswith("C"): - logger.info("Classification data given:") + result = dict(pred=preds, target=target) + if check_classification(result=result): + logger.info("Classification data given: ") metric = _cls_inference(**infer_kwargs) - elif name[0] in set(["R", "M"]): - logger.info("Regression data given:") + else: + logger.info("Regression data given: ") metric = _reg_inference(**infer_kwargs) _get_norm_cf_reg(**infer_kwargs) - else: - logger.info("Failed to inference. Check the run name for the task.") - metric = None return metric