Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: doubles removal #172

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions src/deep_neurographs/doubles_removal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
Created on Sat June 25 9:00:00 2024

@author: Anna Grim
@email: [email protected]

Module that removes doubled fragments from a NeuroGraph.

"""

from deep_neurographs import utils
import networkx as nx


def run(neurograph, max_size, node_spacing):
"""
Removes connected components from "neurgraph" that are likely to be a
double.

Parameters
----------
neurograph : NeuroGraph
Graph to be searched for doubles.
max_size : int
Maximum size of connected components to be searched.
node_spacing : int
Expected distance in microns between nodes in "neurograph".

Returns
-------
NeuroGraph
Graph with doubles removed.

"""
# Assign processes
doubles_cnt = 0
neurograph.init_kdtree()
not_doubles = set()
for nodes in list(nx.connected_components(neurograph)):
# Determine whether to inspect fragment
swc_id = get_swc_id(neurograph, nodes)
if swc_id not in not_doubles:
xyz_arr = inspect_component(neurograph, nodes)
if len(xyz_arr) > 0 and len(xyz_arr) * node_spacing < max_size:
not_double_id = is_double(neurograph, xyz_arr, swc_id)
if not_double_id:
doubles_cnt += 1
neurograph = remove_component(neurograph, nodes, swc_id)
not_doubles.add(not_double_id)
print("# Doubles detected:", doubles_cnt)


def is_double(neurograph, xyz_arr, swc_id_i):
"""
Determines whether the connected component corresponding to "root" is a
double of another connected component.

Paramters
---------
neurograph : NeuroGraph
Graph to be searched for doubles.
xyz_arr : numpy.ndarray
Array containing xyz coordinates corresponding to some fragment (i.e.
connected component in neurograph).
swc_id_i : str
swc id corresponding to fragment.

Returns
-------
str or None
Indication of whether connected component is a double. If True, the
swc_id of the main fragment (i.e. non doubles) is returned. Otherwise,
the value None is returned to indicate that query fragment is not a
double.

"""
# Compute projections
hits = dict()
for xyz_i in xyz_arr:
for xyz_j in neurograph.query_kdtree(xyz_i, 6):
try:
swc_id_j = neurograph.xyz_to_swc(xyz_j)
if swc_id_i != swc_id_j:
hits = utils.append_dict_value(hits, swc_id_j, 1)
except:
pass

# Check criteria
if len(hits) > 0:
swc_id_j = utils.find_best(hits)
percent_hit = len(hits[swc_id_j]) / len(xyz_arr)
else:
percent_hit = 0
return swc_id_j if swc_id_j is not None and percent_hit > 0.5 else None


# --- utils ---
def get_swc_id(neurograph, nodes):
"""
Gets the swc id corresponding to "nodes".

Parameters
----------
neurograph : NeuroGraph
Graph containing "nodes".
nodes : list[int]
Nodes to be checked.

Returns
-------
str
swc id of "nodes".

"""
i = utils.sample_singleton(nodes)
return neurograph.nodes[i]["swc_id"]


def inspect_component(neurograph, nodes):
"""
Determines whether to inspect component for doubles.

Parameters
----------
neurograph : NeuroGraph
Graph to be searched.
nodes : iterable
Nodes that comprise a connected component.

Returns
-------
numpy.ndarray or list
Array containing xyz coordinates of nodes.

"""
if len(nodes) == 2:
i, j = tuple(nodes)
return neurograph.edges[i, j]["xyz"]
else:
return []


def remove_component(neurograph, nodes, swc_id):
"""
Removes "nodes" from "neurograph".

Parameters
----------
neurograph : NeuroGraph
Graph that contains "nodes".
nodes : list[int]
Nodes to be removed.
swc_id : str
swc id corresponding to nodes which comprise a connected component in
"neurograph".

Returns
-------
NeuroGraph
Graph with nodes removed.

"""
i, j = tuple(nodes)
neurograph = remove_xyz_entries(neurograph, i, j)
neurograph.remove_nodes_from([i, j])
neurograph.leafs.remove(i)
neurograph.leafs.remove(j)
neurograph.swc_ids.remove(swc_id)
return neurograph


def remove_xyz_entries(neurograph, i, j):
"""
Removes dictionary entries from "neurograph.xyz_to_edge" corresponding to
the edge {i, j}.

Parameters
----------
neurograph : NeuroGraph
Graph to be updated.
i : int
Node in "neurograph".
j : int
Node in "neurograph".

Returns
-------
NeuroGraph
Updated graph.

"""
for xyz in neurograph.edges[i, j]["xyz"]:
del neurograph.xyz_to_edge[tuple(xyz)]
return neurograph
29 changes: 0 additions & 29 deletions src/deep_neurographs/generate_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,35 +199,6 @@ def get_conection(neurograph, leaf, xyz):
return neurograph, node


def is_valid(neurograph, i, filter_doubles):
"""
Determines whether is a valid node to generate proposals from. A node is
considered valid if it is not contained in a doubled connected component
(if applicable).

Parameters
----------
neurograph : NeuroGraph
Graph built from swc files.
i : int
Node to be validated.
filter_doubles : bool
Indication of whether to prevent proposals from being connected to a
doubled connected component.

Returns
-------
bool
Indication of whether node is valid.

"""
if filter_doubles:
neurograph.upd_doubles(i)
swc_id = neurograph.nodes[i]["swc_id"]
return True if swc_id in neurograph.doubles else False


# -- utils --
def get_closer_endpoint(neurograph, edge, xyz):
i, j = tuple(edge)
d_i = geometry.dist(neurograph.nodes[i]["xyz"], xyz)
Expand Down
9 changes: 6 additions & 3 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy):
)
if i > cnt * chunk_size:
cnt, t1 = report_progress(
i, len(zip_paths), chunk_size, cnt, t0, t1
i + 1, len(zip_paths), chunk_size, cnt, t0, t1
)

# Store results
Expand Down Expand Up @@ -306,11 +306,14 @@ def build_neurograph(
t0, t1 = utils.init_timers()
chunk_size = max(int(n_components * 0.02), 1)
cnt, i = 1, 0
n_components = len(irreducibles)
while len(irreducibles):
irreducible_set = irreducibles.pop()
neurograph.add_component(irreducible_set)
if i > cnt * chunk_size and progress_bar:
cnt, t1 = report_progress(i, n_components, chunk_size, cnt, t0, t1)
cnt, t1 = report_progress(
i + 2, n_components, chunk_size, cnt, t0, t1
)
i += 1
if progress_bar:
t, unit = utils.time_writer(time() - t0)
Expand Down Expand Up @@ -364,7 +367,7 @@ def get_irreducibles(
n_edges += count_edges(irreducibles_i)
if i > progress_cnt * chunk_size and progress_bar:
progress_cnt, t1 = report_progress(
i, n_components, chunk_size, progress_cnt, t0, t1
i + 1, n_components, chunk_size, progress_cnt, t0, t1
)
if progress_bar:
t, unit = utils.time_writer(time() - t0)
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def run_without_seeds(
# Report progress
if i > progress_cnt * chunk_size and progress_bar:
progress_cnt, t1 = utils.report_progress(
i, n_batches, chunk_size, progress_cnt, t0, t1
i + 1, n_batches, chunk_size, progress_cnt, t0, t1
)
t0, t1 = utils.init_timers()

Expand Down
Loading
Loading