Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat trim passes #187

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/deep_neurographs/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,24 @@ def dist(v_1, v_2, metric="l2"):
return distance.euclidean(v_1, v_2)


def length(path):
"""
Computes the length of "path".

Parameters
----------
path : list
xyz coordinates that form a path.

Returns
-------
float
Length of "path".

"""
return np.sum([dist(path[i], path[i - 1]) for i in range(1, len(path))])


def make_line(xyz_1, xyz_2, n_steps):
"""
Generates a series of points representing a straight line between two 3D
Expand Down
149 changes: 112 additions & 37 deletions src/deep_neurographs/graph_artifact_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from time import time
import sys

from deep_neurographs import graph_utils as gutils
from deep_neurographs import utils
from deep_neurographs.geometry import dist
from deep_neurographs.geometry import dist, length

COLOR = "1.0 0.0 0.0"
MAX_DEPTH = 16
Expand Down Expand Up @@ -228,49 +229,105 @@ def trim_passings(neurograph):
# Main
n_endpoints_trimmed = 0
for leaf in neurograph.leafs:
trim_bool = inspect_branch(neurograph, leaf)
if trim_bool:
n_endpoints_trimmed += 1
if leaf in neurograph.nodes:
hits = search_along_branch(neurograph, leaf)
if len(hits) == 1:
simple_trim(neurograph, leaf, hits)
n_endpoints_trimmed += 1
print("")
elif len(hits) > 1:
print("complex trim")
print("# hits:", n_endpoints_trimmed)


def inspect_branch(neurograph, i):
trim_bool = False
swc_id = neurograph.nodes[i]["swc_id"]
for xyz, radius in get_branch(neurograph, i):
hits = search_along_branch(neurograph, swc_id, xyz, radius)
if len(hits) > 0:
trim_bool = True

hits = keep_passings(hits)

if trim_bool:
print("")
return False


def search_along_branch(neurograph, swc_id_leaf, xyz_leaf, radius):
def search_along_branch(neurograph, leaf):
hits = dict()
for xyz in neurograph.query_kdtree(xyz_leaf, radius):
try:
swc_id = neurograph.xyz_to_swc(xyz)
if swc_id != swc_id_leaf:
hits = utils.append_dict_value(hits, swc_id, xyz)
d = dist(xyz, xyz_leaf)
print(swc_id_leaf, swc_id, d, xyz)
except:
pass
return hits
swc_id_leaf = neurograph.nodes[leaf]["swc_id"]
for xyz_leaf, radius in get_branch(neurograph, leaf):
for xyz in neurograph.query_kdtree(xyz_leaf, radius):
try:
swc_id = neurograph.xyz_to_swc(xyz)
if swc_id != swc_id_leaf:
hits = utils.append_dict_value(hits, swc_id, xyz)
except:
pass
return keep_passings(hits) if len(hits) > 0 else hits


def keep_passings(hits):
rm_keys = list()
for swc_id, xyz_coords in hits.items():
if compute_length(xyz_coords) < 5:
if length(xyz_coords) < 3:
rm_keys.append(swc_id)
return utils.remove_items(hits, rm_keys)


def simple_trim(neurograph, leaf, hits):
# Initializations
swc_id, xyz_coords = unpack_dict(hits)
i, j = get_edge(neurograph, xyz_coords)
if not (neurograph.is_leaf(i) or neurograph.is_leaf(j)):
trim_from_leaf(neurograph, leaf, xyz_coords)

# Check for significant difference in radii
radius_leaf = get_branch_avg_radii(neurograph, leaf)
radius_edge = np.mean(neurograph.edges[i, j]["radius"])
if radius_leaf < radius_edge - 1:
print("1 - trim from leaf")
trim_from_leaf(neurograph, leaf, xyz_coords)
elif radius_edge < radius_leaf - 1:
print("1 - trim_from_edge")
#trim_from_edge(neurograph, leaf, (i, j))
else:
# Determine smaller fragment
leaf_component_size = len(gutils.get_component(neurograph, leaf))
edge_component_size = len(gutils.get_component(neurograph, i))
if leaf_component_size < edge_component_size:
print("2 - trim from leaf")
trim_from_leaf(neurograph, leaf, xyz_coords)
else:
print("2 - trim_from_edge")
#trim_from_edge(neurograph, leaf, edge)


def trim_from_leaf(neurograph, leaf, xyz_coords):
# Initializations
j = list(neurograph.neighbors(leaf))[0]
xyz_coords = set([tuple(xyz) for xyz in xyz_coords])

# Determine points to trim
idx = 0
while len(xyz_coords) > 0:
for xyz_query, radius in get_branch(neurograph, leaf):
for xyz in neurograph.query_kdtree(xyz_query, radius):
if tuple(xyz) in xyz_coords:
xyz_coords.remove(tuple(xyz))
idx += 1
xyz_coords = neurograph.oriented_edge((leaf, j), leaf)
idx = max(len(xyz_coords) + 2, idx)

# Trim points
if length(xyz_coords[idx::]) > 16:
neurograph = trim(neurograph, leaf, j, xyz_coords, idx)
else:
neurograph.remove_node(leaf)
if neurograph.degree[j] == 0:
neurograph.remove_node(j)
return neurograph


def trim_from_edge(neurograph, leaf, xyz_coords):
pass


def trim(neurograph, leaf, j, xyz_coords, idx):
e = (leaf, j)
neurograph.nodes[leaf]["xyz"] = xyz_coords[idx]
neurograph.edges[leaf, j]["xyz"] = xyz_coords[idx::]
neurograph.edges[e]["radius"] = neurograph.edges[e]["radius"][idx::]
return neurograph


# --- utils ---
def get_swc_id(neurograph, nodes):
"""
Expand Down Expand Up @@ -317,19 +374,37 @@ def get_branch(neurograph, i):
return zip(xyz_coords[0:n], radii[0:n])


def compute_length(path):
def get_branch_avg_radii(neurograph, leaf):
"""
Computes the length of "path".
Gets the average radii of the branch emanating from "leaf".

Parameters
----------
path : list
xyz coordinates that form a path.
neurograph : NeuroGraph
Graph containing "leaf".
leaf : int
Node id of leaf node.

Returns
-------
float
Length of "path".
Average radii of the branch emanating from "leaf".

"""
return np.sum([dist(path[i], path[i - 1]) for i in range(1, len(path))])
j = list(neurograph.neighbors(leaf))[0]
return np.mean(neurograph.oriented_edge((leaf, j), leaf, key="radius"))


def unpack_dict(my_dict):
return list(my_dict.items())[0]


def get_edge(neurograph, xyz_coords):
hits = dict()
for xyz in xyz_coords:
try:
edge = neurograph.xyz_to_edge[tuple(xyz)]
hits = utils.append_dict_value(hits, edge, 1)
except:
pass
return tuple(utils.find_best(hits))
21 changes: 15 additions & 6 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,12 +702,21 @@ def node_xyz_dist(self, node, xyz):
return get_dist(xyz, self.nodes[node]["xyz"])

def edge_length(self, edge):
length = 0
for i in range(1, len(self.edges[edge]["xyz"])):
length += get_dist(
self.edges[edge]["xyz"][i - 1], self.edges[edge]["xyz"][i]
)
return length
"""
Computes length of path stored as xyz coordinates in "edge".

Parameters
----------
edge : tuple
Edge in self.

Returns
-------
float
Path length of edge.

"""
return geometry.length(self.edges[edge]["xyz"])

def is_contained(self, node_or_xyz, buffer=0):
if self.bbox:
Expand Down
Loading