Skip to content

Commit

Permalink
Remove redundant todos, change assert to warn
Browse files Browse the repository at this point in the history
  • Loading branch information
szmazurek committed Dec 6, 2024
1 parent 8ff6cbb commit bff4381
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
13 changes: 9 additions & 4 deletions GANDLF/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
All the metrics are to be called from here
"""
from warnings import warn
from typing import Union

from GANDLF.losses.regression import MSE_loss, CEL
Expand Down Expand Up @@ -115,10 +116,14 @@ def get_metrics(params: dict) -> dict:
metric_calculators = {}
for metric_name in params["metrics"]:
metric_name = metric_name.lower()
assert (
metric_name in global_metrics_dict
), f"Could not find the requested metric '{metric_name}'"
metric_calculators[metric_name] = global_metrics_dict[metric_name]
if metric_name not in global_metrics_dict:
warn(
f"Metric {metric_name} not found in global metrics dictionary, it will not be used.",
UserWarning,
)
continue
else:
metric_calculators[metric_name] = global_metrics_dict[metric_name]
return metric_calculators


Expand Down
1 change: 0 additions & 1 deletion GANDLF/metrics/metric_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(self, params):

def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args):
metric_results = {}
# TODO what do we do with edge case in GaNDLF/GANDLF/compute/loss_and_metric.py?
for metric_name, metric_calculator in self.metrics_calculators.items():
metric_value = (
metric_calculator(prediction, target, self.params).detach().cpu()
Expand Down
1 change: 0 additions & 1 deletion GANDLF/models/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def _initialize_loss(self):
self.loss = LossCalculatorFactory(self.params).get_loss_calculator()

def _initialize_metric_calculators(self):
# TODO can we have situation that metrics are empty?
self.metric_calculators = MetricCalculatorFactory(
self.params
).get_metric_calculator()
Expand Down

0 comments on commit bff4381

Please sign in to comment.