Skip to content

Commit

Permalink
k-scanner to give EC upper limits
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret committed Dec 5, 2023
1 parent f21a06c commit eaef425
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/gnn_tracking/graph_construction/k_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
from torch_geometric.data import Data
from tqdm import tqdm

from gnn_tracking.analysis.graphs import get_largest_segment_fracs
from gnn_tracking.analysis.graphs import get_cc_labels, get_largest_segment_fracs
from gnn_tracking.metrics.cluster_metrics import (
flatten_track_metrics,
tracking_metrics_data,
)
from gnn_tracking.metrics.graph_construction import get_efficiency_purity_edges
from gnn_tracking.models.graph_construction import knn_with_max_radius
from gnn_tracking.utils.dictionaries import pivot_record_list
from gnn_tracking.utils.dictionaries import add_key_prefix, pivot_record_list
from gnn_tracking.utils.log import logger

# ruff: noqa: ARG002
Expand Down Expand Up @@ -204,6 +208,17 @@ def __call__(
break
self._results.append(r)

def _evaluate_tracking_metrics_upper_bounds(self, data: Data) -> dict[str, float]:
y = data.y.bool()
ei = data.edge_index[:, y]
labels = get_cc_labels(ei, num_nodes=data.num_nodes)
return add_key_prefix(
flatten_track_metrics(
tracking_metrics_data(data, labels.detach().cpu().numpy(), [0.9]),
),
"max_",
)

def _evaluate_graph(self, data: Data, k: int) -> dict[str, float] | None:
"""Evaluate metrics for single graphs
Expand Down Expand Up @@ -243,4 +258,5 @@ def _evaluate_graph(self, data: Data, k: int) -> dict[str, float] | None:
**get_efficiency_purity_edges(
data, pt_thld=self.hparams.pt_thld, max_eta=self.hparams.max_eta
),
**self._evaluate_tracking_metrics_upper_bounds(data),
}

0 comments on commit eaef425

Please sign in to comment.