diff --git a/src/deep_neurographs/graph_artifact_removal.py b/src/deep_neurographs/fragment_filtering.py similarity index 63% rename from src/deep_neurographs/graph_artifact_removal.py rename to src/deep_neurographs/fragment_filtering.py index 30dd784..c5ac37d 100644 --- a/src/deep_neurographs/graph_artifact_removal.py +++ b/src/deep_neurographs/fragment_filtering.py @@ -5,7 +5,7 @@ @email: anna.grim@alleninstitute.org Module that removes doubled fragments and trims branches that pass by each -other from a NeuroGraph. +other from a FragmentsGraph. """ from collections import defaultdict @@ -21,57 +21,70 @@ QUERY_DIST = 15 +# --- Curvy Removal --- +def remove_curvy(graph, max_length, ratio=0.5): + deleted_ids = set() + components = [c for c in connected_components(graph) if len(c) == 2] + for nodes in tqdm(components, desc="Curvy Filter:"): + if len(nodes) == 2: + i, j = tuple(nodes) + length = graph.edges[i, j]["length"] + endpoint_dist = graph.dist(i, j) + if endpoint_dist / length < ratio and length < max_length: + deleted_ids.add(graph.edges[i, j]["swc_id"]) + delete_fragment(graph, i, j) + return len(deleted_ids) + + # --- Doubles Removal --- -def remove_doubles(neurograph, max_size, node_spacing, output_dir=None): +def remove_doubles(graph, max_length, node_spacing, output_dir=None): """ Removes connected components from "neurgraph" that are likely to be a double. Parameters ---------- - neurograph : NeuroGraph + graph : FragmentsGraph Graph to be searched for doubles. - max_size : int + max_length : int Maximum size of connected components to be searched. node_spacing : int - Expected distance in microns between nodes in "neurograph". + Expected distance in microns between nodes in "graph". output_dir : str or None, optional Directory that doubles will be written to. The default is None. Returns ------- - NeuroGraph + graph Graph with doubles removed. """ # Initializations - components = [c for c in connected_components(neurograph) if len(c) == 2] - deleted = set() - kdtree = neurograph.get_kdtree() + components = [c for c in connected_components(graph) if len(c) == 2] + deleted_ids = set() + kdtree = graph.get_kdtree() if output_dir: util.mkdir(output_dir, delete=True) # Main - desc = "Doubles Detection" + desc = "Doubles Filtering" for idx in tqdm(np.argsort([len(c) for c in components]), desc=desc): i, j = tuple(components[idx]) - swc_id = neurograph.nodes[i]["swc_id"] - if swc_id not in deleted: - if len(neurograph.edges[i, j]["xyz"]) * node_spacing < max_size: + swc_id = graph.nodes[i]["swc_id"] + if swc_id not in deleted_ids: + if graph.edges[i, j]["length"] < max_length: # Check doubles criteria - n_points = len(neurograph.edges[i, j]["xyz"]) - hits = compute_projections(neurograph, kdtree, (i, j)) + n_points = len(graph.edges[i, j]["xyz"]) + hits = compute_projections(graph, kdtree, (i, j)) if check_doubles_criteria(hits, n_points): if output_dir: - neurograph.to_swc( - output_dir, components[idx], color=COLOR - ) - neurograph = delete(neurograph, components[idx], swc_id) - deleted.add(swc_id) - return len(deleted) + graph.to_swc(output_dir, [i, j], color=COLOR) + delete_fragment(graph, i, j) + deleted_ids.add(swc_id) + return len(deleted_ids) -def compute_projections(neurograph, kdtree, edge): +def compute_projections(graph, kdtree, edge): """ Given a fragment defined by "edge", this routine iterates of every xyz in the fragment and projects it onto the closest fragment. For each detected @@ -80,11 +93,11 @@ def compute_projections(neurograph, kdtree, edge): Parameters ---------- - neurograph : NeuroGraph + graph : graph Graph that contains "edge". kdtree : KDTree KD-Tree that contains all xyz coordinates of every fragment in - "neurograph". + "graph". edge : tuple Pair of leaf nodes that define a fragment. @@ -96,13 +109,13 @@ def compute_projections(neurograph, kdtree, edge): """ hits = defaultdict(list) - query_id = neurograph.edges[edge]["swc_id"] - for i, xyz in enumerate(neurograph.edges[edge]["xyz"]): + query_id = graph.edges[edge]["swc_id"] + for i, xyz in enumerate(graph.edges[edge]["xyz"]): # Compute projections best_id = None best_dist = np.inf for hit_xyz in geometry.query_ball(kdtree, xyz, QUERY_DIST): - hit_id = neurograph.xyz_to_swc(hit_xyz) + hit_id = graph.xyz_to_swc(hit_xyz) if hit_id is not None and hit_id != query_id: if geometry.dist(hit_xyz, xyz) < best_dist: best_dist = geometry.dist(hit_xyz, xyz) @@ -144,56 +157,54 @@ def check_doubles_criteria(hits, n_points): return False -def delete(neurograph, nodes, swc_id): +def delete_fragment(graph, i, j): """ - Deletes "nodes" from "neurograph". + Deletes nodes "i" and "j" from "graph", where these nodes form a connected + component. Parameters ---------- - neurograph : NeuroGraph - Graph that contains "nodes". - nodes : list[int] - Nodes to be removed. - swc_id : str - swc id corresponding to nodes which comprise a connected component in - "neurograph". + graph : FragmentsGraph + Graph that contains nodes to be deleted. + i : int + Node to be removed. + j : int + Node to be removed. Returns ------- - NeuroGraph + graph Graph with nodes removed. """ - i, j = tuple(nodes) - neurograph = remove_xyz_entries(neurograph, i, j) - neurograph.remove_nodes_from([i, j]) - neurograph.swc_ids.remove(swc_id) - return neurograph + graph = remove_xyz_entries(graph, i, j) + graph.swc_ids.remove(graph.nodes[i]["swc_id"]) + graph.remove_nodes_from([i, j]) -def remove_xyz_entries(neurograph, i, j): +def remove_xyz_entries(graph, i, j): """ - Removes dictionary entries from "neurograph.xyz_to_edge" corresponding to + Removes dictionary entries from "graph.xyz_to_edge" corresponding to the edge {i, j}. Parameters ---------- - neurograph : NeuroGraph + graph : graph Graph to be updated. i : int - Node in "neurograph". + Node in "graph". j : int - Node in "neurograph". + Node in "graph". Returns ------- - NeuroGraph + graph Updated graph. """ - for xyz in neurograph.edges[i, j]["xyz"]: - del neurograph.xyz_to_edge[tuple(xyz)] - return neurograph + for xyz in graph.edges[i, j]["xyz"]: + del graph.xyz_to_edge[tuple(xyz)] + return graph def upd_hits(hits, key, value): diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index 9267b8a..1f23c0d 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -21,7 +21,7 @@ def run( - neurograph, + graph, radius, complex_bool=False, long_range_bool=True, @@ -33,7 +33,7 @@ def run( Parameters ---------- - neurograph : NeuroGraph + graph : FragmentsGraph Graph that proposals will be generated for. radius : float Maximum Euclidean distance between endpoints of proposal. @@ -58,29 +58,29 @@ def run( """ # Initializations connections = dict() - kdtree = init_kdtree(neurograph, complex_bool) + kdtree = init_kdtree(graph, complex_bool) radius *= RADIUS_SCALING_FACTOR if trim_endpoints_bool else 1.0 if progress_bar: - iterable = tqdm(neurograph.get_leafs(), desc="Proposals") + iterable = tqdm(graph.get_leafs(), desc="Proposals") else: - iterable = neurograph.get_leafs() + iterable = graph.get_leafs() # Main for leaf in iterable: # Generate potential proposals candidates = get_candidates( - neurograph, + graph, leaf, kdtree, radius, - neurograph.proposals_per_leaf, + graph.proposals_per_leaf, complex_bool, ) # Generate long range proposals (if applicable) if len(candidates) == 0 and long_range_bool: candidates = get_candidates( - neurograph, + graph, leaf, kdtree, radius * RADIUS_SCALING_FACTOR, @@ -90,36 +90,39 @@ def run( # Determine which potential proposals to keep for i in candidates: - leaf_swc_id = neurograph.nodes[leaf]["swc_id"] - pair_id = frozenset((leaf_swc_id, neurograph.nodes[i]["swc_id"])) + leaf_swc_id = graph.nodes[leaf]["swc_id"] + pair_id = frozenset((leaf_swc_id, graph.nodes[i]["swc_id"])) if pair_id in connections.keys(): cur_proposal = connections[pair_id] - cur_dist = neurograph.proposal_length(cur_proposal) - if neurograph.dist(leaf, i) < cur_dist: - neurograph.remove_proposal(cur_proposal) + cur_dist = graph.proposal_length(cur_proposal) + if graph.dist(leaf, i) < cur_dist: + graph.remove_proposal(cur_proposal) del connections[pair_id] else: continue # Add proposal - neurograph.add_proposal(leaf, i) + graph.add_proposal(leaf, i) connections[pair_id] = frozenset({leaf, i}) # Trim endpoints (if applicable) + n_trimmed = 0 if trim_endpoints_bool: radius /= RADIUS_SCALING_FACTOR - long_range, in_range = separate_proposals(neurograph, radius) - neurograph = run_trimming(neurograph, long_range, radius, progress_bar) - neurograph = run_trimming(neurograph, in_range, radius, progress_bar) + long_range, in_range = separate_proposals(graph, radius) + graph, n_trimmed_1 = run_trimming(graph, long_range, radius) + graph, n_trimmed_2 = run_trimming(graph, in_range, radius) + n_trimmed = n_trimmed_1 + n_trimmed_2 + return n_trimmed -def init_kdtree(neurograph, complex_bool): +def init_kdtree(graph, complex_bool): """ Initializes a KD-Tree used to generate proposals. Parameters ---------- - neurograph : NeuroGraph + graph : FragmentsGraph Graph that proposals will be generated for. complex_bool : bool Indication of whether to generate complex proposals. @@ -131,37 +134,37 @@ def init_kdtree(neurograph, complex_bool): """ if complex_bool: - return neurograph.get_kdtree() + return graph.get_kdtree() else: - return neurograph.get_kdtree(node_type="leaf") + return graph.get_kdtree(node_type="leaf") def get_candidates( - neurograph, leaf, kdtree, radius, max_proposals, complex_bool + graph, leaf, kdtree, radius, max_proposals, complex_bool ): # Generate candidates candidates = list() - for xyz in search_kdtree(neurograph, leaf, kdtree, radius, max_proposals): - i = get_connecting_node(neurograph, leaf, xyz, radius, complex_bool) + for xyz in search_kdtree(graph, leaf, kdtree, radius, max_proposals): + i = get_connecting_node(graph, leaf, xyz, radius, complex_bool) if i is not None: - if neurograph.is_valid_proposal(leaf, i, complex_bool): + if graph.is_valid_proposal(leaf, i, complex_bool): candidates.append(i) # Process the results if max_proposals < 0 and len(candidates) == 1: - return candidates if neurograph.is_leaf(candidates[0]) else [] + return candidates if graph.is_leaf(candidates[0]) else [] else: return list() if max_proposals < 0 else candidates -def search_kdtree(neurograph, leaf, kdtree, radius, max_proposals): +def search_kdtree(graph, leaf, kdtree, radius, max_proposals): """ - Generates proposals for node "leaf" in "neurograph" by finding candidate + Generates proposals for node "leaf" in "graph" by finding candidate xyz points on distinct connected components nearby. Parameters ---------- - neurograph : NeuroGraph + graph : FragmentsGraph Graph built from swc files. kdtree : ... ... @@ -180,10 +183,10 @@ def search_kdtree(neurograph, leaf, kdtree, radius, max_proposals): """ # Generate candidates candidates = dict() - leaf_xyz = neurograph.nodes[leaf]["xyz"] + leaf_xyz = graph.nodes[leaf]["xyz"] for xyz in geometry.query_ball(kdtree, leaf_xyz, radius): - swc_id = neurograph.xyz_to_swc(xyz) - if swc_id != neurograph.nodes[leaf]["swc_id"]: + swc_id = graph.xyz_to_swc(xyz) + if swc_id != graph.nodes[leaf]["swc_id"]: d = geometry.dist(leaf_xyz, xyz) if swc_id not in candidates.keys(): candidates[swc_id] = {"dist": d, "xyz": tuple(xyz)} @@ -228,13 +231,13 @@ def get_best(candidates, max_proposals): return list_candidates_xyz(candidates) -def get_connecting_node(neurograph, leaf, xyz, radius, complex_bool): +def get_connecting_node(graph, leaf, xyz, radius, complex_bool): """ Gets the node that proposal with leaf will connect to. Parameters ---------- - neurograph : NeuroGraph + graph : FragmentsGraph Graph containing "leaf". leaf : int Leaf node. @@ -247,25 +250,25 @@ def get_connecting_node(neurograph, leaf, xyz, radius, complex_bool): Node id. """ - edge = neurograph.xyz_to_edge[xyz] - node = get_closer_endpoint(neurograph, edge, xyz) - if neurograph.dist(leaf, node) < radius: + edge = graph.xyz_to_edge[xyz] + node = get_closer_endpoint(graph, edge, xyz) + if graph.dist(leaf, node) < radius: return node elif complex_bool: - attrs = neurograph.get_edge_data(*edge) + attrs = graph.get_edge_data(*edge) idx = np.where(np.all(attrs["xyz"] == xyz, axis=1))[0][0] if type(idx) is int: - return neurograph.split_edge(edge, attrs, idx) + return graph.split_edge(edge, attrs, idx) return None -def get_closer_endpoint(neurograph, edge, xyz): +def get_closer_endpoint(graph, edge, xyz): """ Gets the node from "edge" that is closer to "xyz". Parameters ---------- - neurograph : NeuroGraph + graph : FragmentsGraph Graph containing "edge". edge : tuple Edge to be checked. @@ -279,17 +282,17 @@ def get_closer_endpoint(neurograph, edge, xyz): """ i, j = tuple(edge) - d_i = geometry.dist(neurograph.nodes[i]["xyz"], xyz) - d_j = geometry.dist(neurograph.nodes[j]["xyz"], xyz) + d_i = geometry.dist(graph.nodes[i]["xyz"], xyz) + d_j = geometry.dist(graph.nodes[j]["xyz"], xyz) return i if d_i < d_j else j -def separate_proposals(neurograph, radius): +def separate_proposals(graph, radius): long_range_proposals = list() proposals = list() - for proposal in neurograph.proposals: + for proposal in graph.proposals: i, j = tuple(proposal) - if neurograph.dist(i, j) > radius: + if graph.dist(i, j) > radius: long_range_proposals.append(proposal) else: proposals.append(proposal) @@ -297,30 +300,28 @@ def separate_proposals(neurograph, radius): # --- Trim Endpoints --- -def run_trimming(neurograph, proposals, radius, progress_bar): +def run_trimming(graph, proposals, radius): n_endpoints_trimmed = 0 long_radius = radius * RADIUS_SCALING_FACTOR for proposal in deepcopy(proposals): i, j = tuple(proposal) - is_simple = neurograph.is_simple(proposal) - is_single = neurograph.is_single_proposal(proposal) + is_simple = graph.is_simple(proposal) + is_single = graph.is_single_proposal(proposal) trim_bool = False if is_simple and is_single: - neurograph, trim_bool = trim_endpoints( - neurograph, i, j, long_radius + graph, trim_bool = trim_endpoints( + graph, i, j, long_radius ) - elif neurograph.dist(i, j) > radius: - neurograph.remove_proposal(proposal) + elif graph.dist(i, j) > radius: + graph.remove_proposal(proposal) n_endpoints_trimmed += 1 if trim_bool else 0 - if progress_bar: - print("# Endpoints Trimmed:", n_endpoints_trimmed) - return neurograph + return graph, n_endpoints_trimmed -def trim_endpoints(neurograph, i, j, radius): +def trim_endpoints(graph, i, j, radius): # Initializations - branch_i = neurograph.branch(i) - branch_j = neurograph.branch(j) + branch_i = graph.branch(i) + branch_j = graph.branch(j) # Check both orderings idx_i, idx_j = trim_endpoints_ordered(branch_i, branch_j) @@ -333,14 +334,14 @@ def trim_endpoints(neurograph, i, j, radius): # Update branches (if applicable) if min(d1, d2) > radius: - neurograph.remove_proposal(frozenset((i, j))) - return neurograph, False + graph.remove_proposal(frozenset((i, j))) + return graph, False elif min(d1, d2) + 2 < geometry.dist(branch_i[0], branch_j[0]): if compute_dot(branch_i, branch_j, idx_i, idx_j) < DOT_THRESHOLD: - neurograph = trim_to_idx(neurograph, i, idx_i) - neurograph = trim_to_idx(neurograph, j, idx_j) - return neurograph, True - return neurograph, False + graph = trim_to_idx(graph, i, idx_i) + graph = trim_to_idx(graph, j, idx_j) + return graph, True + return graph, False def trim_endpoints_ordered(branch_1, branch_2): @@ -375,13 +376,13 @@ def trim_endpoint(branch_1, branch_2): return 0 if best_idx is None else best_idx -def trim_to_idx(neurograph, i, idx): +def trim_to_idx(graph, i, idx): """ Trims the branch emanating from "i". Parameters ---------- - neurograph : NeuroGraph + graph : FragmentsGraph Graph containing node "i" i : int Leaf node. @@ -394,21 +395,21 @@ def trim_to_idx(neurograph, i, idx): """ # Update node - branch_xyz = neurograph.branch(i, key="xyz") - branch_radii = neurograph.branch(i, key="radius") - neurograph.nodes[i]["xyz"] = branch_xyz[idx] - neurograph.nodes[i]["radius"] = branch_radii[idx] + branch_xyz = graph.branch(i, key="xyz") + branch_radii = graph.branch(i, key="radius") + graph.nodes[i]["xyz"] = branch_xyz[idx] + graph.nodes[i]["radius"] = branch_radii[idx] # Update edge - j = neurograph.leaf_neighbor(i) - neurograph.edges[i, j]["xyz"] = branch_xyz[idx::] - neurograph.edges[i, j]["radius"] = branch_radii[idx::] + j = graph.leaf_neighbor(i) + graph.edges[i, j]["xyz"] = branch_xyz[idx::] + graph.edges[i, j]["radius"] = branch_radii[idx::] for k in range(idx): try: - del neurograph.xyz_to_edge[tuple(branch_xyz[k])] + del graph.xyz_to_edge[tuple(branch_xyz[k])] except KeyError: pass - return neurograph + return graph # --- utils --- @@ -439,9 +440,9 @@ def compute_dot(branch_1, branch_2, idx_1, idx_2): Parameters ---------- branch_1 : numpy.ndarray - xyz coordinates of some branch from a neurograph. + xyz coordinates of some branch from a graph. branch_2 : numpy.ndarray - xyz coordinates of some branch from a neurograph. + xyz coordinates of some branch from a graph. idx_1 : int Index that "branch_1" would be trimmed to (i.e. xyz coordinates from 0 to "idx_1" would be deleted from "branch_1"). diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index 32f59d6..0aaaeae 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -19,7 +19,7 @@ from torch.nn.functional import sigmoid from tqdm import tqdm -from deep_neurographs.graph_artifact_removal import remove_doubles +from deep_neurographs import fragment_filtering from deep_neurographs.machine_learning.feature_generation import ( FeatureGenerator, ) @@ -221,11 +221,15 @@ def build_graph(self, fragments_pointer): self.graph = graph_builder.run(fragments_pointer) # Remove doubles (if applicable) + n_curvy = fragment_filtering.remove_curvy(self.graph, 100) + n_curvy = util.reformat_number(n_curvy) if self.graph_config.remove_doubles_bool: - n_doubles = remove_doubles( + n_doubles = fragment_filtering.remove_doubles( self.graph, 200, self.graph_config.node_spacing ) - self.report(f"# Doubles Detected: {n_doubles}") + n_doubles = util.reformat_number(n_doubles) + self.report(f"# Double Fragments Deleted: {n_doubles}") + self.report(f"# Curvy Fragments Deleted: {n_curvy}") # Save valid labels and current graph swcs_path = os.path.join(self.output_dir, "processed-swcs.zip") @@ -238,7 +242,7 @@ def build_graph(self, fragments_pointer): t, unit = util.time_writer(time() - t0) self.report(f"Module Runtime: {round(t, 4)} {unit}") - # Report graph overview + # Report graph overview self.report("\nInitial Graph...") self.report_graph() @@ -263,17 +267,20 @@ def generate_proposals(self, radius=None): # Main t0 = time() - self.graph.generate_proposals( + n_trimmed = self.graph.generate_proposals( radius, complex_bool=self.graph_config.complex_bool, long_range_bool=self.graph_config.long_range_bool, proposals_per_leaf=self.graph_config.proposals_per_leaf, trim_endpoints_bool=self.graph_config.trim_endpoints_bool, ) + n_proposals = util.reformat_number(self.graph.n_proposals()) + n_trimmed = util.reformat_number(n_trimmed) # Report results t, unit = util.time_writer(time() - t0) + self.report(f"# Trimmed: {n_trimmed}") self.report(f"# Proposals: {n_proposals}") self.report(f"Module Runtime: {round(t, 4)} {unit}\n") @@ -575,9 +582,7 @@ def get_batch_dataset(self, neurograph, batch): ... """ - #t0 = time() features = self.feature_generator.run(neurograph, batch, self.radius) - #print("Feature Generation:", time() - t0) computation_graph = batch["graph"] if type(batch) is dict else None dataset = ml_util.init_dataset( neurograph, diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index 6a98427..6cc3053 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -9,9 +9,9 @@ """ import re + import torch import torch.nn.init as init - from torch import nn from torch.nn import Dropout, LeakyReLU from torch_geometric.nn import GATv2Conv as GATConv @@ -255,6 +255,7 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict): # --- Utils --- def reformat_edge_key(key): + print(key) return tuple([rm_non_alphanumeric(s) for s in key.split(",")]) diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 09d5047..ab82dd2 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -337,7 +337,7 @@ def generate_proposals( # Main self.reset_proposals() self.set_proposals_per_leaf(proposals_per_leaf) - generate_proposals.run( + n_trimmed = generate_proposals.run( self, search_radius, complex_bool=complex_bool, @@ -351,6 +351,7 @@ def generate_proposals( self.gt_accepts = init_targets(self, groundtruth_graph) else: self.gt_accepts = set() + return n_trimmed def reset_proposals(self): """ diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index cd316dc..d97663c 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -37,7 +37,7 @@ MIN_SIZE = 30 NODE_SPACING = 1 SMOOTH_BOOL = True -PRUNE_DEPTH = 25 +PRUNE_DEPTH = 20 class GraphLoader: @@ -135,22 +135,21 @@ def run( # --- Graph structure extraction --- def schedule_processes(self, swc_dicts): """ - Gets irreducible components of each graph stored in "swc_dicts" by setting - up a parellelization scheme that sends each swc_dict to a CPU and calling - the subroutine "get_irreducibles". + Gets irreducible components of each graph stored in "swc_dicts" by + setting up a parellelization scheme that sends each swc_dict to a CPU + and calls the subroutine "get_irreducibles". Parameters ---------- swc_dicts : list[dict] - List of dictionaries such that each entry contains the conents of an - swc file. + List of dictionaries such that each entry contains the conents of + an swc file. Returns ------- list[dict] - List of irreducibles stored in a dictionary where key-values are type - of irreducible (i.e. leaf, branching, or edge) and the corresponding - set of all irreducibles from the graph of that type. + List of dictionaries such that each is the set of irreducibles in + a connected component of the graph corresponding to "swc_dicts". """ with ProcessPoolExecutor() as executor: @@ -178,10 +177,8 @@ def schedule_processes(self, swc_dicts): def get_irreducibles(self, swc_dict): """ - Gets irreducible components of the graph stored in "swc_dict" by building - the graph store in the swc_dict and parsing it. In addition, this function - also calls routines prunes spurious branches and short paths connecting - branches (i.e. possible merge mistakes). + Gets the irreducible components of graph stored in "swc_dict". This + routine also calls routines prunes short paths. Parameters ---------- @@ -191,28 +188,26 @@ def get_irreducibles(self, swc_dict): Returns ------- list - List of irreducibles stored in a dictionary where key-values are type - of irreducible (i.e. leaf, branching, or edge) and corresponding set of - all irreducibles from the graph of that type. + List of dictionaries such that each is the set of irreducibles in + a connected component of the graph corresponding to "swc_dict". """ # Build dense graph swc_dict["idx"] = dict(zip(swc_dict["id"], range(len(swc_dict["id"])))) graph, _ = swc_util.to_graph(swc_dict, set_attrs=True) - graph = self.clip_branches(graph) - graph = self.prune_branches(graph) + self.clip_branches(graph) + self.prune_branches(graph) # Extract irreducibles irreducibles = list() if graph.number_of_nodes() > 1: for nodes in nx.connected_components(graph): if len(nodes) > 1: - subgraph = graph.subgraph(nodes) - irreducibles_i = self.get_component_irreducibles( - subgraph, swc_dict + result = self.get_component_irreducibles( + graph.subgraph(nodes), swc_dict ) - if irreducibles_i: - irreducibles.append(irreducibles_i) + if result: + irreducibles.append(result) return irreducibles def clip_branches(self, graph): @@ -223,14 +218,10 @@ def clip_branches(self, graph): ---------- graph : networkx.Graph Graph to be searched - img_bbox : dict - Dictionary with the keys "min" and "max" which specify a bounding box - in the image. The default is None. Returns ------- - networkx.Graph - "graph" with nodes deleted that were not contained in "img_bbox". + None """ if self.img_bbox: @@ -240,45 +231,6 @@ def clip_branches(self, graph): if not util.is_contained(self.img_bbox, xyz): delete_nodes.add(i) graph.remove_nodes_from(delete_nodes) - return graph - - def prune_branches(self, graph): - """ - Prunes all short branches from "graph". A short branch is a path between a - leaf and branching node where the path length is less than "prune_depth". - - Parameters - ---------- - graph : networkx.Graph - Graph to be pruned. - prune_depth : float - Path length microns that determines whether a branch is short. The - default is 16.0. - - Returns - ------- - networkx.Graph - Graph with short branches pruned. - - """ - for leaf in get_leafs(graph): - branch = [leaf] - depth = 2 * self.prune_depth - length = 0 - for (i, j) in nx.dfs_edges(graph, source=leaf, depth_limit=depth): - # Visit edge - length += compute_dist(graph, i, j) - if graph.degree(j) == 2: - branch.append(j) - elif graph.degree(j) > 2: - graph.remove_nodes_from(branch) - break - - # Check whether to stop - if length > self.prune_depth: - graph.remove_nodes_from(branch[0:min(5, len(branch))]) - break - return graph def get_component_irreducibles(self, graph, swc_dict): """ @@ -349,6 +301,50 @@ def get_component_irreducibles(self, graph, swc_dict): else: return False + def prune_branches(self, graph): + """ + Prunes all short branches from "graph". A short branch is a path + between a leaf and branching node where the path length is less than + "self.prune_depth". + + Parameters + ---------- + graph : networkx.Graph + Graph to be pruned. + + Returns + ------- + networkx.Graph + Graph with short branches pruned. + + """ + first_pass = True + deleted_nodes = list() + n_passes = 0 + while len(deleted_nodes) > 0 or first_pass: + # Visit leafs + n_passes += 1 + deleted_nodes = list() + for leaf in get_leafs(graph): + branch = [leaf] + length = 0 + for (i, j) in nx.dfs_edges(graph, source=leaf): + # Visit edge + length += compute_dist(graph, i, j) + if graph.degree(j) == 2: + branch.append(j) + elif graph.degree(j) > 2: + deleted_nodes.extend(branch) + graph.remove_nodes_from(branch) + break + + # Check whether to stop + if length > self.prune_depth or first_pass: + graph.remove_nodes_from(branch[0:min(3, len(branch))]) + break + + first_pass = False + # --- Utils --- def get_irreducible_nodes(graph):