Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor integrate upds #108

Merged
merged 3 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 83 additions & 3 deletions src/deep_neurographs/delete_merges_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,14 @@ def detect_merges_neuron(
radius : int
Each node within "radius" is deleted.
output_dir : str, optional
...
Directory that merge sites are saved in swc files. The default is
None.
save : bool, optional
Indication of whether to save merge sites. The default is False.

Returns
-------
set
delete_nodes : set
Nodes that are part of a merge mistake.

"""
Expand Down Expand Up @@ -171,6 +172,34 @@ def detect_intersections(target_densegraph, graph, component):


def detect_merges(target_densegraph, graph, hits, radius, output_dir, save):
"""
Detects merge mistakes in "graph" (i.e. whether "graph" is closely aligned
with two distinct connected components in "target_densegraph".

Parameters
----------
target_densegraph : DenseGraph
Graph built from ground truth swc files.
graph : networkx.Graph
Graph build from a predicted swc file.
hits : dict
Dictionary that stores intersections between "target_densegraph" and
"graph", where the keys are swc ids from "target_densegraph" and
values are nodes from "graph".
radius : int
Each node within "radius" is deleted.
output_dir : str, optional
Directory that merge sites are saved in swc files. The default is
None.
save : bool, optional
Indication of whether to save merge sites.

Returns
-------
merge_sites : set
Nodes that are part of a merge site.

"""
merge_sites = set()
if len(hits.keys()) > 0:
visited = set()
Expand All @@ -184,7 +213,6 @@ def detect_merges(target_densegraph, graph, hits, radius, output_dir, save):
# Check for merge site
min_dist, sites = locate_site(graph, hits[id_1], hits[id_2])
visited.add(pair)
print(graph.nodes[sites[0]]["xyz"], min_dist)
if min_dist < MERGE_DIST_THRESHOLD:
merge_nbhd = get_merged_nodes(graph, sites, radius)
merge_sites = merge_sites.union(merge_nbhd)
Expand Down Expand Up @@ -231,6 +259,24 @@ def locate_site(graph, merged_1, merged_2):


def get_merged_nodes(graph, sites, radius):
"""
Gets nodes that are falsely merged.

Parameters
----------
graph : networkx.Graph
Graph that contains a merge at "sites".
sites : list
Nodes in "graph" that are part of a merge mistake.
radius : int
Radius about node to be searched.

Returns
-------
merged_nodes : set
Nodes that are falsely merged.

"""
i, j = tuple(sites)
merged_nodes = set(nx.shortest_path(graph, source=i, target=j))
merged_nodes = merged_nodes.union(get_nbhd(graph, i, radius))
Expand All @@ -239,10 +285,44 @@ def get_merged_nodes(graph, sites, radius):


def get_nbhd(graph, i, radius):
"""
Gets all nodes within a path length of "radius" from node "i".

Parameters
----------
graph : networkx.Graph
Graph to searched.
i : node
Node that is root of neighborhood to be returned.
radius : int
Radius about node to be searched.

Returns
-------
set
Nodes within a path length of "radius" from node "i".

"""
return set(nx.dfs_tree(graph, source=i, depth_limit=radius))


def get_point(graph, sites):
"""
Gets midpoint of merge site defined by the pair contained in "sites".

Parameters
----------
graph : networkx.Graph
Graph that contains a merge at "sites".
sites : list
Nodes in "graph" that are part of a merge mistake.

Returns
-------
numpy.ndarray
Midpoint between pair of xyz coordinates in "sites".

"""
xyz_0 = graph.nodes[sites[0]]["xyz"]
xyz_1 = graph.nodes[sites[1]]["xyz"]
return geometry.get_midpoint(xyz_0, xyz_1)
43 changes: 43 additions & 0 deletions src/deep_neurographs/densegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,37 @@ def init_kdtree(self):
self.kdtree = KDTree(list(self.xyz_to_swc.keys()))

def get_projection(self, xyz):
"""
Projects "xyz" onto "self by finding the closest point.

Parameters
----------
xyz : numpy.ndarray
xyz coordinate to be queried.

Returns
-------
numpy.ndarray
Projection of "xyz".

"""
_, idx = self.kdtree.query(xyz, k=1)
return tuple(self.kdtree.data[idx])

def save(self, output_dir):
"""
Saves "self" to an swc file.

Parameters
----------
output_dir : str
Path to directory that swc files are written to.

Returns
-------
None

"""
for swc_id, graph in self.graphs.items():
cnt = 0
for component in nx.connected_components(graph):
Expand All @@ -128,6 +155,22 @@ def save(self, output_dir):
swc_utils.write(path, entry_list)

def make_entries(self, graph, component):
"""
Makes swc entries corresponding to nodes in "component".

Parameters
----------
graph : networkx.Graph
Graph that "component" is a connected component of.
component : set
Connected component of "graph".

Returns
-------
entry_list
List of swc entries generated from nodes in "component".

"""
node_to_idx = dict()
entry_list = []
for i, j in nx.dfs_edges(graph.subgraph(component)):
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_profile(img, xyz_arr, process_id=None, window=[5, 5, 5]):
def fill_path(img, path, val=-1):
for xyz in path:
x, y, z = tuple(np.floor(xyz).astype(int))
img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val
img[x - 1: x + 2, y - 1: y + 2, z - 1: z + 2] = val
return img


Expand Down
1 change: 0 additions & 1 deletion src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

"""

import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from time import time

Expand Down
3 changes: 1 addition & 2 deletions src/deep_neurographs/machine_learning/feature_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import numpy as np
import tensorstore as ts
from time import time

from deep_neurographs import geometry, utils

Expand Down Expand Up @@ -68,7 +67,7 @@ def run(neurograph, model_type, img_path, labels_path=None, proposals=None):

"""
# Initializations
img_driver = driver = "n5" if ".n5" in img_path else "zarr"
img_driver = "n5" if ".n5" in img_path else "zarr"
img = utils.open_tensorstore(img_path, img_driver)
if labels_path:
labels_driver = "neuroglancer_precomputed"
Expand Down
10 changes: 2 additions & 8 deletions src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,19 @@
"""

from copy import deepcopy
from random import sample
from time import time

import fastremap
import networkx as nx
import numpy as np
import torch
from time import time
from torch.nn.functional import sigmoid
from torch.utils.data import DataLoader

from deep_neurographs import graph_utils as gutils
from deep_neurographs import reconstruction as build
from deep_neurographs import utils
from deep_neurographs.machine_learning import feature_generation
from deep_neurographs.machine_learning import ml_utils
from deep_neurographs.machine_learning import feature_generation, ml_utils
from deep_neurographs.neurograph import NeuroGraph

BATCH_SIZE_PROPOSALS = 1000
Expand Down Expand Up @@ -114,7 +112,6 @@ def run_without_seeds(
chunk_size = max(int(n_batches * 0.02), 1)
for i, batch in enumerate(batches):
# Prediction
t2 = time()
proposals_i = [proposals[j] for j in batch]
accepts_i = predict(
neurograph,
Expand All @@ -128,7 +125,6 @@ def run_without_seeds(
)

# Merge proposals
t2 = time()
neurograph = build.fuse_branches(neurograph, accepts_i)
accepts.extend(accepts_i)

Expand All @@ -153,7 +149,6 @@ def predict(
confidence_threshold=0.7,
):
# Generate features
t3 = time()
features = feature_generation.run(
neurograph,
model_type,
Expand All @@ -164,7 +159,6 @@ def predict(
dataset = ml_utils.init_dataset(neurograph, features, model_type)

# Run model
t3 = time()
proposal_probs = run_model(dataset, model, model_type)
accepts = build.get_accepted_proposals(
neurograph,
Expand Down
7 changes: 3 additions & 4 deletions src/deep_neurographs/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

"""

import networkx as nx
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from random import sample

import networkx as nx
import numpy as np

from deep_neurographs import graph_utils as gutils
Expand Down Expand Up @@ -62,7 +61,7 @@ def get_accepted_proposals(
high_threshold=0.9,
low_threshold=0.6,
structure_aware=True,
):
):
# Get positive edge predictions
preds = threshold_preds(preds, idx_to_edge, low_threshold)
if structure_aware:
Expand Down Expand Up @@ -130,7 +129,7 @@ def get_structure_aware_accepts(
good_probs.append(prob)

more_accepts = check_cycles_sequential(graph, good_preds, good_probs)
accepts.extend(more_accepts)
accepts.extend(more_accepts)
return accepts


Expand Down
9 changes: 1 addition & 8 deletions src/deep_neurographs/swc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,13 +341,6 @@ def save_edge(path, xyz_1, xyz_2, color=None, radius=6):
f.write(make_simple_entry(2, 1, xyz_2, radius=radius))


def set_radius(graph, i):
try:
return graph[i]["radius"]
except:
return 2


def make_entry(graph, i, parent, node_to_idx):
"""
Makes an entry to be written in an swc file.
Expand All @@ -368,7 +361,7 @@ def make_entry(graph, i, parent, node_to_idx):
...

"""
r = set_radius(graph, i)
r = graph[i]["radius"]
x, y, z = tuple(graph.nodes[i]["xyz"])
node_to_idx[i] = len(node_to_idx) + 1
entry = f"{node_to_idx[i]} 2 {x} {y} {z} {r} {node_to_idx[parent]}"
Expand Down
Loading
Loading