Skip to content

Commit

Permalink
Fix logger.name based cls/reg branching
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 11, 2024
1 parent 0f5f551 commit e9908b2
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions sage/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,14 @@ def finalize_inference(prediction: list,
run_name = save_name[:-4] + "_" + timestamp()
preds, target = prediction["pred"], prediction["target"]
infer_kwargs = dict(preds=preds, target=target, root_dir=root_dir, run_name=run_name)
if name.startswith("C"):
logger.info("Classification data given:")
result = dict(pred=preds, target=target)
if check_classification(result=result):
logger.info("Classification data given: ")
metric = _cls_inference(**infer_kwargs)
elif name[0] in set(["R", "M"]):
logger.info("Regression data given:")
else:
logger.info("Regression data given: ")
metric = _reg_inference(**infer_kwargs)
_get_norm_cf_reg(**infer_kwargs)
else:
logger.info("Failed to inference. Check the run name for the task.")
metric = None
return metric


Expand Down

0 comments on commit e9908b2

Please sign in to comment.