Skip to content

Commit

Permalink
feat: isolated proposal detection
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Apr 2, 2024
1 parent 7bd82c7 commit 6313095
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 63 deletions.
68 changes: 46 additions & 22 deletions src/deep_neurographs/machine_learning/groundtruth_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from deep_neurographs import utils
from deep_neurographs.geometry import dist as get_dist

CLOSE_THRESHOLD = 3.5
ALIGNED_THRESHOLD = 3.5
MIN_INTERSECTION = 10


Expand All @@ -42,10 +42,10 @@ def get_valid_proposals(target_neurograph, pred_neurograph):
invalid_proposals = set()
node_to_target = dict()
for component in nx.connected_components(pred_neurograph):
aligned_bool, target_id = is_component_aligned(
aligned, target_id = is_component_aligned(
target_neurograph, pred_neurograph, component
)
if not aligned_bool:
if not aligned:
i = utils.sample_singleton(component)
invalid_proposals.add(pred_neurograph.nodes[i]["swc_id"])
else:
Expand Down Expand Up @@ -104,31 +104,50 @@ def is_component_aligned(target_neurograph, pred_neurograph, component):
d = get_dist(hat_xyz, xyz)
dists = utils.append_dict_value(dists, hat_swc_id, d)

# Check whether there's a merge
hits = []
for key in dists.keys():
if len(dists[key]) > 10 and np.mean(dists[key]) < CLOSE_THRESHOLD:
hits.append(key)
if len(hits) > 1:
print("pred_swc_id:", pred_neurograph.edges[edge]["swc_id"])
print("target_swc_id:", list(dists.keys()))
print("")

# Deterine whether aligned
hat_swc_id = utils.find_best(dists)
dists = np.array(dists[hat_swc_id])
aligned_score = np.mean(dists[dists < np.percentile(dists, 90)])
if aligned_score < 4 and hat_swc_id:
intersects = True if len(dists) > MIN_INTERSECTION else False
aligned_score = np.mean(dists[dists < np.percentile(dists, 85)])
if (aligned_score < ALIGNED_THRESHOLD and hat_swc_id) and intersects:
return True, hat_swc_id
else:
return False, None


def is_valid(target_neurograph, pred_neurograph, target_id, edge):
"""
Determines whether a proposal is valid, meaning it must be consistent and
aligned.
Parameters
----------
target_neurograph : NeuroGraph
Graph built from ground truth swc files.
pred_neurograph : NeuroGraph
Graph build from predicted swc files.
target_id : str
swc id of target that the proposal "edge" corresponds to.
edge : frozenset
Edge proposal to be checked.
Returns
-------
bool
Indication of whether proposal is consistent
"""
#aligned = is_proposal_aligned(target_neurograph, pred_neurograph, edge)
consistent = is_consistent(
target_neurograph, pred_neurograph, target_id, edge
)
return True if consistent else False


def is_consistent(target_neurograph, pred_neurograph, target_id, edge):
"""
Determines whether the proposal connects two branches that correspond to
either the same or adjacent branches on the ground truth. If either
condition holds, then the proposal is said to be valid.
condition holds, then the proposal is said to be consistent.
Parameters
----------
Expand All @@ -137,14 +156,14 @@ def is_valid(target_neurograph, pred_neurograph, target_id, edge):
pred_neurograph : NeuroGraph
Graph build from predicted swc files.
target_id : str
...
swc id of target that the proposal "edge" corresponds to.
edge : frozenset
Edge proposal to be checked.
Returns
-------
bool
Indication of whether proposal is valid
Indication of whether proposal is consistent
"""
# Find closest edges from target_neurograph
Expand All @@ -169,6 +188,15 @@ def is_valid(target_neurograph, pred_neurograph, target_id, edge):
return False


def is_proposal_aligned(target_neurograph, pred_neurograph, edge):
xyz_0, xyz_1 = pred_neurograph.proposal_xyz(edge)
proj_dists = []
for xyz in geometry.make_line(xyz_0, xyz_1, 10):
hat_xyz = target_neurograph.get_projection(tuple(xyz))
proj_dists.append(get_dist(hat_xyz, xyz))
return True if np.mean(proj_dists) < ALIGNED_THRESHOLD else False


def proj_branch(target_neurograph, pred_neurograph, target_id, i):
# Compute projections
hits = dict()
Expand Down Expand Up @@ -196,10 +224,6 @@ def proj_branch(target_neurograph, pred_neurograph, target_id, i):
return best_edge


def is_proposal_aligned(target_neurograph, pred_neurograph, edge):
pass


def is_adjacent(neurograph, edge_i, edge_j):
"""
Determines whether "edge_i" and "edge_j" are adjacent, meaning there
Expand Down
126 changes: 85 additions & 41 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from deep_neurographs.generate_proposals import is_valid
from deep_neurographs.geometry import check_dists
from deep_neurographs.geometry import dist as get_dist
from deep_neurographs.geometry import get_midpoint
from deep_neurographs.machine_learning.groundtruth_generation import (
init_targets,
)
Expand Down Expand Up @@ -85,7 +86,7 @@ def copy_graph(self, add_attrs=False):
graph.add_edges_from(self.edges)
return graph

# --- Add to graph --
# --- Edit Graph --
def add_component(self, irreducibles):
"""
Adds a connected component to "self".
Expand Down Expand Up @@ -187,6 +188,21 @@ def __add_edge(self, edge, attrs, idxs, swc_id):
for xyz in attrs["xyz"][idxs]:
self.xyz_to_edge[tuple(xyz)] = edge

def absorb_node(self, i, nb_1, nb_2):
# Get attributes
xyz = self.get_branches(i, key="xyz")
radius = self.get_branches(i, key="radius")

# Edit graph
self.remove_node(i)
self.add_edge(
nb_1,
nb_2,
xyz=np.vstack([np.flip(xyz[1], axis=0), xyz[0][1:, :]]),
radius=np.concatenate((radius[0], np.flip(radius[1]))),
swc_id=self.nodes[nb_1]["swc_id"],
)

def split_edge(self, edge, attrs, idx):
"""
Splits "edge" into two distinct edges by making the subnode at "idx" a
Expand Down Expand Up @@ -248,11 +264,13 @@ def add_proposal(self, i, j):
edge = frozenset((i, j))
self.nodes[i]["proposals"].add(j)
self.nodes[j]["proposals"].add(i)
self.xyz_to_proposal[tuple(self.nodes[i]["xyz"])] = edge
self.xyz_to_proposal[tuple(self.nodes[j]["xyz"])] = edge
self.proposals[edge] = {
"xyz": np.array([self.nodes[i]["xyz"], self.nodes[j]["xyz"]])
}

# --- Proposal and Ground Truth Generation ---
# --- Proposal Generation ---
def generate_proposals(
self,
radius,
Expand Down Expand Up @@ -312,11 +330,13 @@ def generate_proposals(

# Finish
self.filter_nodes()
self.init_proposal_kdtree()
if optimize:
self.run_optimization()

def reset_proposals(self):
self.proposals = dict()
self.xyz_to_proposal = dict()
for i in self.nodes:
self.nodes[i]["proposals"] = set()

Expand Down Expand Up @@ -351,7 +371,7 @@ def run_optimization(self):
xyz_1, xyz_2 = geometry.optimize_alignment(self, img, edge)
self.proposals[edge]["xyz"] = np.array([xyz_1, xyz_2])

# -- kdtree --
# -- KDTree --
def init_kdtree(self):
"""
Builds a KD-Tree from the (x,y,z) coordinates of the subnodes of
Expand All @@ -366,8 +386,22 @@ def init_kdtree(self):
None
"""
if not self.kdtree:
self.kdtree = KDTree(list(self.xyz_to_edge.keys()))
self.kdtree = KDTree(list(self.xyz_to_edge.keys()))

def init_proposal_kdtree(self):
"""
Builds a KD-Tree from the (x,y,z) coordinates of the proposals.
Parameters
----------
None
Returns
-------
None
"""
self.proposal_kdtree = KDTree(list(self.xyz_to_proposal.keys()))

def query_kdtree(self, xyz, d):
"""
Expand All @@ -392,7 +426,7 @@ def get_projection(self, xyz):
_, idx = self.kdtree.query(xyz, k=1)
return tuple(self.kdtree.data[idx])

# --- utils ---
# --- Proposal Utils ---
def n_proposals(self):
"""
Computes number of edges proposals in the graph.
Expand All @@ -412,13 +446,58 @@ def n_proposals(self):
def get_proposals(self):
return list(self.proposals.keys())

def get_simple_proposals(self):
return set([e for e in self.get_proposals() if self.is_simple(e)])

def get_complex_proposals(self):
return set([e for e in self.get_proposals() if not self.is_simple(e)])

def get_isolated_proposals(self, radius):
isolated_proposals = set()
for edge in self.proposals.keys():
xyz = self.proposal_midpoint(edge)
if len(self.proposal_kdtree.query_ball_point(xyz, radius)) <= 2:
isolated_proposals.add(edge)
return isolated_proposals

def is_simple(self, edge):
i, j = tuple(edge)
return True if self.is_leaf(i) and self.is_leaf(j) else False

def proposal_xyz(self, edge):
return tuple(self.proposals[edge]["xyz"])

def proposal_length(self, edge):
i, j = tuple(edge)
return get_dist(self.nodes[i]["xyz"], self.nodes[j]["xyz"])

def proposal_midpoint(self, edge):
i, j = tuple(edge)
return get_midpoint(self.nodes[i]["xyz"], self.nodes[j]["xyz"])

def merge_proposal(self, edge):
# Attributes
i, j = tuple(edge)
xyz = np.vstack([self.nodes[i]["xyz"], self.nodes[j]["xyz"]])
radius = np.array([self.nodes[i]["radius"], self.nodes[j]["radius"]])

# Add
self.add_edge(i, j, xyz=xyz, radius=radius, swc_id="merged")
del self.proposals[edge]
# delete from kdtree

def remove_nonisolated_proposals(self, radius):
isolated_proposals = self.get_isolated_proposals(radius)
proposals = self.get_proposals()
while len(proposals) > 0:
edge = proposals.pop()
if edge not in isolated_proposals:
i, j = tuple(edge)
self.nodes[i]["proposals"].remove(j)
self.nodes[j]["proposals"].remove(i)
del self.proposals[edge]

# --- Utils ---
def get_branches(self, i, key="xyz"):
branches = []
for j in self.neighbors(i):
Expand Down Expand Up @@ -464,16 +543,6 @@ def get_edge_attr(self, edge, key):
xyz_arr = gutils.get_edge_attr(self, edge, key)
return xyz_arr[0], xyz_arr[-1]

def get_complex_proposals(self):
return set([e for e in self.get_proposals() if not self.is_simple(e)])

def get_simple_proposals(self):
return set([e for e in self.get_proposals() if self.is_simple(e)])

def is_simple(self, edge):
i, j = tuple(edge)
return True if self.is_leaf(i) and self.is_leaf(j) else False

def to_patch_coords(self, edge, midpoint, chunk_size):
patch_coords = []
for xyz in self.edges[edge]["xyz"]:
Expand Down Expand Up @@ -582,31 +651,6 @@ def filter_nodes(self):
nbs = list(self.neighbors(i))
self.absorb_node(i, nbs[0], nbs[1])

def absorb_node(self, i, nb_1, nb_2):
# Get attributes
xyz = self.get_branches(i, key="xyz")
radius = self.get_branches(i, key="radius")

# Edit graph
self.remove_node(i)
self.add_edge(
nb_1,
nb_2,
xyz=np.vstack([np.flip(xyz[1], axis=0), xyz[0][1:, :]]),
radius=np.concatenate((radius[0], np.flip(radius[1]))),
swc_id=self.nodes[nb_1]["swc_id"],
)

def merge_proposal(self, edge):
# Attributes
i, j = tuple(edge)
xyz = np.vstack([self.nodes[i]["xyz"], self.nodes[j]["xyz"]])
radius = np.array([self.nodes[i]["radius"], self.nodes[j]["radius"]])

# Add
self.add_edge(i, j, xyz=xyz, radius=radius, swc_id="merged")
del self.proposals[edge]

def to_swc(self, path):
with ThreadPoolExecutor() as executor:
threads = []
Expand Down

0 comments on commit 6313095

Please sign in to comment.