From bff43810815b21542e62ee8bb35ea348055f8f2c Mon Sep 17 00:00:00 2001 From: szmazurek Date: Fri, 6 Dec 2024 09:19:07 +0100 Subject: [PATCH] Remove redundant todos, change assert to warn --- GANDLF/metrics/__init__.py | 13 +++++++++---- GANDLF/metrics/metric_calculators.py | 1 - GANDLF/models/lightning_module.py | 1 - 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/GANDLF/metrics/__init__.py b/GANDLF/metrics/__init__.py index ab0e3bbfe..2a8255e20 100644 --- a/GANDLF/metrics/__init__.py +++ b/GANDLF/metrics/__init__.py @@ -1,6 +1,7 @@ """ All the metrics are to be called from here """ +from warnings import warn from typing import Union from GANDLF.losses.regression import MSE_loss, CEL @@ -115,10 +116,14 @@ def get_metrics(params: dict) -> dict: metric_calculators = {} for metric_name in params["metrics"]: metric_name = metric_name.lower() - assert ( - metric_name in global_metrics_dict - ), f"Could not find the requested metric '{metric_name}'" - metric_calculators[metric_name] = global_metrics_dict[metric_name] + if metric_name not in global_metrics_dict: + warn( + f"Metric {metric_name} not found in global metrics dictionary, it will not be used.", + UserWarning, + ) + continue + else: + metric_calculators[metric_name] = global_metrics_dict[metric_name] return metric_calculators diff --git a/GANDLF/metrics/metric_calculators.py b/GANDLF/metrics/metric_calculators.py index f6a7165d2..d1bac8c61 100644 --- a/GANDLF/metrics/metric_calculators.py +++ b/GANDLF/metrics/metric_calculators.py @@ -32,7 +32,6 @@ def __init__(self, params): def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args): metric_results = {} - # TODO what do we do with edge case in GaNDLF/GANDLF/compute/loss_and_metric.py? for metric_name, metric_calculator in self.metrics_calculators.items(): metric_value = ( metric_calculator(prediction, target, self.params).detach().cpu() diff --git a/GANDLF/models/lightning_module.py b/GANDLF/models/lightning_module.py index 147e7b70b..cb7a7480d 100644 --- a/GANDLF/models/lightning_module.py +++ b/GANDLF/models/lightning_module.py @@ -25,7 +25,6 @@ def _initialize_loss(self): self.loss = LossCalculatorFactory(self.params).get_loss_calculator() def _initialize_metric_calculators(self): - # TODO can we have situation that metrics are empty? self.metric_calculators = MetricCalculatorFactory( self.params ).get_metric_calculator()