Skip to content

Commit

Permalink
feat: doubles removal
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Jun 25, 2024
1 parent dd5fece commit 2784743
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 94 deletions.
194 changes: 194 additions & 0 deletions src/deep_neurographs/doubles_removal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
Created on Sat June 25 9:00:00 2024
@author: Anna Grim
@email: [email protected]
Module that removes doubled fragments from a NeuroGraph.
"""

from deep_neurographs import utils
import networkx as nx


def run(neurograph, max_size, node_spacing):
"""
Removes connected components from "neurgraph" that are likely to be a
double.
Parameters
----------
neurograph : NeuroGraph
Graph to be searched for doubles.
max_size : int
Maximum size of connected components to be searched.
node_spacing : int
Expected distance in microns between nodes in "neurograph".
Returns
-------
NeuroGraph
Graph with doubles removed.
"""
# Assign processes
doubles_cnt = 0
neurograph.init_kdtree()
not_doubles = set()
for nodes in list(nx.connected_components(neurograph)):
# Determine whether to inspect fragment
swc_id = get_swc_id(neurograph, nodes)
if swc_id not in not_doubles:
xyz_arr = inspect_component(neurograph, nodes)
if len(xyz_arr) > 0 and len(xyz_arr) * node_spacing < max_size:
not_double_id = is_double(neurograph, xyz_arr, swc_id)
if not_double_id:
doubles_cnt += 1
neurograph = remove_component(neurograph, nodes, swc_id)
not_doubles.add(not_double_id)
print("# Doubles detected:", doubles_cnt)


def is_double(neurograph, xyz_arr, swc_id_i):
"""
Determines whether the connected component corresponding to "root" is a
double of another connected component.
Paramters
---------
neurograph : NeuroGraph
Graph to be searched for doubles.
xyz_arr : numpy.ndarray
Array containing xyz coordinates corresponding to some fragment (i.e.
connected component in neurograph).
swc_id_i : str
swc id corresponding to fragment.
Returns
-------
str or None
Indication of whether connected component is a double. If True, the
swc_id of the main fragment (i.e. non doubles) is returned. Otherwise,
the value None is returned to indicate that query fragment is not a
double.
"""
# Compute projections
hits = dict()
for xyz_i in xyz_arr:
for xyz_j in neurograph.query_kdtree(xyz_i, 6):
try:
swc_id_j = neurograph.xyz_to_swc(xyz_j)
if swc_id_i != swc_id_j:
hits = utils.append_dict_value(hits, swc_id_j, 1)
except:
pass

# Check criteria
if len(hits) > 0:
swc_id_j = utils.find_best(hits)
percent_hit = len(hits[swc_id_j]) / len(xyz_arr)
else:
percent_hit = 0
return swc_id_j if swc_id_j is not None and percent_hit > 0.5 else None


# --- utils ---
def get_swc_id(neurograph, nodes):
"""
Gets the swc id corresponding to "nodes".
Parameters
----------
neurograph : NeuroGraph
Graph containing "nodes".
nodes : list[int]
Nodes to be checked.
Returns
-------
str
swc id of "nodes".
"""
i = utils.sample_singleton(nodes)
return neurograph.nodes[i]["swc_id"]


def inspect_component(neurograph, nodes):
"""
Determines whether to inspect component for doubles.
Parameters
----------
neurograph : NeuroGraph
Graph to be searched.
nodes : iterable
Nodes that comprise a connected component.
Returns
-------
numpy.ndarray or list
Array containing xyz coordinates of nodes.
"""
if len(nodes) == 2:
i, j = tuple(nodes)
return neurograph.edges[i, j]["xyz"]
else:
return []


def remove_component(neurograph, nodes, swc_id):
"""
Removes "nodes" from "neurograph".
Parameters
----------
neurograph : NeuroGraph
Graph that contains "nodes".
nodes : list[int]
Nodes to be removed.
swc_id : str
swc id corresponding to nodes which comprise a connected component in
"neurograph".
Returns
-------
NeuroGraph
Graph with nodes removed.
"""
i, j = tuple(nodes)
neurograph = remove_xyz_entries(neurograph, i, j)
neurograph.remove_nodes_from([i, j])
neurograph.leafs.remove(i)
neurograph.leafs.remove(j)
neurograph.swc_ids.remove(swc_id)
return neurograph


def remove_xyz_entries(neurograph, i, j):
"""
Removes dictionary entries from "neurograph.xyz_to_edge" corresponding to
the edge {i, j}.
Parameters
----------
neurograph : NeuroGraph
Graph to be updated.
i : int
Node in "neurograph".
j : int
Node in "neurograph".
Returns
-------
NeuroGraph
Updated graph.
"""
for xyz in neurograph.edges[i, j]["xyz"]:
del neurograph.xyz_to_edge[tuple(xyz)]
return neurograph
29 changes: 0 additions & 29 deletions src/deep_neurographs/generate_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,35 +199,6 @@ def get_conection(neurograph, leaf, xyz):
return neurograph, node


def is_valid(neurograph, i, filter_doubles):
"""
Determines whether is a valid node to generate proposals from. A node is
considered valid if it is not contained in a doubled connected component
(if applicable).
Parameters
----------
neurograph : NeuroGraph
Graph built from swc files.
i : int
Node to be validated.
filter_doubles : bool
Indication of whether to prevent proposals from being connected to a
doubled connected component.
Returns
-------
bool
Indication of whether node is valid.
"""
if filter_doubles:
neurograph.upd_doubles(i)
swc_id = neurograph.nodes[i]["swc_id"]
return True if swc_id in neurograph.doubles else False


# -- utils --
def get_closer_endpoint(neurograph, edge, xyz):
i, j = tuple(edge)
d_i = geometry.dist(neurograph.nodes[i]["xyz"], xyz)
Expand Down
9 changes: 6 additions & 3 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy):
)
if i > cnt * chunk_size:
cnt, t1 = report_progress(
i, len(zip_paths), chunk_size, cnt, t0, t1
i + 1, len(zip_paths), chunk_size, cnt, t0, t1
)

# Store results
Expand Down Expand Up @@ -306,11 +306,14 @@ def build_neurograph(
t0, t1 = utils.init_timers()
chunk_size = max(int(n_components * 0.02), 1)
cnt, i = 1, 0
n_components = len(irreducibles)
while len(irreducibles):
irreducible_set = irreducibles.pop()
neurograph.add_component(irreducible_set)
if i > cnt * chunk_size and progress_bar:
cnt, t1 = report_progress(i, n_components, chunk_size, cnt, t0, t1)
cnt, t1 = report_progress(
i + 2, n_components, chunk_size, cnt, t0, t1
)
i += 1
if progress_bar:
t, unit = utils.time_writer(time() - t0)
Expand Down Expand Up @@ -364,7 +367,7 @@ def get_irreducibles(
n_edges += count_edges(irreducibles_i)
if i > progress_cnt * chunk_size and progress_bar:
progress_cnt, t1 = report_progress(
i, n_components, chunk_size, progress_cnt, t0, t1
i + 1, n_components, chunk_size, progress_cnt, t0, t1
)
if progress_bar:
t, unit = utils.time_writer(time() - t0)
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def run_without_seeds(
# Report progress
if i > progress_cnt * chunk_size and progress_bar:
progress_cnt, t1 = utils.report_progress(
i, n_batches, chunk_size, progress_cnt, t0, t1
i + 1, n_batches, chunk_size, progress_cnt, t0, t1
)
t0, t1 = utils.init_timers()

Expand Down
Loading

0 comments on commit 2784743

Please sign in to comment.