diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index 28e6efa..7bf40c8 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -178,11 +178,10 @@ def extract_irreducibles(self, graph): irreducibles = None self.prune_branches(graph) if compute_path_length(graph) > self.min_size: - # Extract irreducible nodes + # Irreducible nodes leafs, branchings = get_irreducible_nodes(graph) - assert len(leafs) > 0, "No leaf nodes!" - # Extract irreducible edges + # Irreducible edges edges = dict() root = None for (i, j) in nx.dfs_edges(graph, source=util.sample_once(leafs)): @@ -282,20 +281,20 @@ def get_irreducible_nodes(graph): def set_node_attrs(graph, nodes): """ - Set node attributes by extracting values from "swc_dict". + Set node attributes by extracting information from "graph". Parameters ---------- - swc_dict : dict - Contents of an swc file. + graph : networkx.Graph + Graph that contains "nodes". nodes : list List of node ids to set attributes. Returns ------- dict - Dictionary in which keys are node ids and values are a dictionary of - attributes extracted from "swc_dict". + Dictionary where keys are node ids and values are a dictionary of + attributes extracted from the input graph. """ node_attrs = dict() @@ -310,14 +309,17 @@ def set_edge_attrs(graph, edges): edge_attrs = dict() for edge, path in edges.items(): # Extract attributes + length = 0 radius_list, xyz_list = list(), list() - for i in path: + for idx, i in enumerate(path): radius_list.append(graph.nodes[i]["radius"]) xyz_list.append(graph.nodes[i]["xyz"]) + if idx > 0: + length += compute_dist(graph, path[idx], path[idx - 1]) # Set attributes edge_attrs[edge] = { - "length": 1000, + "length": length, "radius": np.array(radius_list), "xyz": np.array(xyz_list) } diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index ef020c1..242b588 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -5,7 +5,7 @@ @email: anna.grim@alleninstitute.org -Routines for working with swc files. +Routines for reading and writing swc files. """ @@ -46,8 +46,7 @@ def __init__(self, anisotropy=[1.0, 1.0, 1.0], min_size=0): anisotropy of the microscope. The default is [1.0, 1.0, 1.0]. min_size : int, optional Threshold on the number of nodes in swc file. Only swc files with - more than "min_size" nodes are stored in "xyz_coords". The default - is 0. + more than "min_size" nodes are processed. The default is 0. Returns ------- @@ -59,19 +58,19 @@ def __init__(self, anisotropy=[1.0, 1.0, 1.0], min_size=0): def load(self, swc_pointer): """ - Load data based on the type and format of the provided "swc_pointer". + Loads swc files specififed by "swc_pointer" and builds an attributed + graphs from them. Parameters ---------- swc_pointer : dict, list, str - Object that points to swcs to be read, see class documentation for - details. + Object that points to swc files to be read, see class documentation + for details. Returns ------- - List[dict] - List of dictionaries whose keys and values are the attribute name - and values from an swc file. + List[networkx.Graph] or networkx.Graph + Attributed graphs. """ if type(swc_pointer) is dict: @@ -88,55 +87,52 @@ def load(self, swc_pointer): return self.load_from_local_paths(paths) raise Exception("SWC Pointer is not Valid!") - # --- Load subroutines --- - def load_from_local_paths(self, swc_paths): + def load_from_local_paths(self, path_list): """ - Reads swc files from local machine, then returns either the xyz - coordinates or graphs. + Reads swc files from local machine and builds an attributed graph + from them. Paramters --------- - swc_paths : list - List of paths to swc files stored on the local machine. + path_list : List[str] + Paths to swc files on the local machine. Returns ------- - List[dict] - List of dictionaries whose keys and values are the attribute name - and values from an swc file. + List[networkx.Graph] + Attributed graphs. """ with ProcessPoolExecutor(max_workers=1) as executor: # Assign processes processes = list() - for path in swc_paths: + for path in path_list: processes.append( executor.submit(self.load_from_local_path, path) ) # Store results - swc_dicts = list() + graphs = list() for process in as_completed(processes): result = process.result() if result: - swc_dicts.append(result) - return swc_dicts + graphs.append(result) + return graphs def load_from_local_path(self, path): """ - Reads a single swc file from local machine, then returns either the - xyz coordinates or graphs. + Reads a single swc file on local machine and builds an attributed + graph from it. Paramters --------- path : str - Path to swc file stored on the local machine. + Path to swc file on the local machine. Returns ------- - List[dict] - List of dictionaries whose keys and values are the attribute name - and values from an swc file. + networkx.Graph + Attributed graph. """ content = util.read_txt(path) @@ -145,112 +141,105 @@ def load_from_local_path(self, path): graph.graph["swc_id"] = util.get_swc_id(path) return graph else: - return False + return None def load_from_local_zip(self, zip_path): """ - Reads swc files from zip on the local machine, then returns either the - xyz coordinates or graph. Note this routine is hard coded for computing - projected run length. + Reads swc files from a zip file and builds attributed graphs from + them. Paramters --------- - swc_paths : Container - If swc files are on local machine, list of paths to swc files where - each file corresponds to a neuron in the prediction. If swc files - are on cloud, then dict with keys "bucket_name" and "path". + zip_path : str + Path to zip file to be read. Returns ------- - dict - Dictionary that maps an swc_id to the the xyz coordinates read from - that swc file. + List[networkx.Graph] + Attributed graphs. """ with ZipFile(zip_path, "r") as zip_file: - swc_dicts = list() + graphs = list() swc_files = [f for f in zip_file.namelist() if f.endswith(".swc")] for f in tqdm(swc_files, desc="Loading Fragments"): - result = self.load_from_zipped_file(zip_file, f) + result = self.load_from_zip(zip_file, f) if result: - swc_dicts.append(result) - return swc_dicts + graphs.append(result) + return graphs def load_from_gcs(self, gcs_dict): """ - Reads swc files from zips on a GCS bucket. + Reads swc files from zips on a GCS bucket and builds attributed + graphs from them. Parameters ---------- gcs_dict : dict - Dictionary where keys are "bucket_name" and "path". + Dictionary with the keys "bucket_name" and "path" used to read + swcs from GCS bucket Returns ------- - dict - Dictionary that maps an swc_id to the the xyz coordinates read from - that swc file. + List[networkx.Graph] + Attributed graphs. """ - # Initializations bucket = storage.Client().bucket(gcs_dict["bucket_name"]) zip_paths = util.list_gcs_filenames(bucket, gcs_dict["path"], ".zip") - - # Main with ProcessPoolExecutor() as executor: # Assign processes processes = list() for path in tqdm(zip_paths, desc="Download SWCs"): - zip_content = bucket.blob(path).download_as_bytes() + zip_bytes = bucket.blob(path).download_as_bytes() processes.append( - executor.submit(self.load_from_cloud_zip, zip_content) + executor.submit(self.load_from_cloud_zip, zip_bytes) ) # Store results - swc_dicts = list() + graphs = list() for process in as_completed(processes): - swc_dicts.extend(process.result()) - return swc_dicts + graphs.extend(process.result()) + return graphs - def load_from_cloud_zip(self, zip_content): + def load_from_cloud_zip(self, zip_bytes): """ - Reads swc files from a zip that has been downloaded from a cloud - bucket. + Reads swc files from a zip and builds attributed graphs from them. Parameters ---------- - zip_content : ... - content of a zip file. + zip_bytes : bytes + Contents of a zip file in byte format. Returns ------- - dict - Dictionary that maps an swc_id to the the xyz coordinates read from - that swc file. + List[networkx.Graph] + Attributed graphs. """ - with ZipFile(BytesIO(zip_content)) as zip_file: + with ZipFile(BytesIO(zip_bytes)) as zip_file: with ThreadPoolExecutor() as executor: # Assign threads threads = list() - for f in util.list_files_in_zip(zip_content): + for f in util.list_files_in_zip(zip_bytes): threads.append( executor.submit( - self.load_from_zipped_file, zip_file, f + self.load_from_zip, zip_file, f ) ) # Process results - swc_dicts = list() + graphs = list() for thread in as_completed(threads): result = thread.result() if result: - swc_dicts.append(result) - return swc_dicts + graphs.append(result) + return graphs - def load_from_zipped_file(self, zip_file, path): + def load_from_zip(self, zip_file, path): """ - Reads swc file stored at "path" which points to a file in a zip. + Reads swc files at in a zip file at "path" and builds attributed + graphs from them. Parameters ---------- @@ -261,9 +250,8 @@ def load_from_zipped_file(self, zip_file, path): Returns ------- - dict - Dictionary that maps an swc_id to the the xyz coordinates or graph - read from that swc file. + networkx.Graph + Attributed graph. """ content = util.read_zip(zip_file, path).splitlines() @@ -274,10 +262,10 @@ def load_from_zipped_file(self, zip_file, path): else: return False - # --- Process swc content --- + # --- Process SWC Contents --- def parse(self, content): """ - Reads an swc file and builds an undirected graph from it. + Reads an swc file and builds an attributed graphs from it. Parameters ---------- @@ -297,7 +285,7 @@ def parse(self, content): parts = line.split() child = int(parts[0]) parent = int(parts[-1]) - radius = read_radius(parts[-2]) + radius = self.read_radius(parts[-2]) xyz = self.read_xyz(parts[2:5], offset=offset) # Add node @@ -306,54 +294,16 @@ def parse(self, content): graph.add_edge(parent, child) return graph - def parse_old(self, content): - """ - Parses an swc file to extract the content which is stored in a dict. - Note that node_ids from swc are refactored to index from 0 to n-1 - where n is the number of entries in the swc file. - - Parameters - ---------- - content : List[str] - List of entries from an swc file. - - Returns - ------- - dict - Dictionaries whose keys and values are the attribute name - and values from an swc file. - - """ - # Parse swc content - content, offset = self.process_content(content) - swc_dict = { - "id": np.zeros((len(content)), dtype=np.int32), - "radius": np.zeros((len(content)), dtype=np.float32), - "pid": np.zeros((len(content)), dtype=np.int32), - "xyz": np.zeros((len(content), 3), dtype=np.float32), - } - for i, line in enumerate(content): - parts = line.split() - swc_dict["id"][i] = parts[0] - swc_dict["radius"][i] = float(parts[-2]) - swc_dict["pid"][i] = parts[-1] - swc_dict["xyz"][i] = self.read_xyz(parts[2:5], offset) - - # Check whether radius is in nanometers - if swc_dict["radius"][0] > 100: - swc_dict["radius"] /= 1000 - return swc_dict - def process_content(self, content): """ - Processes lines of text from a content source, extracting an offset - value and returning the remaining content starting from the line - immediately after the last commented line. + Processes lines of text from an swc file by iterating over commented + lines to extract offset (if present) and finds the line after the last + commented line. Parameters ---------- content : List[str] - List of strings where each string represents a line of text. + List of strings that represent a line of a text file. Returns ------- @@ -393,19 +343,38 @@ def read_xyz(self, xyz_str, offset=[0.0, 0.0, 0.0]): xyz[i] = self.anisotropy[i] * (float(xyz_str[i]) + offset[i]) return xyz + def read_radius(self, radius_str): + """ + Converts a radius string to a float and adjusts it if the value is in + nanometers. + + Parameters + ---------- + radius_str : str + A string representing the radius value. + + Returns + ------- + float + Radius. + + """ + radius = float(radius_str) + return radius / 1000 if radius > 100 else radius + # --- Write --- def write(path, content, color=None): """ - Write content to a specified file in a format based on the type o - f content. + Writes an swc from the given "content" which is either a list of entries + or a graph. Parameters ---------- path : str - File path where the content will be written. - content : list, dict, nx.Graph - The content to be written. + Path where the content is to be written. + content : List[str] or networkx.Graph + Content of swc file to be written. color : str, optional Color of swc to be written. The default is None. @@ -416,8 +385,6 @@ def write(path, content, color=None): """ if type(content) is list: write_list(path, content, color=color) - elif type(content) is dict: - write_dict(path, content, color=color) elif type(content) is nx.Graph: write_graph(path, content, color=color) else: @@ -432,8 +399,8 @@ def write_list(path, entry_list, color=None): ---------- path : str Path that swc will be written to. - entry_list : list[str] - List of entries that will be written to an swc file. + entry_list : List[str] + List of entries to be written to an swc file. color : str, optional Color of swc to be written. The default is None. @@ -443,7 +410,7 @@ def write_list(path, entry_list, color=None): """ with open(path, "w") as f: - # Preamble + # Comments if color is not None: f.write("# COLOR " + color) else: @@ -454,33 +421,10 @@ def write_list(path, entry_list, color=None): f.write("\n" + entry) -def write_dict(path, swc_dict, color=None): - """ - Writes the dictionary to an swc file. - - Parameters - ---------- - path : str - Path that swc will be written to. - swc_dict : dict - Dictionaries whose keys and values are the attribute name and values - from an swc file. - color : str, optional - Color of swc to be written. The default is None. - - Returns - ------- - None - - """ - graph, _ = to_graph(swc_dict, set_attrs=True) - write_graph(path, graph, color=color) - - def write_graph(path, graph, color=None): """ - Makes a list of entries to be written in an swc file. This routine assumes - that "graph" has a single connected components. + Writes a graph to an swc file. This routine assumes that "graph" has a + single connected component. Parameters ---------- @@ -491,8 +435,7 @@ def write_graph(path, graph, color=None): Returns ------- - List[str] - List of swc file entries to be written. + None """ node_to_idx = {-1: -1} @@ -646,75 +589,3 @@ def set_radius(graph, i): except ValueError: radius = 1.0 return radius - - -# --- Miscellaneous --- -def read_radius(radius_str): - radius = float(radius_str) - return radius / 1000 if radius > 100 else radius - - -def to_graph(swc_dict, swc_id=None, set_attrs=False): - """ - Converts an dictionary containing swc attributes to a graph. - - Parameters - ---------- - swc_dict : dict - Dictionaries whose keys and values are the attribute name and values - from an swc file. - swc_id : str, optional - Identifier that dictionary was generated from. The default is None. - set_attrs : bool, optional - Indication of whether to set attributes. The default is False. - - Returns - ------- - networkx.Graph - Graph generated from "swc_dict". - - """ - graph = nx.Graph(graph_id=swc_id) - graph.add_edges_from(zip(swc_dict["id"][1:], swc_dict["pid"][1:])) - if set_attrs: - xyz = swc_dict["xyz"] - if type(swc_dict["xyz"]) is np.ndarray: - xyz = util.numpy_to_hashable(swc_dict["xyz"]) - graph = __add_attributes(swc_dict, graph) - xyz_to_node = dict(zip(xyz, swc_dict["id"])) - return graph, xyz_to_node - return graph - - -def __add_attributes(swc_dict, graph): - """ - Adds node attributes to a NetworkX graph based on information from - "swc_dict". - - Parameters: - ---------- - swc_dict : dict - A dictionary containing SWC data. It must have the following keys: - - "id": A list of node identifiers (unique for each node). - - "xyz": A list of 3D coordinates (x, y, z) for each node. - - "radius": A list of radii for each node. - - graph : networkx.Graph - A NetworkX graph object to which the attributes will be added. - The graph must contain nodes that correspond to the IDs in - "swc_dict["id"]". - - Returns: - ------- - networkx.Graph - The modified graph with added node attributes for each node. - - """ - attrs = dict() - for idx, node in enumerate(swc_dict["id"]): - attrs[node] = { - "xyz": swc_dict["xyz"][idx], - "radius": swc_dict["radius"][idx], - } - nx.set_node_attributes(graph, attrs) - return graph