From 3cec96907e2e714f67237c5290050193b4ad03c6 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Mon, 11 Nov 2024 17:54:53 -0800 Subject: [PATCH] bug: circular arg fixed, class rename (#279) Co-authored-by: anna-grim --- src/deep_neurographs/neurograph.py | 9 ++-- src/deep_neurographs/utils/graph_util.py | 31 ++++++------ src/deep_neurographs/visualization.py | 62 +++++++++++++----------- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index d5af008..c958bed 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -18,13 +18,12 @@ from numpy import concatenate from scipy.spatial import KDTree -from deep_neurographs import generate_proposals, geometry +from deep_neurographs import generate_proposals, geometry, utils from deep_neurographs.groundtruth_generation import init_targets -from deep_neurographs.utils import graph_util as gutil from deep_neurographs.utils import img_util, util -class NeuroGraph(nx.Graph): +class FragmentsGraph(nx.Graph): """ A class of graphs whose nodes correspond to irreducible nodes from the predicted swc files. @@ -48,7 +47,7 @@ def __init__(self, img_bbox=None, node_spacing=1): None """ - super(NeuroGraph, self).__init__() + super(FragmentsGraph, self).__init__() # General class attributes self.leaf_kdtree = None self.node_cnt = 0 @@ -97,7 +96,7 @@ def set_proxy_soma_ids(self, k): None """ - for i in gutil.largest_components(self, k): + for i in utils.graph_util.largest_components(self, k): self.soma_ids[self.nodes[i]["swc_id"]] = i def get_leafs(self): diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index 2201508..f8ef9c3 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -5,7 +5,7 @@ @email: anna.grim@alleninstitute.org -Routines for loading fragments and building a neurograph. +Routines for loading fragments and building a fragments_graph. Terminology @@ -31,7 +31,6 @@ from tqdm import tqdm from deep_neurographs import geometry -from deep_neurographs.neurograph import NeuroGraph from deep_neurographs.utils import img_util, swc_util, util MIN_SIZE = 30 @@ -82,8 +81,7 @@ def __init__( Returns ------- - FragmentsGraph - FragmentsGraph generated from swc files. + None """ self.anisotropy = anisotropy @@ -120,6 +118,8 @@ def run( FragmentsGraph generated from swc files. """ + from deep_neurographs.neurograph import FragmentsGraph + # Load fragments and extract irreducibles self.img_bbox = img_util.init_bbox(img_patch_origin, img_patch_shape) swc_dicts = self.reader.load(fragments_pointer) @@ -129,13 +129,13 @@ def run( if self.progress_bar: pbar = tqdm(total=len(irreducibles), desc="Combine Graphs") - neurograph = NeuroGraph(node_spacing=self.node_spacing) + fragments_graph = FragmentsGraph(node_spacing=self.node_spacing) while len(irreducibles): irreducible_set = irreducibles.pop() - neurograph.add_component(irreducible_set) + fragments_graph.add_component(irreducible_set) if self.progress_bar: pbar.update(1) - return neurograph + return fragments_graph # --- Graph structure extraction --- def schedule_processes(self, swc_dicts): @@ -645,7 +645,8 @@ def compute_dist(graph, i, j): Returns ------- - Euclidean distance between i and j. + float + Euclidean distance between i and j. """ return geometry.dist(graph.nodes[i]["xyz"], graph.nodes[j]["xyz"]) @@ -686,6 +687,7 @@ def get_leafs(graph): ------- list Leaf nodes "graph". + """ return [i for i in graph.nodes if graph.degree[i] == 1] @@ -746,20 +748,21 @@ def count_components(graph): Graph to be searched. Returns - ------- - Number of connected components. + -------' + int + Number of connected components. """ return nx.number_connected_components(graph) -def largest_components(neurograph, k): +def largest_components(graph, k): """ - Finds the "k" largest connected components in "neurograph". + Finds the "k" largest connected components in "graph". Parameters ---------- - neurograph : NeuroGraph + graph : nx.Graph Graph to be searched. k : int Number of largest connected components to return. @@ -773,7 +776,7 @@ def largest_components(neurograph, k): """ component_cardinalities = k * [-1] node_ids = k * [-1] - for nodes in nx.connected_components(neurograph): + for nodes in nx.connected_components(graph): if len(nodes) > component_cardinalities[-1]: i = 0 while i < k: diff --git a/src/deep_neurographs/visualization.py b/src/deep_neurographs/visualization.py index 4f8b9cf..5ca805d 100644 --- a/src/deep_neurographs/visualization.py +++ b/src/deep_neurographs/visualization.py @@ -16,7 +16,7 @@ def visualize_connected_components( - graph, line_width=3, return_data=False, title="" + graph, width=3, return_data=False, title="" ): """ Visualizes the connected components in "graph". @@ -25,8 +25,8 @@ def visualize_connected_components( ---------- graph : networkx.Graph Graph to be visualized. - line_width : int, optional - Line width used to plot "subset". The default is 5. + width : int, optional + Line width used to plot edges in "subset". The default is 5. return_data : bool, optional Indication of whether to return data object that is used to generate plot. The default is False. @@ -50,7 +50,7 @@ def visualize_connected_components( color = colors[cnt % len(colors)] data.extend( plot_edges( - graph, subgraph.edges, color=color, line_width=line_width + graph, subgraph.edges, color=color, width=width ) ) cnt += 1 @@ -85,16 +85,18 @@ def visualize_graph(graph, title=""): plot(data, title) -def visualize_proposals(graph, target_graph=None, title="Proposals"): +def visualize_proposals( + graph, color=None, groundtruth_graph=None, title="Proposals" +): """ - Visualizes a graph with proposals. + Visualizes a graph and its proposals. Parameters ---------- graph : networkx.Graph Graph to be visualized. - target_graph : networkx.Graph, optional - Graph generated from ground truth tracings. The default is None. + groundtruth_graph : networkx.Graph, optional + Graph generated from groundtruth tracings. The default is None. title : str, optional Title of the plot. Default is "Proposals". @@ -106,24 +108,25 @@ def visualize_proposals(graph, target_graph=None, title="Proposals"): visualize_subset( graph, graph.proposals, + color=color, proposal_subset=True, - target_graph=target_graph, + groundtruth_graph=groundtruth_graph, title=title, ) -def visualize_targets( - graph, target_graph=None, title="Ground Truth - Accepted Proposals" +def visualize_groundtruth( + graph, groundtruth_graph=None, title="Ground Truth - Accepted Proposals" ): """ - Visualizes a graph and its ground truth accept proposals. + Visualizes a graph and its groundtruth accepted proposals. Parameters ---------- graph : networkx.Graph Graph to be visualized. - target_graph : networkx.Graph, optional - Graph generated from ground truth tracings. The default is None. + groundtruth_graph : networkx.Graph, optional + Graph generated from groundtruth tracings. The default is None. title : str, optional Title of the plot. Default is "Ground Truth - Accepted Proposals". @@ -136,7 +139,7 @@ def visualize_targets( graph, graph.target_edges, proposal_subset=True, - target_graph=target_graph, + groundtruth_graph=groundtruth_graph, title=title, ) @@ -144,9 +147,10 @@ def visualize_targets( def visualize_subset( graph, subset, - line_width=5, + color=None, + width=5, proposal_subset=False, - target_graph=None, + groundtruth_graph=None, title="", ): """ @@ -158,12 +162,12 @@ def visualize_subset( Graph to be visualized. subset : container Subset of edges or proposals to be visualized. - line_width : int, optional + width : int, optional Line width used to plot "subset". The default is 5. proposals_subset : bool, optional Indication of whether "subset" is a subset of proposals. The default is False. - target_graph : networkx.Graph, optional + groundtruth_graph : networkx.Graph, optional Graph generated from ground truth tracings. The default is None. title : str, optional Title of the plot. Default is "Proposals". @@ -177,13 +181,15 @@ def visualize_subset( data = plot_edges(graph, graph.edges, color="black") data.append(plot_nodes(graph)) if proposal_subset: - data.extend(plot_proposals(graph, subset, line_width=line_width)) + data.extend( + plot_proposals(graph, subset, color=color, width=width) + ) else: - data.extend(plot_edges(graph, subset, line_width=line_width)) + data.extend(plot_edges(graph, subset, width=width)) # Add target graph (if applicable) - if target_graph: - cc = visualize_connected_components(target_graph, return_data=True) + if groundtruth_graph: + cc = visualize_connected_components(groundtruth_graph, return_data=True) data.extend(cc) plot(data, title) @@ -202,12 +208,12 @@ def plot_nodes(graph): ) -def plot_proposals(graph, proposals, color=None, line_width=5): +def plot_proposals(graph, proposals, color=None, width=5): # Set preferences if color is None: - line = dict(width=line_width) + line = dict(width=width) else: - line = dict(color=color, width=line_width) + line = dict(color=color, width=width) # Add traces traces = [] @@ -225,10 +231,10 @@ def plot_proposals(graph, proposals, color=None, line_width=5): return traces -def plot_edges(graph, edges, color=None, line_width=3): +def plot_edges(graph, edges, color=None, width=3): traces = [] line = ( - dict(width=5) if color is None else dict(color=color, width=line_width) + dict(width=5) if color is None else dict(color=color, width=width) ) for i, j in edges: trace = go.Scatter3d(