diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index d514cb5..28e6efa 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -8,33 +8,20 @@ Overview -------- Code that reads and preprocesses neuron fragments stored as swc files, then -constructs a custom graph object called a "FragmentsGraph". +constructs a custom graph object called a "FragmentsGraph" from the fragments. Graph Construction Algorithm: 1. Read Neuron Fragments to do... - 2. Preprocess Fragments and Extract Irreducibles + 2. Extract Irreducibles to do... 3. Build FragmentsGraph to do... -Terminology ------------- - -Leaf: a node with degree 1. - -Branching: a node with degree > 2. - -Irreducibles: the irreducibles of a graph consists of 1) leaf nodes, -2) branching nodes, and 3) edges connecting (1) and (2). - -Branch: a sequence of nodes between two irreducible nodes. - """ -from collections import defaultdict from concurrent.futures import ProcessPoolExecutor, as_completed from random import sample @@ -118,10 +105,10 @@ def run(self, fragments_pointer): from deep_neurographs.fragments_graph import FragmentsGraph # Step 1: Read Neuron Fragments - swc_dicts = self.reader.load(fragments_pointer) + graph_list = self.reader.load(fragments_pointer) - # Step: Preprocess Fragments and Extract Irreducibles - irreducibles = self.schedule_processes(swc_dicts) + # Step: Extract Irreducibles + irreducibles = self.process_graphs(graph_list) # Step 3: Build FragmentsGraph fragments_graph = FragmentsGraph(node_spacing=self.node_spacing) @@ -130,81 +117,97 @@ def run(self, fragments_pointer): fragments_graph.add_component(irreducible_set) return fragments_graph - # --- Graph structure extraction --- - def schedule_processes(self, swc_dicts): + def process_graphs(self, graphs_list): """ - Gets irreducible components of each graph stored in "swc_dicts" by - setting up a parellelization scheme that sends each swc_dict to a CPU - and calls the subroutine "get_irreducibles". + Processes a list of graphs in parallel and extracts irreducible + subgraphs from each graph. Parameters ---------- - swc_dicts : list[dict] - List of dictionaries such that each entry contains the conents of - an swc file. + graphs_list : List[network.Graph] + List of graphs to be processed. Each graph is passed to the + "process_graph" method, which extracts the irreducible subgraphs + from each graph. Returns ------- - list[dict] - List of dictionaries such that each is the set of irreducibles in - a connected component of the graph corresponding to "swc_dicts". + List[dict] + List of irreducible subgraphs extracted from the input graphs. """ # Initializations if self.progress_bar: - pbar = tqdm(total=len(swc_dicts), desc="Extract Graphs") + pbar = tqdm(total=len(graphs_list), desc="Process Graphs") # Main - with ProcessPoolExecutor() as executor: + with ProcessPoolExecutor(max_workers=1) as executor: # Assign Processes - i = 0 - processes = [None] * len(swc_dicts) - while swc_dicts: - swc_dict = swc_dicts.pop() - processes[i] = executor.submit(self.get_irreducibles, swc_dict) - i += 1 + processes = list() + while graphs_list: + graph = graphs_list.pop() + processes.append( + executor.submit(self.extract_irreducibles, graph) + ) # Store results irreducibles = list() for process in as_completed(processes): - irreducibles.extend(process.result()) + result = process.result() + if result is not None: + irreducibles.append(result) if self.progress_bar: pbar.update(1) return irreducibles - def get_irreducibles(self, swc_dict): + def extract_irreducibles(self, graph): """ - Gets the irreducible components of graph stored in "swc_dict". This - routine also calls routines prunes short paths. + Gets the irreducible subgraph from the input graph. Parameters ---------- - swc_dict : dict - Contents of an swc file. + graph : dict + Graph that irreducible subgraph is to be extracted from. Returns ------- List[dict] - List of dictionaries such that each is the set of irreducibles in - a connected component of the graph corresponding to "swc_dict". + List of dictionaries such that each is the set of irreducibles + from the input graph. """ - # Build dense graph - swc_dict["idx"] = dict(zip(swc_dict["id"], range(len(swc_dict["id"])))) - graph, _ = swc_util.to_graph(swc_dict, set_attrs=True) + irreducibles = None self.prune_branches(graph) - - # Extract irreducibles - irreducibles = list() - path_length = compute_path_length(graph) - if path_length > self.min_size and graph.number_of_nodes() > 1: - for nodes in nx.connected_components(graph): - if len(nodes) > 1: - result = self.get_component_irreducibles( - graph.subgraph(nodes), swc_dict - ) - if result: - irreducibles.append(result) + if compute_path_length(graph) > self.min_size: + # Extract irreducible nodes + leafs, branchings = get_irreducible_nodes(graph) + assert len(leafs) > 0, "No leaf nodes!" + + # Extract irreducible edges + edges = dict() + root = None + for (i, j) in nx.dfs_edges(graph, source=util.sample_once(leafs)): + # Check for start of irreducible edge + if root is None: + root = i + path = [i] + xyz_list = [graph.nodes[i]["xyz"]] + + # Check for end of irreducible edge + path.append(j) + xyz_list.append(graph.nodes[j]["xyz"]) + if j in leafs or j in branchings: + edges[(root, j)] = path + if self.smooth_bool: + graph = smooth_path(graph, path, xyz_list) + root = None + + # Set irreducible attributes + irreducibles = { + "leaf": set_node_attrs(graph, leafs), + "branching": set_node_attrs(graph, branchings), + "edge": set_edge_attrs(graph, edges), + "swc_id": graph.graph["swc_id"], + } return irreducibles def prune_branches(self, graph): @@ -250,72 +253,8 @@ def prune_branches(self, graph): graph.remove_nodes_from(branch[0:k]) break - def get_component_irreducibles(self, graph, swc_dict): - """ - Gets the irreducible components of "graph". - - Parameters - ---------- - graph : networkx.Graph - Graph to be searched. - swc_dict : dict - Dictionary used to build "graph". - Returns - ------- - dict - Dictionary containing irreducible components of "graph". - - """ - # Extract nodes - leafs, branchings = get_irreducible_nodes(graph) - assert len(leafs) > 0, "No leaf nodes!" - - # Extract edges - edges = dict() - nbs = defaultdict(list) - root = None - branch_length = 0 - for (i, j) in nx.dfs_edges(graph, source=util.sample_once(leafs)): - # Check if starting new or continuing current path - if root is None: - root = i - branch_length = 0 - attrs = init_edge_attrs(swc_dict, root) - - # Vist i - xyz_i = swc_dict["xyz"][swc_dict["idx"][i]] - xyz_j = swc_dict["xyz"][swc_dict["idx"][j]] - branch_length += geometry.dist(xyz_i, xyz_j) - - # Visit j - attrs = upd_edge_attrs(swc_dict, attrs, j) - if j in leafs or j in branchings: - attrs["length"] = branch_length - attrs = to_numpy(attrs) - if self.smooth_bool: - swc_dict, edges = smooth_branch( - swc_dict, attrs, edges, nbs, root, j - ) - else: - edges[(root, j)] = attrs - - # Finish - nbs[root].append(j) - nbs[j].append(root) - root = None - - # Output - irreducibles = { - "leaf": set_node_attrs(swc_dict, leafs), - "branching": set_node_attrs(swc_dict, branchings), - "edge": edges, - "swc_id": swc_dict["swc_id"], - } - return irreducibles - - -# --- Utils --- +# --- Extract Irreducibles --- def get_irreducible_nodes(graph): """ Gets irreducible nodes (i.e. leafs and branchings) of a graph. @@ -341,202 +280,7 @@ def get_irreducible_nodes(graph): return leafs, branchings -def smooth_branch(swc_dict, attrs, edges, nbs, root, j): - """ - Smoothes a branch then updates "swc_dict" and "edges" with the new xyz - coordinates of the branch end points. Note that this branch is an edge - in the irreducible graph being built. - - Parameters - ---------- - swc_dict : dict - Contents of an swc file. - attrs : dict - Attributes (from "swc_dict") of edge being smoothed. - edges : dict - Dictionary where the keys are edges in irreducible graph and values - are the corresponding attributes. - nbs : dict - Dictionary where the keys are nodes and values are the neighbors. - root : int - End point of branch to be smoothed. - j : int - End point of branch to be smoothed. - - Returns - ------- - dict, dict - Dictionaries that have been updated with respect to smoothed edges. - - """ - attrs["xyz"] = geometry.smooth_branch(attrs["xyz"], s=2) - swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, root, 0) - swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, j, -1) - edges[(root, j)] = attrs - return swc_dict, edges - - -def upd_xyz(swc_dict, attrs, edges, nbs, i, endpoint): - """ - Updates "swc_dict" and "edges" with the new xyz coordinates of the branch - end points. - - Parameters - ---------- - swc_dict : dict - Contents of an swc file. - attrs : dict - Attributes (from "swc_dict") of edge being smoothed. - edges : dict - Dictionary where the keys are edges in irreducible graph and values - are the corresponding attributes. - nbs : dict - Dictionary where the keys are nodes and values are the neighbors. - endpoint : int - End point of branch to be smoothed. - - Returns - ------- - dict - Updated with new xyz coordinates. - dict - Updated with new xyz coordinates. - - """ - idx = swc_dict["idx"][i] - if i in nbs.keys(): - for j in nbs[i]: - key = (i, j) if (i, j) in edges.keys() else (j, i) - edges = upd_endpoint_xyz( - edges, key, swc_dict["xyz"][idx], attrs["xyz"][endpoint] - ) - swc_dict["xyz"][idx] = attrs["xyz"][endpoint] - return swc_dict, edges - - -def upd_endpoint_xyz(edges, key, old_xyz, new_xyz): - """ - Updates "edges" with the new xyz coordinates of the branch - end points. - - Parameters - ---------- - edges : dict - Dictionary where the keys are edges in irreducible graph and values - are the corresponding attributes. - key : tuple - The edge id of the entry in "edges" which needs to be updated. - old_xyz : numpy.ndarray - Current xyz coordinate of end point. - new_xyz : numpy.ndarray - New xyz coordinate of end point. - - Returns - ------- - dict - Updated with new xyz coordinates. - - """ - if all(edges[key]["xyz"][0] == old_xyz): - edges[key]["xyz"][0] = new_xyz - elif all(edges[key]["xyz"][-1] == old_xyz): - edges[key]["xyz"][-1] = new_xyz - return edges - - -def init_edge_attrs(swc_dict, i): - """ - Initializes edge attribute dictionary with attributes from node "i" which - is an end point of the edge. Note: the following assertion error may be - useful: assert i in swc_dict["idx"].keys(), f"{swc_dict["swc_id"]} - {i}" - - Parameters - ---------- - swc_dict : dict - Contents of an swc file. - i : int - End point of edge and the swc attributes of this node are used to - initialize the edge attriubte dictionary. - - Returns - ------- - dict - Edge attribute dictionary. - - """ - j = swc_dict["idx"][i] - return {"radius": [swc_dict["radius"][j]], "xyz": [swc_dict["xyz"][j]]} - - -def upd_edge_attrs(swc_dict, attrs, i): - """ - Updates an edge attribute dictionary with attributes of node i. - - Parameters - ---------- - swc_dict : dict - Contents of an swc file. - attrs : dict - Attributes (from "swc_dict") of edge being updated. - i : int - Node of edge whose attributes will be added to "attrs". - - Returns - ------- - dict - Edge attribute dictionary. - - """ - swc_id = swc_dict["swc_id"] - assert i != -1, f"{swc_id} - {i}" - j = swc_dict["idx"][i] - attrs["radius"].append(swc_dict["radius"][j]) - attrs["xyz"].append(swc_dict["xyz"][j]) - return attrs - - -def get_edge_attr(graph, edge, attr): - """ - Gets the attribute "attr" of "edge". - - Parameters - ---------- - graph : networkx.Graph - Graph which "edge" belongs to. - edge : tuple - Edge to be queried for its attributes. - attr : str - Attribute to be queried. - - Returns - ------- - Attribute "attr" of "edge" - - """ - return graph.edges[edge][attr] - - -def to_numpy(attrs): - """ - Converts edge attributes from a list to NumPy array. - - Parameters - ---------- - attrs : dict - Dictionary containing attributes of some edge. - - Returns - ------- - dict - Updated edge attribute dictionary. - - """ - attrs["xyz"] = np.array(attrs["xyz"], dtype=np.float32) - attrs["radius"] = np.array(attrs["radius"], dtype=np.float16) - return attrs - - -def set_node_attrs(swc_dict, nodes): +def set_node_attrs(graph, nodes): """ Set node attributes by extracting values from "swc_dict". @@ -545,7 +289,7 @@ def set_node_attrs(swc_dict, nodes): swc_dict : dict Contents of an swc file. nodes : list - List of nodes to set attributes. + List of node ids to set attributes. Returns ------- @@ -554,47 +298,59 @@ def set_node_attrs(swc_dict, nodes): attributes extracted from "swc_dict". """ - attrs = dict() + node_attrs = dict() for i in nodes: - j = swc_dict["idx"][i] - attrs[i] = {"radius": swc_dict["radius"][j], "xyz": swc_dict["xyz"][j]} - return attrs + node_attrs[i] = { + "radius": graph.nodes[i]["radius"], "xyz": graph.nodes[i]["xyz"] + } + return node_attrs + + +def set_edge_attrs(graph, edges): + edge_attrs = dict() + for edge, path in edges.items(): + # Extract attributes + radius_list, xyz_list = list(), list() + for i in path: + radius_list.append(graph.nodes[i]["radius"]) + xyz_list.append(graph.nodes[i]["xyz"]) + + # Set attributes + edge_attrs[edge] = { + "length": 1000, + "radius": np.array(radius_list), + "xyz": np.array(xyz_list) + } + return edge_attrs -def upd_node_attrs(swc_dict, leafs, branchings, i): +# --- Miscellaneous --- +def smooth_path(graph, path, xyz_list): """ - Updates node attributes by extracting values from "swc_dict". + Smooths a given path on a graph by applying smoothing to the coordinates + of the nodes along the path and updating the graph with the smoothed + coordinates. Parameters ---------- - swc_dict : dict - Contents of an swc file that contains the smoothed xyz coordinates of - corresponding to "leafs" and "branchings". Note xyz coordinates are - smoothed during edge extraction. - leafs : dict - Dictionary where keys are leaf node ids and values are attribute - dictionaries. - branchings : dict - Dictionary where keys are branching node ids and values are attribute - dictionaries. - i : int - Node to be updated. + graph : networkx.Graph + Graph containing path to be smoothed. + path : List[int] + List of node indices representing the path in the graph. + xyz_list : List[Tuple[float]] + List of xyz coordinates of path in the graph to be smoothed. Returns ------- - dict - Updated dictionary if "i" was contained in "leafs.keys()". - dict - Updated dictionary if "i" was contained in "branchings.keys()". + networkx.Graph + Input graph with updated "xyz" attributes for the nodes from the input + path. """ - j = swc_dict["idx"][i] - upd_attrs = {"radius": swc_dict["radius"][j], "xyz": swc_dict["xyz"][j]} - if i in leafs: - leafs[i] = upd_attrs - else: - branchings[i] = upd_attrs - return leafs, branchings + smoothed_xyz_list = geometry.smooth_branch(np.array(xyz_list), s=2) + for i, xyz in zip(path, smoothed_xyz_list): + graph.nodes[i]["xyz"] = xyz + return graph def compute_path_length(graph): diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index c759d54..ef020c1 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -141,9 +141,9 @@ def load_from_local_path(self, path): """ content = util.read_txt(path) if len(content) > self.min_size - 10: - result = self.parse(content) - result["swc_id"] = util.get_swc_id(path) - return result + graph = self.parse(content) + graph.graph["swc_id"] = util.get_swc_id(path) + return graph else: return False @@ -268,14 +268,45 @@ def load_from_zipped_file(self, zip_file, path): """ content = util.read_zip(zip_file, path).splitlines() if len(content) > self.min_size - 10: - result = self.parse(content) - result["swc_id"] = util.get_swc_id(path) - return result + graph = self.parse(content) + graph.graph["swc_id"] = util.get_swc_id(path) + return graph else: return False # --- Process swc content --- def parse(self, content): + """ + Reads an swc file and builds an undirected graph from it. + + Parameters + ---------- + path : str + Path to swc file to be read. + + Returns + ------- + networkx.Graph + Graph built from an swc file. + + """ + graph = nx.Graph() + content, offset = self.process_content(content) + for line in content: + # Extract node info + parts = line.split() + child = int(parts[0]) + parent = int(parts[-1]) + radius = read_radius(parts[-2]) + xyz = self.read_xyz(parts[2:5], offset=offset) + + # Add node + graph.add_node(child, radius=radius, xyz=xyz) + if parent != -1: + graph.add_edge(parent, child) + return graph + + def parse_old(self, content): """ Parses an swc file to extract the content which is stored in a dict. Note that node_ids from swc are refactored to index from 0 to n-1 @@ -618,6 +649,11 @@ def set_radius(graph, i): # --- Miscellaneous --- +def read_radius(radius_str): + radius = float(radius_str) + return radius / 1000 if radius > 100 else radius + + def to_graph(swc_dict, swc_id=None, set_attrs=False): """ Converts an dictionary containing swc attributes to a graph.