Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: write neurograph to zip #171

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_irreducibles(
# Extract irreducibles
irreducibles = []
for node_subset in nx.connected_components(graph):
if len(node_subset) > 10:
if len(node_subset) > 30:
subgraph = graph.subgraph(node_subset)
irreducibles_i = __get_irreducibles(subgraph, swc_dict, smooth)
if irreducibles_i:
Expand Down Expand Up @@ -800,3 +800,20 @@ def get_component(graph, root):
if (i, j) in graph.edges:
queue.append(j)
return component


def count_components(graph):
"""
Counts the number of connected components in a graph.

Paramters
---------
graph : networkx.Graph
Graph to be searched.

Returns
-------
Number of connected components.

"""
return len(list(nx.connected_components(graph)))
1 change: 0 additions & 1 deletion src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def run_without_seeds(
)

# Merge proposals
print("# components:", len(list(nx.connected_components(neurograph))))
neurograph = build.fuse_branches(neurograph, accepts_i)
accepts.extend(accepts_i)

Expand Down
132 changes: 91 additions & 41 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

"""
import os
import zipfile
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from io import StringIO
from random import sample

import networkx as nx
Expand Down Expand Up @@ -707,37 +709,94 @@ def component_cardinality(self, root):
queue.append((j, k))
return cardinality

def to_swc(self, path):
def to_zipped_swcs(self, zip_path):
n_components = gutils.count_components(self)
print(f"Writing {n_components} swcs to local machine!")
with zipfile.ZipFile(zip_path, "w") as zipf:
for nodes in nx.connected_components(self):
self.to_zipped_swc(zipf, nodes)

def to_zipped_swc(self, zipf, nodes):
with StringIO() as text_buffer:
n_entries = 0
node_to_idx = dict()
text_buffer.write("# id, type, z, y, x, r, pid")
for i, j in nx.dfs_edges(self.subgraph(nodes)):
# Initialize
if n_entries == 0:
swc_id = self.nodes[i]["swc_id"]
x, y, z = tuple(self.nodes[i]["xyz"])
r = self.nodes[i]["radius"]
text_buffer.write("\n" + f"1 2 {x} {y} {z} {r} -1")
node_to_idx[i] = 1
n_entries += 1

# Create entry
parent = node_to_idx[i]
text_buffer, n_entries = self.branch_to_zip(
text_buffer, n_entries, i, j, parent
)
node_to_idx[j] = n_entries
zipf.writestr(f"{swc_id}.swc", text_buffer.getvalue())

def to_swcs(self, swc_dir):
"""
Write a neurograph to "swc_dir" such that each connected component is
saved as an swc file.

Parameters
----------
swc_dir : str
Directory that neurograph is to be written to

Returns
-------
None

"""
with ThreadPoolExecutor() as executor:
threads = []
print("# swcs:", len(list(nx.connected_components(self))))
for i, component in enumerate(nx.connected_components(self)):
node = sample(component, 1)[0]
swc_id = self.nodes[node]["swc_id"]
path_i = os.path.join(path, f"{swc_id}.swc")
threads.append(
executor.submit(self.component_to_swc, path_i, component)
)
n_components = gutils.count_components(self)
print(f"Writing {n_components} swcs to local machine!")
for i, nodes in enumerate(nx.connected_components(self)):
threads.append(executor.submit(self.to_swc, swc_dir, nodes))

def component_to_swc(self, path, component):
node_to_idx = dict()
def to_swc(self, swc_dir, nodes):
"""
Generates list of swc entries for a given connected component.

Parameters
----------
path : str
Path that swc for component will be written to.
Component : list[int]
List of nodes contained in "component".

Returns
-------
None.

"""
entry_list = []
for i, j in nx.dfs_edges(self.subgraph(component)):
node_to_idx = dict()
for i, j in nx.dfs_edges(self.subgraph(nodes)):
# Initialize
if len(entry_list) == 0:
x, y, z = tuple(self.nodes[i]["xyz"])
r = self.nodes[i]["radius"]
entry_list.append(f"1 2 {x} {y} {z} {r} -1")
node_to_idx[i] = 1

filename = self.nodes[i]["swc_id"] + ".swc"
path = os.path.join(swc_dir, filename)

# Create entry
parent = node_to_idx[i]
entry_list = self.branch_to_entries(entry_list, i, j, parent)
node_to_idx[j] = len(entry_list)

# Write
if len(entry_list) > 0:
swc_utils.write(path, entry_list)
swc_utils.write(path, entry_list)

def branch_to_entries(self, entry_list, i, j, parent):
# Orient branch
Expand All @@ -757,6 +816,24 @@ def branch_to_entries(self, entry_list, i, j, parent):
entry_list.append(entry)
return entry_list

def branch_to_zip(self, text_buffer, n_entries, i, j, parent):
# Orient branch
branch_xyz = self.edges[i, j]["xyz"]
branch_radius = self.edges[i, j]["radius"]
if (branch_xyz[0] != self.nodes[i]["xyz"]).any():
branch_xyz = np.flip(branch_xyz, axis=0)
branch_radius = np.flip(branch_radius, axis=0)

# Make entries
for k in range(1, len(branch_xyz)):
x, y, z = tuple(branch_xyz[k])
r = branch_radius[k]
node_id = n_entries + 1
parent = n_entries if k > 1 else parent
text_buffer.write("\n" + f"{node_id} 2 {x} {y} {z} {r} {parent}")
n_entries += 1
return text_buffer, n_entries

def near_proposal(self, root, depth):
# Check root
if len(self.nodes[root]["proposals"]) > 0:
Expand All @@ -767,30 +844,3 @@ def near_proposal(self, root, depth):
if len(self.nodes[j]["proposals"]) > 0:
return True
return False

def remove_components_without_proposals(self):
remove_nodes = set()
n_components_removed = 0
for component in nx.connected_components(self):
# Check for proposals
hit_proposal = False
for i in component:
if len(self.nodes[i]["proposals"]) > 0:
hit_proposal = True
break

# Check whether hit proposal
if not hit_proposal:
remove_nodes = remove_nodes.union(set(component))
n_components_removed += 1

self.remove_nodes_from(remove_nodes)


def connected_components_with_proposals(self):
nodes = set(self.nodes)
connected_components = list()
while len(nodes) > 0:
root = utils.sample_singleton(nodes)
connected_components.append(self.get_component(root))
return connected_components
16 changes: 10 additions & 6 deletions src/deep_neurographs/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,23 @@ def fuse_branches(neurograph, edges):


# -- Save result --
def save_prediction(neurograph, accepted_proposals, output_dir):
def save_prediction(
neurograph, accepted_proposals, output_dir, save_results=False
):
# Initializations
connections_path = os.path.join(output_dir, "connections.txt")
corrections_dir = os.path.join(output_dir, "corrections")
swc_dir = os.path.join(output_dir, "corrected-processed-swcs")
utils.mkdir(output_dir, delete=True)
swc_zip_path = os.path.join(output_dir, "corrected-processed-swcs.zip")
utils.mkdir(corrections_dir, delete=True)
utils.mkdir(swc_dir, delete=True)

# Write Result
neurograph.to_swc(swc_dir)
save_corrections(neurograph, accepted_proposals, corrections_dir)
n_swcs = gutils.count_components(neurograph)
save_connections(neurograph, accepted_proposals, connections_path)
if save_results:
neurograph.to_zipped_swcs(swc_zip_path)
# save_corrections(neurograph, accepted_proposals, corrections_dir)
else:
print(f"Result contains {n_swcs} swcs!")


def save_corrections(neurograph, accepted_proposals, output_dir):
Expand Down
Loading