diff --git a/src/gnn_tracking/graph_construction/k_scanner.py b/src/gnn_tracking/graph_construction/k_scanner.py index 3aa1b481..9722e3b3 100644 --- a/src/gnn_tracking/graph_construction/k_scanner.py +++ b/src/gnn_tracking/graph_construction/k_scanner.py @@ -15,10 +15,14 @@ from torch_geometric.data import Data from tqdm import tqdm -from gnn_tracking.analysis.graphs import get_largest_segment_fracs +from gnn_tracking.analysis.graphs import get_cc_labels, get_largest_segment_fracs +from gnn_tracking.metrics.cluster_metrics import ( + flatten_track_metrics, + tracking_metrics_data, +) from gnn_tracking.metrics.graph_construction import get_efficiency_purity_edges from gnn_tracking.models.graph_construction import knn_with_max_radius -from gnn_tracking.utils.dictionaries import pivot_record_list +from gnn_tracking.utils.dictionaries import add_key_prefix, pivot_record_list from gnn_tracking.utils.log import logger # ruff: noqa: ARG002 @@ -204,6 +208,17 @@ def __call__( break self._results.append(r) + def _evaluate_tracking_metrics_upper_bounds(self, data: Data) -> dict[str, float]: + y = data.y.bool() + ei = data.edge_index[:, y] + labels = get_cc_labels(ei, num_nodes=data.num_nodes) + return add_key_prefix( + flatten_track_metrics( + tracking_metrics_data(data, labels.detach().cpu().numpy(), [0.9]), + ), + "max_", + ) + def _evaluate_graph(self, data: Data, k: int) -> dict[str, float] | None: """Evaluate metrics for single graphs @@ -243,4 +258,5 @@ def _evaluate_graph(self, data: Data, k: int) -> dict[str, float] | None: **get_efficiency_purity_edges( data, pt_thld=self.hparams.pt_thld, max_eta=self.hparams.max_eta ), + **self._evaluate_tracking_metrics_upper_bounds(data), }