-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: iteratively prune branches (#272)
Co-authored-by: anna-grim <[email protected]>
- Loading branch information
Showing
6 changed files
with
220 additions
and
205 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
@email: [email protected] | ||
Module that removes doubled fragments and trims branches that pass by each | ||
other from a NeuroGraph. | ||
other from a FragmentsGraph. | ||
""" | ||
from collections import defaultdict | ||
|
@@ -21,57 +21,70 @@ | |
QUERY_DIST = 15 | ||
|
||
|
||
# --- Curvy Removal --- | ||
def remove_curvy(graph, max_length, ratio=0.5): | ||
deleted_ids = set() | ||
components = [c for c in connected_components(graph) if len(c) == 2] | ||
for nodes in tqdm(components, desc="Curvy Filter:"): | ||
if len(nodes) == 2: | ||
i, j = tuple(nodes) | ||
length = graph.edges[i, j]["length"] | ||
endpoint_dist = graph.dist(i, j) | ||
if endpoint_dist / length < ratio and length < max_length: | ||
deleted_ids.add(graph.edges[i, j]["swc_id"]) | ||
delete_fragment(graph, i, j) | ||
return len(deleted_ids) | ||
|
||
|
||
# --- Doubles Removal --- | ||
def remove_doubles(neurograph, max_size, node_spacing, output_dir=None): | ||
def remove_doubles(graph, max_length, node_spacing, output_dir=None): | ||
""" | ||
Removes connected components from "neurgraph" that are likely to be a | ||
double. | ||
Parameters | ||
---------- | ||
neurograph : NeuroGraph | ||
graph : FragmentsGraph | ||
Graph to be searched for doubles. | ||
max_size : int | ||
max_length : int | ||
Maximum size of connected components to be searched. | ||
node_spacing : int | ||
Expected distance in microns between nodes in "neurograph". | ||
Expected distance in microns between nodes in "graph". | ||
output_dir : str or None, optional | ||
Directory that doubles will be written to. The default is None. | ||
Returns | ||
------- | ||
NeuroGraph | ||
graph | ||
Graph with doubles removed. | ||
""" | ||
# Initializations | ||
components = [c for c in connected_components(neurograph) if len(c) == 2] | ||
deleted = set() | ||
kdtree = neurograph.get_kdtree() | ||
components = [c for c in connected_components(graph) if len(c) == 2] | ||
deleted_ids = set() | ||
kdtree = graph.get_kdtree() | ||
if output_dir: | ||
util.mkdir(output_dir, delete=True) | ||
|
||
# Main | ||
desc = "Doubles Detection" | ||
desc = "Doubles Filtering" | ||
for idx in tqdm(np.argsort([len(c) for c in components]), desc=desc): | ||
i, j = tuple(components[idx]) | ||
swc_id = neurograph.nodes[i]["swc_id"] | ||
if swc_id not in deleted: | ||
if len(neurograph.edges[i, j]["xyz"]) * node_spacing < max_size: | ||
swc_id = graph.nodes[i]["swc_id"] | ||
if swc_id not in deleted_ids: | ||
if graph.edges[i, j]["length"] < max_length: | ||
# Check doubles criteria | ||
n_points = len(neurograph.edges[i, j]["xyz"]) | ||
hits = compute_projections(neurograph, kdtree, (i, j)) | ||
n_points = len(graph.edges[i, j]["xyz"]) | ||
hits = compute_projections(graph, kdtree, (i, j)) | ||
if check_doubles_criteria(hits, n_points): | ||
if output_dir: | ||
neurograph.to_swc( | ||
output_dir, components[idx], color=COLOR | ||
) | ||
neurograph = delete(neurograph, components[idx], swc_id) | ||
deleted.add(swc_id) | ||
return len(deleted) | ||
graph.to_swc(output_dir, [i, j], color=COLOR) | ||
delete_fragment(graph, i, j) | ||
deleted_ids.add(swc_id) | ||
return len(deleted_ids) | ||
|
||
|
||
def compute_projections(neurograph, kdtree, edge): | ||
def compute_projections(graph, kdtree, edge): | ||
""" | ||
Given a fragment defined by "edge", this routine iterates of every xyz in | ||
the fragment and projects it onto the closest fragment. For each detected | ||
|
@@ -80,11 +93,11 @@ def compute_projections(neurograph, kdtree, edge): | |
Parameters | ||
---------- | ||
neurograph : NeuroGraph | ||
graph : graph | ||
Graph that contains "edge". | ||
kdtree : KDTree | ||
KD-Tree that contains all xyz coordinates of every fragment in | ||
"neurograph". | ||
"graph". | ||
edge : tuple | ||
Pair of leaf nodes that define a fragment. | ||
|
@@ -96,13 +109,13 @@ def compute_projections(neurograph, kdtree, edge): | |
""" | ||
hits = defaultdict(list) | ||
query_id = neurograph.edges[edge]["swc_id"] | ||
for i, xyz in enumerate(neurograph.edges[edge]["xyz"]): | ||
query_id = graph.edges[edge]["swc_id"] | ||
for i, xyz in enumerate(graph.edges[edge]["xyz"]): | ||
# Compute projections | ||
best_id = None | ||
best_dist = np.inf | ||
for hit_xyz in geometry.query_ball(kdtree, xyz, QUERY_DIST): | ||
hit_id = neurograph.xyz_to_swc(hit_xyz) | ||
hit_id = graph.xyz_to_swc(hit_xyz) | ||
if hit_id is not None and hit_id != query_id: | ||
if geometry.dist(hit_xyz, xyz) < best_dist: | ||
best_dist = geometry.dist(hit_xyz, xyz) | ||
|
@@ -144,56 +157,54 @@ def check_doubles_criteria(hits, n_points): | |
return False | ||
|
||
|
||
def delete(neurograph, nodes, swc_id): | ||
def delete_fragment(graph, i, j): | ||
""" | ||
Deletes "nodes" from "neurograph". | ||
Deletes nodes "i" and "j" from "graph", where these nodes form a connected | ||
component. | ||
Parameters | ||
---------- | ||
neurograph : NeuroGraph | ||
Graph that contains "nodes". | ||
nodes : list[int] | ||
Nodes to be removed. | ||
swc_id : str | ||
swc id corresponding to nodes which comprise a connected component in | ||
"neurograph". | ||
graph : FragmentsGraph | ||
Graph that contains nodes to be deleted. | ||
i : int | ||
Node to be removed. | ||
j : int | ||
Node to be removed. | ||
Returns | ||
------- | ||
NeuroGraph | ||
graph | ||
Graph with nodes removed. | ||
""" | ||
i, j = tuple(nodes) | ||
neurograph = remove_xyz_entries(neurograph, i, j) | ||
neurograph.remove_nodes_from([i, j]) | ||
neurograph.swc_ids.remove(swc_id) | ||
return neurograph | ||
graph = remove_xyz_entries(graph, i, j) | ||
graph.swc_ids.remove(graph.nodes[i]["swc_id"]) | ||
graph.remove_nodes_from([i, j]) | ||
|
||
|
||
def remove_xyz_entries(neurograph, i, j): | ||
def remove_xyz_entries(graph, i, j): | ||
""" | ||
Removes dictionary entries from "neurograph.xyz_to_edge" corresponding to | ||
Removes dictionary entries from "graph.xyz_to_edge" corresponding to | ||
the edge {i, j}. | ||
Parameters | ||
---------- | ||
neurograph : NeuroGraph | ||
graph : graph | ||
Graph to be updated. | ||
i : int | ||
Node in "neurograph". | ||
Node in "graph". | ||
j : int | ||
Node in "neurograph". | ||
Node in "graph". | ||
Returns | ||
------- | ||
NeuroGraph | ||
graph | ||
Updated graph. | ||
""" | ||
for xyz in neurograph.edges[i, j]["xyz"]: | ||
del neurograph.xyz_to_edge[tuple(xyz)] | ||
return neurograph | ||
for xyz in graph.edges[i, j]["xyz"]: | ||
del graph.xyz_to_edge[tuple(xyz)] | ||
return graph | ||
|
||
|
||
def upd_hits(hits, key, value): | ||
|
Oops, something went wrong.