Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor rename fragmentsgraph #281

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 97 additions & 61 deletions src/deep_neurographs/fragment_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,83 +8,98 @@
other from a FragmentsGraph.

"""

from collections import defaultdict

import networkx as nx
import numpy as np
from networkx import connected_components
from tqdm import tqdm

from deep_neurographs import geometry
from deep_neurographs.utils import util

COLOR = "1.0 0.0 0.0"
QUERY_DIST = 15


# --- Curvy Removal ---
def remove_curvy(graph, max_length, ratio=0.5):
def remove_curvy(fragments_graph, max_length, ratio=0.5):
"""
Removes connected components with 2 nodes from "fragments_graph" that are
"curvy" fragments, based on a specified ratio of endpoint distance to edge
length and a maximum length threshold.

Parameters
----------
fragments_graph : FragmentsGraph
Graph generated from fragments of a predicted segmentation.
max_length : float
The maximum allowable length (in microns) for an edge to be considered
for removal.
ratio : float, optional
Threshold ratio of endpoint distance to edge length. Components with a
ratio below this value are considered "curvy" and are removed. The
default is 0.5.

Returns
-------
int
Number of fragments removed from the graph.

"""
deleted_ids = set()
components = [c for c in connected_components(graph) if len(c) == 2]
for nodes in tqdm(components, desc="Filter Curvy Fragment"):
if len(nodes) == 2:
i, j = tuple(nodes)
length = graph.edges[i, j]["length"]
endpoint_dist = graph.dist(i, j)
if endpoint_dist / length < ratio and length < max_length:
deleted_ids.add(graph.edges[i, j]["swc_id"])
delete_fragment(graph, i, j)
components = get_line_components(fragments_graph)
for nodes in tqdm(components, desc="Filter Curvy Fragments"):
i, j = tuple(nodes)
length = fragments_graph.edges[i, j]["length"]
endpoint_dist = fragments_graph.dist(i, j)
if endpoint_dist / length < ratio and length < max_length:
deleted_ids.add(fragments_graph.edges[i, j]["swc_id"])
delete_fragment(fragments_graph, i, j)
return len(deleted_ids)


# --- Doubles Removal ---
def remove_doubles(graph, max_length, node_spacing, output_dir=None):
def remove_doubles(fragments_graph, max_length, node_spacing):
"""
Removes connected components from "neurgraph" that are likely to be a
double.
Removes connected components from "fragments_graph" that are likely to be
a double -- caused by ghosting in the image.

Parameters
----------
graph : FragmentsGraph
fragments_graph : FragmentsGraph
Graph to be searched for doubles.
max_length : int
Maximum size of connected components to be searched.
node_spacing : int
Expected distance in microns between nodes in "graph".
output_dir : str or None, optional
Directory that doubles will be written to. The default is None.
Expected distance (in microns) between nodes in "fragments_graph".

Returns
-------
graph
Graph with doubles removed.
int
Number of fragments removed from graph.

"""
# Initializations
components = [c for c in connected_components(graph) if len(c) == 2]
components = get_line_components(fragments_graph)
deleted_ids = set()
kdtree = graph.get_kdtree()
if output_dir:
util.mkdir(output_dir, delete=True)
kdtree = fragments_graph.get_kdtree()

# Main
desc = "Filter Doubled Fragment"
desc = "Filter Doubled Fragments"
for idx in tqdm(np.argsort([len(c) for c in components]), desc=desc):
i, j = tuple(components[idx])
swc_id = graph.nodes[i]["swc_id"]
swc_id = fragments_graph.nodes[i]["swc_id"]
if swc_id not in deleted_ids:
if graph.edges[i, j]["length"] < max_length:
if fragments_graph.edges[i, j]["length"] < max_length:
# Check doubles criteria
n_points = len(graph.edges[i, j]["xyz"])
hits = compute_projections(graph, kdtree, (i, j))
n_points = len(fragments_graph.edges[i, j]["xyz"])
hits = compute_projections(fragments_graph, kdtree, (i, j))
if check_doubles_criteria(hits, n_points):
if output_dir:
graph.to_swc(output_dir, [i, j], color=COLOR)
delete_fragment(graph, i, j)
delete_fragment(fragments_graph, i, j)
deleted_ids.add(swc_id)
return len(deleted_ids)


def compute_projections(graph, kdtree, edge):
def compute_projections(fragments_graph, kdtree, edge):
"""
Given a fragment defined by "edge", this routine iterates of every xyz in
the fragment and projects it onto the closest fragment. For each detected
Expand All @@ -93,11 +108,11 @@ def compute_projections(graph, kdtree, edge):

Parameters
----------
graph : graph
fragments_graph : graph
Graph that contains "edge".
kdtree : KDTree
KD-Tree that contains all xyz coordinates of every fragment in
"graph".
"fragments_graph".
edge : tuple
Pair of leaf nodes that define a fragment.

Expand All @@ -109,13 +124,13 @@ def compute_projections(graph, kdtree, edge):

"""
hits = defaultdict(list)
query_id = graph.edges[edge]["swc_id"]
for i, xyz in enumerate(graph.edges[edge]["xyz"]):
query_id = fragments_graph.edges[edge]["swc_id"]
for i, xyz in enumerate(fragments_graph.edges[edge]["xyz"]):
# Compute projections
best_id = None
best_dist = np.inf
for hit_xyz in geometry.query_ball(kdtree, xyz, QUERY_DIST):
hit_id = graph.xyz_to_swc(hit_xyz)
hit_id = fragments_graph.xyz_to_swc(hit_xyz)
if hit_id is not None and hit_id != query_id:
if geometry.dist(hit_xyz, xyz) < best_dist:
best_dist = geometry.dist(hit_xyz, xyz)
Expand Down Expand Up @@ -157,54 +172,54 @@ def check_doubles_criteria(hits, n_points):
return False


def delete_fragment(graph, i, j):
def delete_fragment(fragments_graph, i, j):
"""
Deletes nodes "i" and "j" from "graph", where these nodes form a connected
component.
Deletes nodes "i" and "j" from "fragments_graph", where these nodes form a
connected component.

Parameters
----------
graph : FragmentsGraph
Graph that contains nodes to be deleted.
fragments_graph : FragmentsGraph
Graph that contains nodes to be removed.
i : int
Node to be removed.
j : int
Node to be removed.

Returns
-------
graph
fragments_graph
Graph with nodes removed.

"""
graph = remove_xyz_entries(graph, i, j)
graph.swc_ids.remove(graph.nodes[i]["swc_id"])
graph.remove_nodes_from([i, j])
fragments_graph = remove_xyz_entries(fragments_graph, i, j)
fragments_graph.swc_ids.remove(fragments_graph.nodes[i]["swc_id"])
fragments_graph.remove_nodes_from([i, j])


def remove_xyz_entries(graph, i, j):
def remove_xyz_entries(fragments_graph, i, j):
"""
Removes dictionary entries from "graph.xyz_to_edge" corresponding to
the edge {i, j}.
Removes dictionary entries from "fragments_graph.xyz_to_edge"
corresponding to the edge {i, j}.

Parameters
----------
graph : graph
fragments_graph : graph
Graph to be updated.
i : int
Node in "graph".
Node in graph.
j : int
Node in "graph".
Node in graph.

Returns
-------
graph
Updated graph.

"""
for xyz in graph.edges[i, j]["xyz"]:
del graph.xyz_to_edge[tuple(xyz)]
return graph
for xyz in fragments_graph.edges[i, j]["xyz"]:
del fragments_graph.xyz_to_edge[tuple(xyz)]
return fragments_graph


def upd_hits(hits, key, value):
Expand All @@ -215,8 +230,8 @@ def upd_hits(hits, key, value):
Parameters
----------
hits : dict
Stores swd_ids of fragments that are within a certain distance a query
fragment along with the corresponding distances.
Stores swd_ids of fragments within a certain distance a query fragment
along with the corresponding distances.
key : str
swc id of some fragment.
value : float
Expand All @@ -229,9 +244,30 @@ def upd_hits(hits, key, value):
Updated version of hits.

"""
if key in hits.keys():
if key in hits:
if value < hits[key]:
hits[key] = value
else:
hits[key] = value
return hits


# --- utils ---
def get_line_components(graph):
"""
Identifies and returns all line components in the given graph. A line
component is defined as a connected component with exactly two nodes.

Parameters
----------
graph : networkx.Graph
Input graph in which line components are to be identified.

Returns
-------
List[set]
List of sets, where each set contains two nodes representing a
connected component with exactly two nodes.

"""
return [c for c in nx.connected_components(graph) if len(c) == 2]
Loading
Loading