Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat gnn inference #127

Merged
merged 6 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 147 additions & 33 deletions src/deep_neurographs/generate_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,56 +8,165 @@

"""

BUFFER = 36
import numpy as np

from deep_neurographs import geometry

def run(neurograph, query_id, query_xyz, radius):
ENDPOINT_DIST = 10


def run(neurograph, search_radius, complex_bool=True):
"""
Generates edge proposals for node "query_id" in "neurograph" by finding
candidate points on distinct connected components near "query_xyz".
Generates proposals emanating from "leaf".

Parameters
----------
neurograph : NeuroGraph
Graph built from swc files.
query_id : int
Node id of the query node.
query_xyz : tuple[float]
(x,y,z) coordinates of the query node.
radius : float
Maximum Euclidean distance between end points of edge proposal.
Graph that proposals will be generated for.
search_radius : float
Maximum Euclidean distance between endpoints of proposal.
complex_bool : bool, optional
Indication of whether to generate complex proposals. The default is
True.

Returns
-------
list
Best edge proposals generated from "query_node".
NeuroGraph
Graph containing leaf that may have been updated.

"""
proposals = dict()
query_swc_id = neurograph.nodes[query_id]["swc_id"]
for xyz in neurograph.query_kdtree(query_xyz, radius):
# Check whether xyz is contained (if applicable)
if not neurograph.is_contained(xyz, buffer=36):
connections = dict()
for leaf in neurograph.leafs:
if neurograph.nodes[leaf]["swc_id"] == "864374559":
neurograph, connections = run_on_leaf(
neurograph, connections, leaf, search_radius, complex_bool
)
neurograph.filter_nodes()
return neurograph


def run_on_leaf(neurograph, connections, leaf, search_radius, complex_bool):
"""
Generates proposals emanating from "leaf".

Parameters
----------
neurograph : NeuroGraph
Graph containing leaf.
connections : dict
Dictionary that tracks which connected components are connected by a
proposal. The keys are a frozenset of the pair of swc ids and values
are the corresponding proposal ids.
leaf : int
Leaf node that proposals are to be generated from.
search_radius : float
Maximum Euclidean distance between endpoints of proposal.
complex_bool : bool
Indication of whether to generate complex proposals.

Returns
-------
NeuroGraph
Graph containing leaf that may have been updated.
dict
Updated "connections" dictionary with information about proposals that
were added to "neurograph".

"""
print("leaf:", leaf)
leaf_swc_id = neurograph.nodes[leaf]["swc_id"]
for xyz in get_candidates(neurograph, leaf, search_radius):
# Get connection
neurograph, node = get_conection(neurograph, leaf, xyz, search_radius)
if not complex_bool and neurograph.degree[node] > 1:
continue

# Check whether proposal is valid
edge = neurograph.xyz_to_edge[tuple(xyz)]
swc_id = neurograph.edges[edge]["swc_id"]
if swc_id != query_swc_id and swc_id not in proposals.keys():
proposals[swc_id] = tuple(xyz)
# Check whether already connection exists
pair_id = frozenset((leaf_swc_id, neurograph.nodes[node]["swc_id"]))
if pair_id in connections.keys():
proposal = connections[pair_id]
dist_1 = neurograph.dist(leaf, node)
dist_2 = neurograph.proposal_length(proposal)
if dist_1 < dist_2:
i, j = tuple(proposal)
neurograph.nodes[i]["proposals"].remove(j)
neurograph.nodes[j]["proposals"].remove(i)
del neurograph.proposals[proposal]
del connections[pair_id]
else:
continue

# Add proposal
neurograph.add_proposal(leaf, node)
connections[pair_id] = frozenset({leaf, node})
return neurograph, connections


def get_candidates(neurograph, leaf, search_radius):
"""
Generates proposals for node "leaf" in "neurograph" by finding candidate
xyz points on distinct connected components nearby.

Parameters
----------
neurograph : NeuroGraph
Graph built from swc files.
leaf : int
Leaf node that proposals are to be generated from.
search_radius : float
Maximum Euclidean distance between endpoints of proposal.

# Check whether to stop
if len(proposals) >= neurograph.proposals_per_leaf:
break
Returns
-------
list
Proposals generated from "leaf".

return list(proposals.values())
"""
candidates = dict()
dists = dict()
leaf_xyz = neurograph.nodes[leaf]["xyz"]
for xyz in neurograph.query_kdtree(leaf_xyz, search_radius):
swc_id = neurograph.xyz_to_swc(xyz)
if swc_id != neurograph.nodes[leaf]["swc_id"]:
if swc_id not in candidates.keys():
candidates[swc_id] = tuple(xyz)
dists[swc_id] = geometry.dist(leaf_xyz, xyz)
elif geometry.dist(leaf_xyz, xyz) < dists[swc_id]:
candidates[swc_id] = tuple(xyz)
dists[swc_id] = geometry.dist(leaf_xyz, xyz)
return get_best_candidates(neurograph, candidates, dists)


def get_best_candidates(neurograph, candidates, dists):
if len(candidates) > neurograph.proposals_per_leaf:
worst = None
for key, d in dists.items():
if worst is None:
worst = key
elif dists[key] > dists[worst]:
worst = key
del candidates[worst]
del dists[worst]
return get_best_candidates(neurograph, candidates, dists)
else:
return list(candidates.values())


def get_conection(neurograph, leaf, xyz, search_radius):
edge = neurograph.xyz_to_edge[xyz]
node, d = get_closer_endpoint(neurograph, edge, xyz)
if d > ENDPOINT_DIST or neurograph.dist(leaf, node) > search_radius:
attrs = neurograph.get_edge_data(*edge)
idx = np.where(np.all(attrs["xyz"] == xyz, axis=1))[0][0]
node = neurograph.split_edge(edge, attrs, idx)
return neurograph, node


def is_valid(neurograph, i, filter_doubles):
"""
Determines whether is a valid node to generate proposals from. A node is
considered valid if it is contained in "self.bbox" (if applicable) and is
not contained in a doubled connected component (if applicable).
considered valid if it is not contained in a doubled connected component
(if applicable).

Parameters
----------
Expand All @@ -77,8 +186,13 @@ def is_valid(neurograph, i, filter_doubles):
"""
if filter_doubles:
neurograph.upd_doubles(i)

swc_id = neurograph.nodes[i]["swc_id"]
is_double = True if swc_id in neurograph.doubles else False
is_contained = neurograph.is_contained(i, buffer=BUFFER)
return False if not is_contained or is_double else True
return True if swc_id in neurograph.doubles else False


# -- utils --
def get_closer_endpoint(neurograph, edge, xyz):
i, j = tuple(edge)
d_i = geometry.dist(neurograph.nodes[i]["xyz"], xyz)
d_j = geometry.dist(neurograph.nodes[j]["xyz"], xyz)
return (i, d_i) if d_i < d_j else (j, d_j)
28 changes: 0 additions & 28 deletions src/deep_neurographs/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,34 +454,6 @@ def dist(v_1, v_2, metric="l2"):
return distance.euclidean(v_1, v_2)


def check_dists(xyz_1, xyz_2, xyz_3, radius):
"""
Determine whether to create new vertex at "xyz_2" or draw proposal between
"xyz_1" and existing node at "xyz_3".

Parameters
----------
xyz_1 : np.ndarray
xyz coordinate of leaf node (i.e. source of edge proposal).
xyz_2 : np.ndarray
xyz coordinate queried from kdtree (i.e. dest of edge proposal).
xyz_3 : np.ndarray
xyz coordinate of existing node in graph that is near "xyz_2".
radius : float
Maximum Euclidean length of edge proposal.

Parameters
----------
bool
Indication of whether to draw edge proposal between "xyz_1" and
"xyz_3".

"""
d_1 = dist(xyz_1, xyz_3) < radius
d_2 = dist(xyz_2, xyz_3) < 10
return True if d_1 and d_2 else False


def make_line(xyz_1, xyz_2, n_steps):
"""
Generates a series of points representing a straight line between two 3D
Expand Down
Loading
Loading