Skip to content

Commit

Permalink
minor upds (#185)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Jul 1, 2024
1 parent af2800c commit 432f71f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
35 changes: 20 additions & 15 deletions src/deep_neurographs/doubles_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
Module that removes doubled fragments from a NeuroGraph.
"""

import networkx as nx
import numpy as np
from time import time
import sys

from deep_neurographs import geometry
from deep_neurographs import utils


COLOR = "1.0 0.0 0.0"

def run(neurograph, max_size, node_spacing, output_dir=None):
Expand Down Expand Up @@ -41,8 +43,7 @@ def run(neurograph, max_size, node_spacing, output_dir=None):
# Initializations
components = list(nx.connected_components(neurograph))
doubles_cnt = 0
neurograph.init_kdtree()


cnt = 1
t0, t1 = utils.init_timers()
chunk_size = int(len(components) * 0.02)
Expand All @@ -51,15 +52,14 @@ def run(neurograph, max_size, node_spacing, output_dir=None):
for i, idx in enumerate(np.argsort([len(c) for c in components])):
# Determine whether to inspect fragment
nodes = components[idx]
swc_id = get_swc_id(neurograph, nodes)
if swc_id not in not_doubles:
xyz_arr = inspect_component(neurograph, nodes)
if len(xyz_arr) * node_spacing < max_size:
if is_double(neurograph, xyz_arr, swc_id):
doubles_cnt += 1
if output_dir:
neurograph.to_swc(output_dir, nodes, color=COLOR)
neurograph = remove_component(neurograph, nodes, swc_id)
xyz_arr = inspect_component(neurograph, nodes)
if len(xyz_arr) * node_spacing < max_size:
swc_id = get_swc_id(neurograph, nodes)
if is_double(neurograph, xyz_arr, swc_id):
if output_dir:
neurograph.to_swc(output_dir, nodes, color=COLOR)
neurograph = remove_component(neurograph, nodes, swc_id)
doubles_cnt += 1

# Update progress bar
if i >= cnt * chunk_size:
Expand Down Expand Up @@ -102,7 +102,7 @@ def is_double(neurograph, fragment, swc_id_i):
swc_id_j = neurograph.xyz_to_swc(xyz_j)
if swc_id_i != swc_id_j:
d = geometry.dist(xyz_i, xyz_j)
hits_i = check_hits(hits_i, swc_id_j, d)
hits_i = upd_hits(hits_i, swc_id_j, d)
except:
pass
if len(hits_i) > 0:
Expand All @@ -111,7 +111,7 @@ def is_double(neurograph, fragment, swc_id_i):
hits = utils.append_dict_value(hits, best_swc_id, best_dist)

# Check criteria
for swc_id_j, dists in hits.items():
for dists in hits.value():
percent_hit = len(dists) / len(fragment)
std = np.std(dists)
if percent_hit > 0.5 and std < 2:
Expand Down Expand Up @@ -220,7 +220,12 @@ def remove_xyz_entries(neurograph, i, j):
del neurograph.xyz_to_edge[tuple(xyz)]
return neurograph

def check_hits(hits, key, value):

def upd_hits(hits, key, value):
"""
Updates "hits" by adding ("key", "value") if this item does not exist.
Otherwise, checks whether "value" is less than "hits[key"]".
"""
if key in hits.keys():
if value < hits[key]:
hits[key] = value
Expand Down
1 change: 0 additions & 1 deletion src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ def prune_short_connectors(graph, length=8):
cnt += 1

# Finish
print("# Potential Merge Sites Detected:", cnt)
graph.remove_nodes_from(list(pruned_nodes))
return graph

Expand Down

0 comments on commit 432f71f

Please sign in to comment.