Skip to content

Commit

Permalink
feat: soma_ids attr in neurograph (#177)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Jun 27, 2024
1 parent ccf85be commit 72872fd
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 35 deletions.
9 changes: 8 additions & 1 deletion src/deep_neurographs/generate_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,15 @@ def run_on_leaf(
if not complex_bool and neurograph.degree[node] > 1:
continue

# Check for somas
swc_id = neurograph.nodes[node]["swc_id"]
hit_1 = swc_id in neurograph.soma_ids
hit_2 = leaf_swc_id in neurograph.soma_ids
if hit_1 and hit_2:
continue

# Check whether already connection exists
pair_id = frozenset((leaf_swc_id, neurograph.nodes[node]["swc_id"]))
pair_id = frozenset((leaf_swc_id, swc_id))
if pair_id in connections.keys():
proposal = connections[pair_id]
dist_1 = neurograph.dist(leaf, node)
Expand Down
34 changes: 34 additions & 0 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,3 +821,37 @@ def count_components(graph):
"""
return nx.number_connected_components(graph)


def largest_components(neurograph, k):
"""
Finds the "k" largest connected components in "neurograph".
Parameters
----------
neurograph : NeuroGraph
Graph to be searched.
k : int
Number of largest connected components to return.
Returns
-------
list
List where each entry is a random node from one of the k largest
connected components.
"""
component_cardinalities = k * [-1]
node_ids = k * [-1]
for nodes in nx.connected_components(neurograph):
if len(nodes) > component_cardinalities[-1]:
i = 0
while i < k:
if len(nodes) > component_cardinalities[i]:
component_cardinalities.insert(i, len(nodes))
component_cardinalities.pop(-1)
nodes_ids.insert(i, utils.sample_singleton(nodes))
nodes_ids.pop(-1)
break
i += 1
return node_ids
4 changes: 2 additions & 2 deletions src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def run_without_seeds(
# Merge proposals
neurograph = build.fuse_branches(neurograph, accepts_i)
accepts.extend(accepts_i)

# Report progress
if i > progress_cnt * chunk_size and progress_bar:
progress_cnt, t1 = utils.report_progress(
Expand Down Expand Up @@ -193,7 +193,7 @@ def get_idxs(dataset, model_type):
return dataset["idx_to_edge"]


# -- Whole Brain Seed-Based Inference --
# -- seed-based inference --
def build_from_soma(
neurograph, labels_path, chunk_origin, chunk_shape=CHUNK_SHAPE, n_hops=1
):
Expand Down
160 changes: 128 additions & 32 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
self.target_edges = set()
self.node_cnt = 0
self.node_spacing = node_spacing
self.soma_ids = set()

# Initialize data structures for proposals
self.complex_proposals = set()
Expand Down Expand Up @@ -88,6 +89,26 @@ def copy_graph(self, add_attrs=False):
graph.add_edges_from(deepcopy(self.edges))
return graph

def set_soma_ids(self, k):
"""
Sets class attribute called "self.soma_ids" as the swc ids of the "k"
largest components. These components are used as a proxy for soma
locations.
Paramters
---------
k : int
Number of largest components to be set as proxy soma locations.
Returns
-------
None
"""
node_ids = gutils.largest_components(self, k)
for i in node_ids:
self.soma_ids.add(self.nodes[i]["swc_id"])

# --- Edit Graph --
def add_component(self, irreducibles):
"""
Expand Down Expand Up @@ -250,31 +271,6 @@ def split_edge(self, edge, attrs, idx):
self.__add_edge((node_id, j), attrs, idxs_2, swc_id)
return node_id

def add_proposal(self, i, j):
"""
Adds proposal between nodes "i" and "j".
Parameters
----------
i : int
Node id.
j : int
Node id
Returns
-------
None
"""
edge = frozenset((i, j))
self.nodes[i]["proposals"].add(j)
self.nodes[j]["proposals"].add(i)
self.xyz_to_proposal[tuple(self.nodes[i]["xyz"])] = edge
self.xyz_to_proposal[tuple(self.nodes[j]["xyz"])] = edge
self.proposals[edge] = {
"xyz": np.array([self.nodes[i]["xyz"], self.nodes[j]["xyz"]])
}

# --- Proposal Generation ---
def generate_proposals(
self,
Expand Down Expand Up @@ -314,14 +310,64 @@ def generate_proposals(
self.run_optimization()

def reset_proposals(self):
"""
Deletes all previously generated proposals.
Parameters
----------
None
Returns
-------
None
"""
self.proposals = dict()
self.xyz_to_proposal = dict()
for i in self.nodes:
self.nodes[i]["proposals"] = set()

def set_proposals_per_leaf(self, proposals_per_leaf):
"""
Sets the maximum number of proposals per leaf as a class attribute.
Parameters
----------
proposals_per_leaf : int
Maximum number of proposals per leaf.
Returns
-------
None
"""
self.proposals_per_leaf = proposals_per_leaf

def add_proposal(self, i, j):
"""
Adds proposal between nodes "i" and "j".
Parameters
----------
i : int
Node id.
j : int
Node id
Returns
-------
None
"""
edge = frozenset((i, j))
self.nodes[i]["proposals"].add(j)
self.nodes[j]["proposals"].add(i)
self.xyz_to_proposal[tuple(self.nodes[i]["xyz"])] = edge
self.xyz_to_proposal[tuple(self.nodes[j]["xyz"])] = edge
self.proposals[edge] = {
"xyz": np.array([self.nodes[i]["xyz"], self.nodes[j]["xyz"]])
}

def init_targets(self, target_neurograph):
target_neurograph.init_kdtree()
self.target_edges = init_targets(target_neurograph, self)
Expand Down Expand Up @@ -431,6 +477,19 @@ def n_proposals(self):
return len(self.proposals)

def get_proposals(self):
"""
Gets the proposal ids (i.e. node pairs).
Parameters
----------
None
Returns
-------
list
Proposal ids
"""
return list(self.proposals.keys())

def get_simple_proposals(self):
Expand Down Expand Up @@ -491,15 +550,52 @@ def proposal_midpoint(self, proposal):
return get_midpoint(self.nodes[i]["xyz"], self.nodes[j]["xyz"])

def merge_proposal(self, edge):
# Attributes
i, j = tuple(edge)
xyz = np.vstack([self.nodes[i]["xyz"], self.nodes[j]["xyz"]])
radius = np.array([self.nodes[i]["radius"], self.nodes[j]["radius"]])
swc_id = self.nodes[i]["swc_id"]
soma_bool_1 = self.nodes[i]["swc_id"] in self.soma_ids
soma_bool_2 = self.nodes[j]["swc_id"] in self.soma_ids
if not (soma_bool_1 and soma_bool_1):
# Attributes
xyz = np.vstack([self.nodes[i]["xyz"], self.nodes[j]["xyz"]])
radius = np.array([self.nodes[i]["radius"], self.nodes[j]["radius"]])
if self.nodes[i]["swc_id"] in self.soma_ids:
r = j
swc_id = self.nodes[i]["swc_id"]
else:
r = i
swc_id = self.nodes[j]["swc_id"]

# Update graph
self.upd_ids(swc_id, r)
self.add_edge(i, j, xyz=xyz, radius=radius, swc_id=swc_id)
if i in self.leafs:
self.leafs.remove(i)
if j in self.leafs:
self.leafs.remove(j)
del self.proposals[edge]

def upd_ids(self, swc_id, r):
"""
Updates the swc_id of all nodes connected to "r".
Parameters
----------
swc_id : str
Segment id.
r : int
Node.
# Add
self.add_edge(i, j, xyz=xyz, radius=radius, swc_id=swc_id)
del self.proposals[edge]
Returns
-------
None
"""
queue = [r]
visited = []
while len(queue) > 0:
i = queue.pop()
self.nodes[i]["swc_id"] = swc_id
for j in [j for j in self.neighbors(i) if j not in visited]:
queue.append(j)

# --- Utils ---
def dist(self, i, j):
Expand Down

0 comments on commit 72872fd

Please sign in to comment.