From 6313095e8868c53c02443e972d4bf984c5c7072b Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 2 Apr 2024 19:27:35 +0000 Subject: [PATCH] feat: isolated proposal detection --- .../groundtruth_generation.py | 68 +++++++--- src/deep_neurographs/neurograph.py | 126 ++++++++++++------ 2 files changed, 131 insertions(+), 63 deletions(-) diff --git a/src/deep_neurographs/machine_learning/groundtruth_generation.py b/src/deep_neurographs/machine_learning/groundtruth_generation.py index 429856d..8aeccdf 100644 --- a/src/deep_neurographs/machine_learning/groundtruth_generation.py +++ b/src/deep_neurographs/machine_learning/groundtruth_generation.py @@ -17,7 +17,7 @@ from deep_neurographs import utils from deep_neurographs.geometry import dist as get_dist -CLOSE_THRESHOLD = 3.5 +ALIGNED_THRESHOLD = 3.5 MIN_INTERSECTION = 10 @@ -42,10 +42,10 @@ def get_valid_proposals(target_neurograph, pred_neurograph): invalid_proposals = set() node_to_target = dict() for component in nx.connected_components(pred_neurograph): - aligned_bool, target_id = is_component_aligned( + aligned, target_id = is_component_aligned( target_neurograph, pred_neurograph, component ) - if not aligned_bool: + if not aligned: i = utils.sample_singleton(component) invalid_proposals.add(pred_neurograph.nodes[i]["swc_id"]) else: @@ -104,31 +104,50 @@ def is_component_aligned(target_neurograph, pred_neurograph, component): d = get_dist(hat_xyz, xyz) dists = utils.append_dict_value(dists, hat_swc_id, d) - # Check whether there's a merge - hits = [] - for key in dists.keys(): - if len(dists[key]) > 10 and np.mean(dists[key]) < CLOSE_THRESHOLD: - hits.append(key) - if len(hits) > 1: - print("pred_swc_id:", pred_neurograph.edges[edge]["swc_id"]) - print("target_swc_id:", list(dists.keys())) - print("") - # Deterine whether aligned hat_swc_id = utils.find_best(dists) dists = np.array(dists[hat_swc_id]) - aligned_score = np.mean(dists[dists < np.percentile(dists, 90)]) - if aligned_score < 4 and hat_swc_id: + intersects = True if len(dists) > MIN_INTERSECTION else False + aligned_score = np.mean(dists[dists < np.percentile(dists, 85)]) + if (aligned_score < ALIGNED_THRESHOLD and hat_swc_id) and intersects: return True, hat_swc_id else: return False, None def is_valid(target_neurograph, pred_neurograph, target_id, edge): + """ + Determines whether a proposal is valid, meaning it must be consistent and + aligned. + + Parameters + ---------- + target_neurograph : NeuroGraph + Graph built from ground truth swc files. + pred_neurograph : NeuroGraph + Graph build from predicted swc files. + target_id : str + swc id of target that the proposal "edge" corresponds to. + edge : frozenset + Edge proposal to be checked. + + Returns + ------- + bool + Indication of whether proposal is consistent + """ + #aligned = is_proposal_aligned(target_neurograph, pred_neurograph, edge) + consistent = is_consistent( + target_neurograph, pred_neurograph, target_id, edge + ) + return True if consistent else False + + +def is_consistent(target_neurograph, pred_neurograph, target_id, edge): """ Determines whether the proposal connects two branches that correspond to either the same or adjacent branches on the ground truth. If either - condition holds, then the proposal is said to be valid. + condition holds, then the proposal is said to be consistent. Parameters ---------- @@ -137,14 +156,14 @@ def is_valid(target_neurograph, pred_neurograph, target_id, edge): pred_neurograph : NeuroGraph Graph build from predicted swc files. target_id : str - ... + swc id of target that the proposal "edge" corresponds to. edge : frozenset Edge proposal to be checked. Returns ------- bool - Indication of whether proposal is valid + Indication of whether proposal is consistent """ # Find closest edges from target_neurograph @@ -169,6 +188,15 @@ def is_valid(target_neurograph, pred_neurograph, target_id, edge): return False +def is_proposal_aligned(target_neurograph, pred_neurograph, edge): + xyz_0, xyz_1 = pred_neurograph.proposal_xyz(edge) + proj_dists = [] + for xyz in geometry.make_line(xyz_0, xyz_1, 10): + hat_xyz = target_neurograph.get_projection(tuple(xyz)) + proj_dists.append(get_dist(hat_xyz, xyz)) + return True if np.mean(proj_dists) < ALIGNED_THRESHOLD else False + + def proj_branch(target_neurograph, pred_neurograph, target_id, i): # Compute projections hits = dict() @@ -196,10 +224,6 @@ def proj_branch(target_neurograph, pred_neurograph, target_id, i): return best_edge -def is_proposal_aligned(target_neurograph, pred_neurograph, edge): - pass - - def is_adjacent(neurograph, edge_i, edge_j): """ Determines whether "edge_i" and "edge_j" are adjacent, meaning there diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 8957a4e..1776bff 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -22,6 +22,7 @@ from deep_neurographs.generate_proposals import is_valid from deep_neurographs.geometry import check_dists from deep_neurographs.geometry import dist as get_dist +from deep_neurographs.geometry import get_midpoint from deep_neurographs.machine_learning.groundtruth_generation import ( init_targets, ) @@ -85,7 +86,7 @@ def copy_graph(self, add_attrs=False): graph.add_edges_from(self.edges) return graph - # --- Add to graph -- + # --- Edit Graph -- def add_component(self, irreducibles): """ Adds a connected component to "self". @@ -187,6 +188,21 @@ def __add_edge(self, edge, attrs, idxs, swc_id): for xyz in attrs["xyz"][idxs]: self.xyz_to_edge[tuple(xyz)] = edge + def absorb_node(self, i, nb_1, nb_2): + # Get attributes + xyz = self.get_branches(i, key="xyz") + radius = self.get_branches(i, key="radius") + + # Edit graph + self.remove_node(i) + self.add_edge( + nb_1, + nb_2, + xyz=np.vstack([np.flip(xyz[1], axis=0), xyz[0][1:, :]]), + radius=np.concatenate((radius[0], np.flip(radius[1]))), + swc_id=self.nodes[nb_1]["swc_id"], + ) + def split_edge(self, edge, attrs, idx): """ Splits "edge" into two distinct edges by making the subnode at "idx" a @@ -248,11 +264,13 @@ def add_proposal(self, i, j): edge = frozenset((i, j)) self.nodes[i]["proposals"].add(j) self.nodes[j]["proposals"].add(i) + self.xyz_to_proposal[tuple(self.nodes[i]["xyz"])] = edge + self.xyz_to_proposal[tuple(self.nodes[j]["xyz"])] = edge self.proposals[edge] = { "xyz": np.array([self.nodes[i]["xyz"], self.nodes[j]["xyz"]]) } - # --- Proposal and Ground Truth Generation --- + # --- Proposal Generation --- def generate_proposals( self, radius, @@ -312,11 +330,13 @@ def generate_proposals( # Finish self.filter_nodes() + self.init_proposal_kdtree() if optimize: self.run_optimization() def reset_proposals(self): self.proposals = dict() + self.xyz_to_proposal = dict() for i in self.nodes: self.nodes[i]["proposals"] = set() @@ -351,7 +371,7 @@ def run_optimization(self): xyz_1, xyz_2 = geometry.optimize_alignment(self, img, edge) self.proposals[edge]["xyz"] = np.array([xyz_1, xyz_2]) - # -- kdtree -- + # -- KDTree -- def init_kdtree(self): """ Builds a KD-Tree from the (x,y,z) coordinates of the subnodes of @@ -366,8 +386,22 @@ def init_kdtree(self): None """ - if not self.kdtree: - self.kdtree = KDTree(list(self.xyz_to_edge.keys())) + self.kdtree = KDTree(list(self.xyz_to_edge.keys())) + + def init_proposal_kdtree(self): + """ + Builds a KD-Tree from the (x,y,z) coordinates of the proposals. + + Parameters + ---------- + None + + Returns + ------- + None + + """ + self.proposal_kdtree = KDTree(list(self.xyz_to_proposal.keys())) def query_kdtree(self, xyz, d): """ @@ -392,7 +426,7 @@ def get_projection(self, xyz): _, idx = self.kdtree.query(xyz, k=1) return tuple(self.kdtree.data[idx]) - # --- utils --- + # --- Proposal Utils --- def n_proposals(self): """ Computes number of edges proposals in the graph. @@ -412,13 +446,58 @@ def n_proposals(self): def get_proposals(self): return list(self.proposals.keys()) + def get_simple_proposals(self): + return set([e for e in self.get_proposals() if self.is_simple(e)]) + + def get_complex_proposals(self): + return set([e for e in self.get_proposals() if not self.is_simple(e)]) + + def get_isolated_proposals(self, radius): + isolated_proposals = set() + for edge in self.proposals.keys(): + xyz = self.proposal_midpoint(edge) + if len(self.proposal_kdtree.query_ball_point(xyz, radius)) <= 2: + isolated_proposals.add(edge) + return isolated_proposals + + def is_simple(self, edge): + i, j = tuple(edge) + return True if self.is_leaf(i) and self.is_leaf(j) else False + def proposal_xyz(self, edge): return tuple(self.proposals[edge]["xyz"]) def proposal_length(self, edge): i, j = tuple(edge) return get_dist(self.nodes[i]["xyz"], self.nodes[j]["xyz"]) + + def proposal_midpoint(self, edge): + i, j = tuple(edge) + return get_midpoint(self.nodes[i]["xyz"], self.nodes[j]["xyz"]) + def merge_proposal(self, edge): + # Attributes + i, j = tuple(edge) + xyz = np.vstack([self.nodes[i]["xyz"], self.nodes[j]["xyz"]]) + radius = np.array([self.nodes[i]["radius"], self.nodes[j]["radius"]]) + + # Add + self.add_edge(i, j, xyz=xyz, radius=radius, swc_id="merged") + del self.proposals[edge] + # delete from kdtree + + def remove_nonisolated_proposals(self, radius): + isolated_proposals = self.get_isolated_proposals(radius) + proposals = self.get_proposals() + while len(proposals) > 0: + edge = proposals.pop() + if edge not in isolated_proposals: + i, j = tuple(edge) + self.nodes[i]["proposals"].remove(j) + self.nodes[j]["proposals"].remove(i) + del self.proposals[edge] + + # --- Utils --- def get_branches(self, i, key="xyz"): branches = [] for j in self.neighbors(i): @@ -464,16 +543,6 @@ def get_edge_attr(self, edge, key): xyz_arr = gutils.get_edge_attr(self, edge, key) return xyz_arr[0], xyz_arr[-1] - def get_complex_proposals(self): - return set([e for e in self.get_proposals() if not self.is_simple(e)]) - - def get_simple_proposals(self): - return set([e for e in self.get_proposals() if self.is_simple(e)]) - - def is_simple(self, edge): - i, j = tuple(edge) - return True if self.is_leaf(i) and self.is_leaf(j) else False - def to_patch_coords(self, edge, midpoint, chunk_size): patch_coords = [] for xyz in self.edges[edge]["xyz"]: @@ -582,31 +651,6 @@ def filter_nodes(self): nbs = list(self.neighbors(i)) self.absorb_node(i, nbs[0], nbs[1]) - def absorb_node(self, i, nb_1, nb_2): - # Get attributes - xyz = self.get_branches(i, key="xyz") - radius = self.get_branches(i, key="radius") - - # Edit graph - self.remove_node(i) - self.add_edge( - nb_1, - nb_2, - xyz=np.vstack([np.flip(xyz[1], axis=0), xyz[0][1:, :]]), - radius=np.concatenate((radius[0], np.flip(radius[1]))), - swc_id=self.nodes[nb_1]["swc_id"], - ) - - def merge_proposal(self, edge): - # Attributes - i, j = tuple(edge) - xyz = np.vstack([self.nodes[i]["xyz"], self.nodes[j]["xyz"]]) - radius = np.array([self.nodes[i]["radius"], self.nodes[j]["radius"]]) - - # Add - self.add_edge(i, j, xyz=xyz, radius=radius, swc_id="merged") - del self.proposals[edge] - def to_swc(self, path): with ThreadPoolExecutor() as executor: threads = []