diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index d6d6313..1fd4c86 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -40,7 +40,7 @@ def run(neurograph, search_radius, complex_bool=True): neurograph, connections = run_on_leaf( neurograph, connections, leaf, search_radius, complex_bool ) - neurograph.filter_nodes() + # neurograph.filter_nodes() return neurograph @@ -154,7 +154,7 @@ def get_conection(neurograph, leaf, xyz, search_radius): edge = neurograph.xyz_to_edge[xyz] node, d = get_closer_endpoint(neurograph, edge, xyz) if d > ENDPOINT_DIST: - #or neurograph.dist(leaf, node) > search_radius: + # or neurograph.dist(leaf, node) > search_radius: attrs = neurograph.get_edge_data(*edge) idx = np.where(np.all(attrs["xyz"] == xyz, axis=1))[0][0] node = neurograph.split_edge(edge, attrs, idx) diff --git a/src/deep_neurographs/graph_utils.py b/src/deep_neurographs/graph_utils.py index ba79da3..b9c4f33 100644 --- a/src/deep_neurographs/graph_utils.py +++ b/src/deep_neurographs/graph_utils.py @@ -33,11 +33,10 @@ def get_irreducibles( swc_dict, bbox=None, - min_size=0, prune_connectors=False, - prune_spurious=True, connector_length=8, prune_depth=16, + trim_depth=0, smooth=True, ): """ @@ -52,18 +51,14 @@ def get_irreducibles( Contents of an swc file. bbox : dict, optional ... - min_size : int, optional - Minimum number of nodes in graph to continue processing it after - pruning spurious branches. The default is 0. prune_connectors : bool, optional Indication of whether to prune short paths connecting branches. The default is False. - prune_spurious : bool, optional - Indication of whether to prune short branches (i.e. spurious branhces). - The default is True. - prune_depth : int, optional - Path length that determines whether a branch is short. The default is - 16. + prune_depth : float, optional + Path length microns that determines whether a branch is short and + should be pruned. The default is 16. + trim_depth : float, optional + Depth in microns to trim branch. The default is 0. smooth : bool, optional Indication of whether to smooth each branch. The default is True. @@ -78,28 +73,27 @@ def get_irreducibles( # Build dense graph swc_dict["idx"] = dict(zip(swc_dict["id"], range(len(swc_dict["id"])))) graph, _ = swc_utils.to_graph(swc_dict, set_attrs=True) - graph = trim_branches(graph, bbox) + graph = clip_branches(graph, bbox) graph, connector_centroids = prune_branches( graph, - prune_connectors=prune_connectors, - prune_spurious=prune_spurious, connector_length=connector_length, + prune_connectors=prune_connectors, prune_depth=prune_depth, + trim_depth=trim_depth, ) # Extract irreducibles irreducibles = [] for node_subset in nx.connected_components(graph): - if len(node_subset) > prune_depth: + if len(node_subset) > 10: subgraph = graph.subgraph(node_subset) irreducibles_i = __get_irreducibles(subgraph, swc_dict, smooth) if irreducibles_i: irreducibles.append(irreducibles_i) + return irreducibles, [] - return irreducibles, connector_centroids - -def trim_branches(graph, bbox): +def clip_branches(graph, bbox): """ Deletes all nodes from "graph" that are not contained in "bbox". @@ -128,10 +122,10 @@ def trim_branches(graph, bbox): def prune_branches( graph, - prune_connectors=False, - prune_spurious=True, connector_length=8, + prune_connectors=False, prune_depth=16, + trim_depth=0, ): """ Prunes spurious branches and short paths connecting branches @@ -141,12 +135,16 @@ def prune_branches( ---------- graph : networkx.Graph Graph to be pruned. + connector_length : float, optional + ... prune_connectors : bool, optional Indication of whether to prune short paths connecting branches. The default is False. - prune_spurious : bool, optional - Indication of whether to prune short branches (i.e. spurious - branches). The default is True + prune_depth : float, optional + Path length microns that determines whether a branch is short and + should be pruned. The default is 16. + trim_depth : float, optional + Depth in microns to trim branch. The default is 0. Returns ------- @@ -156,9 +154,9 @@ def prune_branches( List of xyz coordinates of the centerpoint of the connector path. """ - # Prune spurious branches - if prune_spurious or prune_connectors: - graph = prune_short_branches(graph, prune_depth) + # Prune/Trim branches + if prune_depth > 0 or prune_connectors: + graph = prune_trim_branches(graph, prune_depth, trim_depth) # Prune connectors connector_xyz = [] @@ -188,14 +186,17 @@ def __get_irreducibles(graph, swc_dict, smooth): """ # Extract nodes leafs, junctions = get_irreducible_nodes(graph) - if len(leafs) == 0: - return False + assert len(leafs), "No leaf nodes!" + if len(leafs) > 0: + source = sample(leafs, 1)[0] + else: + source = sample(junctions, 1)[0] # Extract edges edges = dict() nbs = dict() root = None - for (i, j) in nx.dfs_edges(graph, source=sample(leafs, 1)[0]): + for (i, j) in nx.dfs_edges(graph, source=source): # Check if start of path is valid if root is None: root = i @@ -253,16 +254,17 @@ def get_irreducible_nodes(graph): return leafs, junctions -# --- edit graph --- -def prune_short_branches(graph, depth): +# --- Refine graph --- +def prune_trim_branches(graph, depth, trim_depth): """ - Prunes all short branches from "graph". A short branch is a path between a - leaf and junction node with a path length smaller than depth. + Prunes all short branches from "graph" and trims branchs if applicable. A + short branch is a path between a leaf and junction node with a path length + smaller than depth. Parameters ---------- graph : networkx.Graph - Graph to be searched + Graph to be searched. depth : int Path length that determines whether a branch is short. @@ -274,14 +276,16 @@ def prune_short_branches(graph, depth): """ remove_nodes = [] for leaf in get_leafs(graph): - remove_nodes.extend(inspect_branch(graph, leaf, depth)) + remove_nodes.extend(inspect_branch(graph, leaf, depth, trim_depth)) graph.remove_nodes_from(remove_nodes) return graph -def inspect_branch(graph, leaf, depth): +def inspect_branch(graph, leaf, depth, trim_depth): """ - Determines whether the branch emanating from "leaf" should be pruned. + Determines whether the branch emanating from "leaf" should be pruned and + returns nodes that should be pruned. If applicable (i.e. trim_depth > 0), + trims the branch by "trim_depth" microns. Parameters ---------- @@ -291,7 +295,9 @@ def inspect_branch(graph, leaf, depth): Leaf node being inspected to determine whether it is the endpoint of a short branch that should be pruned. depth : int - Path length that determines whether a branch is short. + Path length microns that determines whether a branch is short. + trim_depth : float + Depth in microns to trim branch. Returns ------- @@ -300,13 +306,56 @@ def inspect_branch(graph, leaf, depth): Otherwise, an empty list is returned. """ + # Check whether to prune path = [leaf] + node_spacing = [] for (i, j) in nx.dfs_edges(graph, source=leaf, depth_limit=depth): + node_spacing.append(compute_dist(graph, i, j)) if graph.degree(j) > 2: return path elif graph.degree(j) == 2: path.append(j) - return path[0:max(10, len(path))] + elif np.sum(node_spacing) > depth: + break + + # Check whether to trim + spacing = np.mean(node_spacing) + if trim_depth > 0 and graph.number_of_nodes() > 3 * trim_depth / spacing: + return trim_branch(graph, path, trim_depth) + else: + return [] + + +def trim_branch(graph, path, trim_depth): + branch_length = 0 + for i in range(1, len(path)): + xyz_1 = graph.nodes[path[i - 1]]["xyz"] + xyz_2 = graph.nodes[path[i]]["xyz"] + branch_length += geometry.dist(xyz_1, xyz_2) + if branch_length > trim_depth: + break + return path[0:i] + + +def compute_dist(graph, i, j): + """ + Computes Euclidean distance between nodes i and j. + + Parameters + ---------- + graph : netowrkx.Graph + Graph containing nodes i and j. + i : int + Node. + j : int + Node. + + Returns + ------- + Euclidean distance between i and j. + + """ + return geometry.dist(graph.nodes[i]["xyz"], graph.nodes[j]["xyz"]) def prune_short_connectors(graph, length=8): @@ -316,7 +365,7 @@ def prune_short_connectors(graph, length=8): Parameters ---------- - graph : netowrkx.graph + graph : netowrkx.Graph Graph to be inspected. length : int, optional Upper bound on the distance that defines a connector path to be diff --git a/src/deep_neurographs/img_utils.py b/src/deep_neurographs/img_utils.py index a17023b..c9f59c9 100644 --- a/src/deep_neurographs/img_utils.py +++ b/src/deep_neurographs/img_utils.py @@ -140,9 +140,10 @@ def read_tensorstore_with_bbox(img, bbox): ) except Exception as e: print(type(e), e) - shape = [end[i] - start[i] +1 for i in range(3)] + shape = [end[i] - start[i] + 1 for i in range(3)] return np.zeros(shape) + def read_chunk(img, xyz, shape, from_center=True): """ Reads a chunk of data from arr"", given the xyz coordinates and shape of diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 5605206..8847557 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -18,14 +18,12 @@ from deep_neurographs.neurograph import NeuroGraph from deep_neurographs.swc_utils import process_gcs_zip, process_local_paths -import pickle - MIN_SIZE = 30 NODE_SPACING = 2 SMOOTH = True PRUNE_CONNECTORS = False -PRUNE_SPURIOUS = True PRUNE_DEPTH = 16 +TRIM_DEPTH = 0 CONNECTOR_LENGTH = 16 @@ -39,9 +37,9 @@ def build_neurograph_from_local( node_spacing=NODE_SPACING, progress_bar=False, prune_connectors=PRUNE_CONNECTORS, - prune_spurious=PRUNE_SPURIOUS, connector_length=CONNECTOR_LENGTH, prune_depth=PRUNE_DEPTH, + trim_depth=TRIM_DEPTH, smooth=SMOOTH, swc_dir=None, swc_paths=None, @@ -76,10 +74,6 @@ def build_neurograph_from_local( Indication of whether to prune connectors (see graph_utils.py), sites that are likely to be false merges. The default is the global variable "PRUNE_CONNECTORS". - prune_spurious : bool, optional - Indication of whether to prune spurious branches, these are short - branches which are an artifical from skeletonization. The default is - the global variable "PRUNE_SPURIOUS". connector_length : int, optional Maximum length of connecting paths pruned (see graph_utils.py). The default is the global variable "CONNECTOR_LENGTH". @@ -125,9 +119,9 @@ def build_neurograph_from_local( node_spacing=node_spacing, progress_bar=progress_bar, prune_connectors=prune_connectors, - prune_spurious=prune_spurious, connector_length=connector_length, prune_depth=prune_depth, + trim_depth=trim_depth, smooth=smooth, swc_paths=paths, ) @@ -142,9 +136,9 @@ def build_neurograph_from_gcs_zips( min_size=MIN_SIZE, node_spacing=NODE_SPACING, prune_connectors=PRUNE_CONNECTORS, - prune_spurious=PRUNE_SPURIOUS, connector_length=CONNECTOR_LENGTH, prune_depth=PRUNE_DEPTH, + trim_depth=TRIM_DEPTH, smooth=SMOOTH, ): """ @@ -172,10 +166,6 @@ def build_neurograph_from_gcs_zips( Indication of whether to prune connectors (see graph_utils.py), sites that are likely to be false merges. The default is the global variable "PRUNE_CONNECTORS". - prune_spurious : bool, optional - Indication of whether to prune spurious branches, these are short - branches which are an artifical from skeletonization. The default is - the global variable "PRUNE_SPURIOUS". connector_length : int, optional Maximum length of connecting paths pruned (see graph_utils.py). The default is the global variable "CONNECTOR_LENGTH". @@ -208,9 +198,9 @@ def build_neurograph_from_gcs_zips( min_size=min_size, node_spacing=node_spacing, prune_connectors=prune_connectors, - prune_spurious=prune_spurious, connector_length=connector_length, prune_depth=prune_depth, + trim_depth=trim_depth, smooth=smooth, ) t, unit = utils.time_writer(time() - t0) @@ -249,13 +239,16 @@ def download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy): # Assign processes cnt = 1 t0, t1 = utils.init_timers() - chunk_size = int(len(zip_paths) * 0.02) + chunk_size = int(len(zip_paths) * 0.1) with ProcessPoolExecutor() as executor: processes = [] + print("# zips:", len(zip_paths)) for i, path in enumerate(zip_paths): zip_content = bucket.blob(path).download_as_bytes() processes.append( - executor.submit(process_gcs_zip, zip_content, anisotropy, min_size) + executor.submit( + process_gcs_zip, zip_content, anisotropy, min_size + ) ) if i > cnt * chunk_size: cnt, t1 = report_progress( @@ -279,9 +272,9 @@ def build_neurograph( swc_paths=None, progress_bar=True, prune_connectors=PRUNE_CONNECTORS, - prune_spurious=PRUNE_SPURIOUS, connector_length=CONNECTOR_LENGTH, prune_depth=PRUNE_DEPTH, + trim_depth=TRIM_DEPTH, smooth=SMOOTH, ): # Extract irreducibles @@ -295,9 +288,9 @@ def build_neurograph( min_size=min_size, progress_bar=progress_bar, prune_connectors=prune_connectors, - prune_spurious=prune_spurious, connector_length=connector_length, prune_depth=prune_depth, + trim_depth=trim_depth, smooth=smooth, ) @@ -308,10 +301,7 @@ def build_neurograph( print("# edges:", utils.reformat_number(n_edges)) neurograph = NeuroGraph( - img_bbox=img_bbox, - img_path=img_path, - node_spacing=node_spacing, - swc_paths=swc_paths, + img_path=img_path, node_spacing=node_spacing, swc_paths=swc_paths ) t0, t1 = utils.init_timers() chunk_size = max(int(n_components * 0.02), 1) @@ -325,6 +315,7 @@ def build_neurograph( if progress_bar: t, unit = utils.time_writer(time() - t0) print("\n" + f"add_irreducibles(): {round(t, 4)} {unit}") + return neurograph @@ -334,13 +325,13 @@ def get_irreducibles( min_size=MIN_SIZE, progress_bar=True, prune_connectors=PRUNE_CONNECTORS, - prune_spurious=PRUNE_SPURIOUS, connector_length=CONNECTOR_LENGTH, prune_depth=PRUNE_DEPTH, + trim_depth=TRIM_DEPTH, smooth=SMOOTH, ): n_components = len(swc_dicts) - chunk_size = max(int(n_components * 0.02), 1) + chunk_size = max(int(n_components * 0.25), 1) with ProcessPoolExecutor() as executor: # Assign Processes i = 0 @@ -351,11 +342,10 @@ def get_irreducibles( gutils.get_irreducibles, swc_dict, bbox, - min_size, prune_connectors, - prune_spurious, connector_length, prune_depth, + trim_depth, smooth, ) i += 1 diff --git a/src/deep_neurographs/machine_learning/hetero_graph_models.py b/src/deep_neurographs/machine_learning/hetero_graph_models.py index 4f4fa9c..0faf440 100644 --- a/src/deep_neurographs/machine_learning/hetero_graph_models.py +++ b/src/deep_neurographs/machine_learning/hetero_graph_models.py @@ -210,8 +210,6 @@ def __init__( self.output = Linear(heads_1 * heads_2 * hidden_dim) # Convolutional layers - n_node_types = len(node_dict.keys()) - n_edge_type = len(edge_dict.keys()) self.conv1 = HEATConv( hidden_dim, hidden_dim, @@ -300,10 +298,8 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict, metadata): # Convolutional layers x_dict = self.conv1(x_dict, edge_index_dict, metadata) - x_dict = self.conv2( - x_dict, edge_index_dict, metadata - ) + x_dict = self.conv2(x_dict, edge_index_dict, metadata) # Output x_dict = self.output(x_dict["proposal"]) - return x_dict \ No newline at end of file + return x_dict diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/machine_learning/inference.py index 4ec1b0a..5612d14 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/machine_learning/inference.py @@ -136,6 +136,7 @@ def run_without_seeds( ) # Merge proposals + print("# components:", len(list(nx.connected_components(neurograph)))) neurograph = build.fuse_branches(neurograph, accepts_i) accepts.extend(accepts_i) diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 8559af1..7d0626e 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -42,7 +42,7 @@ def __init__( swc_paths=None, img_path=None, label_mask=None, - node_spacing=2, + node_spacing=1, train_model=False, ): super(NeuroGraph, self).__init__() @@ -703,21 +703,10 @@ def component_cardinality(self, root): queue.append((j, k)) return cardinality - def filter_nodes(self): - # Find nodes to filter - ingest_nodes = set() - for i in [i for i in self.nodes if self.degree[i] == 2]: - if len(self.nodes[i]["proposals"]) == 0: - ingest_nodes.add(i) - - # Ingest nodes to be filtered - for i in ingest_nodes: - nbs = list(self.neighbors(i)) - self.absorb_node(i, nbs[0], nbs[1]) - def to_swc(self, path): with ThreadPoolExecutor() as executor: threads = [] + print("# swcs:", len(list(nx.connected_components(self)))) for i, component in enumerate(nx.connected_components(self)): node = sample(component, 1)[0] swc_id = self.nodes[node]["swc_id"] @@ -743,7 +732,7 @@ def component_to_swc(self, path, component): node_to_idx[j] = len(entry_list) # Write - if len(entry_list) > 30: + if len(entry_list) > 0: swc_utils.write(path, entry_list) def branch_to_entries(self, entry_list, i, j, parent): diff --git a/src/deep_neurographs/reconstruction.py b/src/deep_neurographs/reconstruction.py index 7fc12ab..49dcd6c 100644 --- a/src/deep_neurographs/reconstruction.py +++ b/src/deep_neurographs/reconstruction.py @@ -208,7 +208,7 @@ def save_prediction(neurograph, accepted_proposals, output_dir): # Write Result neurograph.to_swc(swc_dir) - #save_corrections(neurograph, accepted_proposals, corrections_dir) + save_corrections(neurograph, accepted_proposals, corrections_dir) save_connections(neurograph, accepted_proposals, connections_path) diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index 6bef833..c7b0146 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -21,7 +21,7 @@ # -- io utils -- def process_local_paths( - paths, anisotropy=[1.0, 1.0, 1.0], min_size=3, img_bbox=None + paths, anisotropy=[1.0, 1.0, 1.0], min_size=5, img_bbox=None ): """ Iterates over a list of swc paths to swc file, then builds a dictionary