diff --git a/src/deep_neurographs/doubles_removal.py b/src/deep_neurographs/doubles_removal.py index 440c82d..4a90002 100644 --- a/src/deep_neurographs/doubles_removal.py +++ b/src/deep_neurographs/doubles_removal.py @@ -12,7 +12,7 @@ import networkx as nx -def run(neurograph, max_size, node_spacing): +def run(neurograph, min_size, max_size, node_spacing, output_dir=None): """ Removes connected components from "neurgraph" that are likely to be a double. @@ -25,6 +25,8 @@ def run(neurograph, max_size, node_spacing): Maximum size of connected components to be searched. node_spacing : int Expected distance in microns between nodes in "neurograph". + output_dir : str or None, optional + Directory that doubles will be written to. The default is None. Returns ------- @@ -41,16 +43,20 @@ def run(neurograph, max_size, node_spacing): swc_id = get_swc_id(neurograph, nodes) if swc_id not in not_doubles: xyz_arr = inspect_component(neurograph, nodes) - if len(xyz_arr) > 0 and len(xyz_arr) * node_spacing < max_size: + upper_bound = len(xyz_arr) * node_spacing < max_size + lower_bound = len(xyz_arr) * node_spacing > min_size + if upper_bound and lower_bound: not_double_id = is_double(neurograph, xyz_arr, swc_id) if not_double_id: doubles_cnt += 1 + if output_dir: + neurograph.to_swc(output_dir, nodes, color="1.0 0.0 0.0") neurograph = remove_component(neurograph, nodes, swc_id) not_doubles.add(not_double_id) print("# Doubles detected:", doubles_cnt) -def is_double(neurograph, xyz_arr, swc_id_i): +def is_double(neurograph, fragment, swc_id_i): """ Determines whether the connected component corresponding to "root" is a double of another connected component. @@ -59,7 +65,7 @@ def is_double(neurograph, xyz_arr, swc_id_i): --------- neurograph : NeuroGraph Graph to be searched for doubles. - xyz_arr : numpy.ndarray + fragment : numpy.ndarray Array containing xyz coordinates corresponding to some fragment (i.e. connected component in neurograph). swc_id_i : str @@ -76,8 +82,8 @@ def is_double(neurograph, xyz_arr, swc_id_i): """ # Compute projections hits = dict() - for xyz_i in xyz_arr: - for xyz_j in neurograph.query_kdtree(xyz_i, 6): + for xyz_i in fragment: + for xyz_j in neurograph.query_kdtree(xyz_i, 5): try: swc_id_j = neurograph.xyz_to_swc(xyz_j) if swc_id_i != swc_id_j: @@ -88,7 +94,7 @@ def is_double(neurograph, xyz_arr, swc_id_i): # Check criteria if len(hits) > 0: swc_id_j = utils.find_best(hits) - percent_hit = len(hits[swc_id_j]) / len(xyz_arr) + percent_hit = len(hits[swc_id_j]) / len(fragment) else: percent_hit = 0 return swc_id_j if swc_id_j is not None and percent_hit > 0.5 else None diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 1c8a37e..a7e2994 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -666,18 +666,23 @@ def component_cardinality(self, root): queue.append((j, k)) return cardinality - def to_zipped_swcs(self, zip_path): + def to_zipped_swcs(self, zip_path, color=None): n_components = gutils.count_components(self) print(f"Writing {n_components} swcs to local machine!") with zipfile.ZipFile(zip_path, "w") as zipf: for nodes in nx.connected_components(self): - self.to_zipped_swc(zipf, nodes) + self.to_zipped_swc(zipf, nodes, color) - def to_zipped_swc(self, zipf, nodes): + def to_zipped_swc(self, zipf, nodes, color): with StringIO() as text_buffer: + # Preamble n_entries = 0 node_to_idx = dict() + if color: + text_buffer.write("# COLOR " + color) text_buffer.write("# id, type, z, y, x, r, pid") + + # Write entries for i, j in nx.dfs_edges(self.subgraph(nodes)): # Initialize if n_entries == 0: @@ -718,16 +723,18 @@ def to_swcs(self, swc_dir): for i, nodes in enumerate(nx.connected_components(self)): threads.append(executor.submit(self.to_swc, swc_dir, nodes)) - def to_swc(self, swc_dir, nodes): + def to_swc(self, swc_dir, nodes, color=None): """ Generates list of swc entries for a given connected component. Parameters ---------- - path : str - Path that swc for component will be written to. - Component : list[int] - List of nodes contained in "component". + swc_dir : str + Directory that swc will be written to. + nodes : list[int] + Nodes to be written to an swc file. + color : None or str + Color that swc files should be given. Returns ------- @@ -753,7 +760,7 @@ def to_swc(self, swc_dir, nodes): node_to_idx[j] = len(entry_list) # Write - swc_utils.write(path, entry_list) + swc_utils.write(path, entry_list, color=color) def branch_to_entries(self, entry_list, i, j, parent): # Orient branch diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index c7b0146..5af7c2d 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -235,7 +235,7 @@ def write_list(path, entry_list, color=None): with open(path, "w") as f: # Preamble if color is not None: - f.write("# COLOR" + color) + f.write("# COLOR " + color) else: f.write("# id, type, z, y, x, r, pid")