diff --git a/src/deep_neurographs/doubles_removal.py b/src/deep_neurographs/doubles_removal.py new file mode 100644 index 0000000..440c82d --- /dev/null +++ b/src/deep_neurographs/doubles_removal.py @@ -0,0 +1,194 @@ +""" +Created on Sat June 25 9:00:00 2024 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Module that removes doubled fragments from a NeuroGraph. + +""" + +from deep_neurographs import utils +import networkx as nx + + +def run(neurograph, max_size, node_spacing): + """ + Removes connected components from "neurgraph" that are likely to be a + double. + + Parameters + ---------- + neurograph : NeuroGraph + Graph to be searched for doubles. + max_size : int + Maximum size of connected components to be searched. + node_spacing : int + Expected distance in microns between nodes in "neurograph". + + Returns + ------- + NeuroGraph + Graph with doubles removed. + + """ + # Assign processes + doubles_cnt = 0 + neurograph.init_kdtree() + not_doubles = set() + for nodes in list(nx.connected_components(neurograph)): + # Determine whether to inspect fragment + swc_id = get_swc_id(neurograph, nodes) + if swc_id not in not_doubles: + xyz_arr = inspect_component(neurograph, nodes) + if len(xyz_arr) > 0 and len(xyz_arr) * node_spacing < max_size: + not_double_id = is_double(neurograph, xyz_arr, swc_id) + if not_double_id: + doubles_cnt += 1 + neurograph = remove_component(neurograph, nodes, swc_id) + not_doubles.add(not_double_id) + print("# Doubles detected:", doubles_cnt) + + +def is_double(neurograph, xyz_arr, swc_id_i): + """ + Determines whether the connected component corresponding to "root" is a + double of another connected component. + + Paramters + --------- + neurograph : NeuroGraph + Graph to be searched for doubles. + xyz_arr : numpy.ndarray + Array containing xyz coordinates corresponding to some fragment (i.e. + connected component in neurograph). + swc_id_i : str + swc id corresponding to fragment. + + Returns + ------- + str or None + Indication of whether connected component is a double. If True, the + swc_id of the main fragment (i.e. non doubles) is returned. Otherwise, + the value None is returned to indicate that query fragment is not a + double. + + """ + # Compute projections + hits = dict() + for xyz_i in xyz_arr: + for xyz_j in neurograph.query_kdtree(xyz_i, 6): + try: + swc_id_j = neurograph.xyz_to_swc(xyz_j) + if swc_id_i != swc_id_j: + hits = utils.append_dict_value(hits, swc_id_j, 1) + except: + pass + + # Check criteria + if len(hits) > 0: + swc_id_j = utils.find_best(hits) + percent_hit = len(hits[swc_id_j]) / len(xyz_arr) + else: + percent_hit = 0 + return swc_id_j if swc_id_j is not None and percent_hit > 0.5 else None + + +# --- utils --- +def get_swc_id(neurograph, nodes): + """ + Gets the swc id corresponding to "nodes". + + Parameters + ---------- + neurograph : NeuroGraph + Graph containing "nodes". + nodes : list[int] + Nodes to be checked. + + Returns + ------- + str + swc id of "nodes". + + """ + i = utils.sample_singleton(nodes) + return neurograph.nodes[i]["swc_id"] + + +def inspect_component(neurograph, nodes): + """ + Determines whether to inspect component for doubles. + + Parameters + ---------- + neurograph : NeuroGraph + Graph to be searched. + nodes : iterable + Nodes that comprise a connected component. + + Returns + ------- + numpy.ndarray or list + Array containing xyz coordinates of nodes. + + """ + if len(nodes) == 2: + i, j = tuple(nodes) + return neurograph.edges[i, j]["xyz"] + else: + return [] + + +def remove_component(neurograph, nodes, swc_id): + """ + Removes "nodes" from "neurograph". + + Parameters + ---------- + neurograph : NeuroGraph + Graph that contains "nodes". + nodes : list[int] + Nodes to be removed. + swc_id : str + swc id corresponding to nodes which comprise a connected component in + "neurograph". + + Returns + ------- + NeuroGraph + Graph with nodes removed. + + """ + i, j = tuple(nodes) + neurograph = remove_xyz_entries(neurograph, i, j) + neurograph.remove_nodes_from([i, j]) + neurograph.leafs.remove(i) + neurograph.leafs.remove(j) + neurograph.swc_ids.remove(swc_id) + return neurograph + + +def remove_xyz_entries(neurograph, i, j): + """ + Removes dictionary entries from "neurograph.xyz_to_edge" corresponding to + the edge {i, j}. + + Parameters + ---------- + neurograph : NeuroGraph + Graph to be updated. + i : int + Node in "neurograph". + j : int + Node in "neurograph". + + Returns + ------- + NeuroGraph + Updated graph. + + """ + for xyz in neurograph.edges[i, j]["xyz"]: + del neurograph.xyz_to_edge[tuple(xyz)] + return neurograph diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index d66dbff..30142ad 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -199,35 +199,6 @@ def get_conection(neurograph, leaf, xyz): return neurograph, node -def is_valid(neurograph, i, filter_doubles): - """ - Determines whether is a valid node to generate proposals from. A node is - considered valid if it is not contained in a doubled connected component - (if applicable). - - Parameters - ---------- - neurograph : NeuroGraph - Graph built from swc files. - i : int - Node to be validated. - filter_doubles : bool - Indication of whether to prevent proposals from being connected to a - doubled connected component. - - Returns - ------- - bool - Indication of whether node is valid. - - """ - if filter_doubles: - neurograph.upd_doubles(i) - swc_id = neurograph.nodes[i]["swc_id"] - return True if swc_id in neurograph.doubles else False - - -# -- utils -- def get_closer_endpoint(neurograph, edge, xyz): i, j = tuple(edge) d_i = geometry.dist(neurograph.nodes[i]["xyz"], xyz) diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 8847557..b436ce6 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -252,7 +252,7 @@ def download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy): ) if i > cnt * chunk_size: cnt, t1 = report_progress( - i, len(zip_paths), chunk_size, cnt, t0, t1 + i + 1, len(zip_paths), chunk_size, cnt, t0, t1 ) # Store results @@ -306,11 +306,14 @@ def build_neurograph( t0, t1 = utils.init_timers() chunk_size = max(int(n_components * 0.02), 1) cnt, i = 1, 0 + n_components = len(irreducibles) while len(irreducibles): irreducible_set = irreducibles.pop() neurograph.add_component(irreducible_set) if i > cnt * chunk_size and progress_bar: - cnt, t1 = report_progress(i, n_components, chunk_size, cnt, t0, t1) + cnt, t1 = report_progress( + i + 2, n_components, chunk_size, cnt, t0, t1 + ) i += 1 if progress_bar: t, unit = utils.time_writer(time() - t0) @@ -364,7 +367,7 @@ def get_irreducibles( n_edges += count_edges(irreducibles_i) if i > progress_cnt * chunk_size and progress_bar: progress_cnt, t1 = report_progress( - i, n_components, chunk_size, progress_cnt, t0, t1 + i + 1, n_components, chunk_size, progress_cnt, t0, t1 ) if progress_bar: t, unit = utils.time_writer(time() - t0) diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/machine_learning/inference.py index 4ec1b0a..dc465fe 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/machine_learning/inference.py @@ -142,7 +142,7 @@ def run_without_seeds( # Report progress if i > progress_cnt * chunk_size and progress_bar: progress_cnt, t1 = utils.report_progress( - i, n_batches, chunk_size, progress_cnt, t0, t1 + i + 1, n_batches, chunk_size, progress_cnt, t0, t1 ) t0, t1 = utils.init_timers() diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 0947591..1c8a37e 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -12,7 +12,6 @@ from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from io import StringIO -from random import sample import networkx as nx import numpy as np @@ -59,6 +58,7 @@ def __init__( self.junctions = set() self.proposals = dict() self.target_edges = set() + self.node_cnt = 0 self.node_spacing = node_spacing # Initialize data structures for proposals @@ -105,19 +105,20 @@ def add_component(self, irreducibles): None """ - # Nodes - ids = self.__add_nodes(irreducibles, "leafs", dict()) - ids = self.__add_nodes(irreducibles, "junctions", ids) - - # Edges swc_id = irreducibles["swc_id"] - self.swc_ids.add(swc_id) - for (i, j), attrs in irreducibles["edges"].items(): - edge = (ids[i], ids[j]) - idxs = np.arange(0, attrs["xyz"].shape[0], self.node_spacing) - if idxs[-1] != attrs["xyz"].shape[0] - 1: - idxs = np.append(idxs, attrs["xyz"].shape[0] - 1) - self.__add_edge(edge, attrs, idxs, swc_id) + if swc_id not in self.swc_ids: + # Nodes + self.swc_ids.add(swc_id) + ids = self.__add_nodes(irreducibles, "leafs", dict()) + ids = self.__add_nodes(irreducibles, "junctions", ids) + + # Edges + for (i, j), attrs in irreducibles["edges"].items(): + edge = (ids[i], ids[j]) + idxs = np.arange(0, attrs["xyz"].shape[0], self.node_spacing) + if idxs[-1] != attrs["xyz"].shape[0] - 1: + idxs = np.append(idxs, attrs["xyz"].shape[0] - 1) + self.__add_edge(edge, attrs, idxs, swc_id) def __add_nodes(self, irreducibles, node_type, node_ids): """ @@ -142,7 +143,7 @@ def __add_nodes(self, irreducibles, node_type, node_ids): """ for i in irreducibles[node_type].keys(): - cur_id = self.number_of_nodes() + 1 + cur_id = self.node_cnt + 1 self.add_node( cur_id, proposals=set(), @@ -150,6 +151,7 @@ def __add_nodes(self, irreducibles, node_type, node_ids): swc_id=irreducibles["swc_id"], xyz=irreducibles[node_type][i]["xyz"], ) + self.node_cnt += 1 if node_type == "leafs": self.leafs.add(cur_id) else: @@ -230,7 +232,7 @@ def split_edge(self, edge, attrs, idx): self.remove_edge(i, j) # Create node - node_id = len(self.nodes) + 1 + node_id = self.node_cnt + 1 swc_id = attrs["swc_id"] self.add_node( node_id, @@ -239,6 +241,7 @@ def split_edge(self, edge, attrs, idx): swc_id=swc_id, xyz=tuple(attrs["xyz"][idx]), ) + self.node_cnt += 1 # Create edges idxs_1 = np.arange(0, idx + 1) @@ -638,52 +641,6 @@ def get_reconstruction(self, proposals, upd_self=False): ) return reconstruction - def upd_doubles(self, i): - swc_id_i = self.nodes[i]["swc_id"] - if swc_id_i not in self.doubles: - if self.is_double(i): - self.doubles.add(swc_id_i) - - def is_double(self, i): - """ - Determines whether the connected component corresponding to "root" is - a double of another connected component. - - Paramters - --------- - root : int - Node of connected component to be evaluated. - - Returns - ------- - bool - Indication of whether connected component is a double. - - """ - nb = list(self.neighbors(i))[0] - if self.degree[i] == 1 and self.degree[nb] == 1: - # Find near components - swc_id_i = self.nodes[i]["swc_id"] - hits = dict() # near components - segment_i = self.get_branches(i)[0] - for xyz_i in segment_i: - for xyz_j in self.query_kdtree(xyz_i, 8): - swc_id_j, node = self.xyz_to_swc(xyz_j, return_node=True) - if swc_id_i != swc_id_j: - hits = utils.append_dict_value(hits, swc_id_j, node) - break - - # Parse queried components - swc_id_j, n_close = utils.find_best(hits) - percent_close = n_close / len(segment_i) - if swc_id_j is not None and percent_close > 0.5: - j = sample(hits[swc_id_j], 1)[0] - length_i = len(segment_i) - length_j = self.component_cardinality(j) - if length_i / length_j < 0.6: - return True - return False - def xyz_to_swc(self, xyz, return_node=False): edge = self.xyz_to_edge[tuple(xyz)] i, j = tuple(edge)