Skip to content

Commit

Permalink
refactor: optimize image profile generation (#100)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Apr 3, 2024
1 parent 5f31bbc commit 2d9c6f0
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 69 deletions.
16 changes: 4 additions & 12 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def build_neurograph_from_local(
# Process swc files
assert swc_dir or swc_paths, "Provide swc_dir or swc_paths!"
img_bbox = utils.get_img_bbox(img_patch_origin, img_patch_shape)
paths = get_paths(swc_dir) if swc_dir else swc_paths
paths = utils.list_paths(swc_dir, ext=".swc") if swc_dir else swc_paths
swc_dicts, paths = process_local_paths(
paths, anisotropy=anisotropy, min_size=min_size, img_bbox=img_bbox
)
Expand Down Expand Up @@ -213,11 +213,9 @@ def build_neurograph_from_gcs_zips(
smooth=smooth,
)
t, unit = utils.time_writer(time() - t0)
print(f"Memory Consumption: {round(utils.get_memory_usage(), 4)} GBs")
print(f"Module Runtime: {round(t, 4)} {unit} \n")

t, unit = utils.time_writer(time() - total_runtime)
print(f"Total Runtime: {round(t, 4)} {unit}")
print(f"Memory Consumption: {round(utils.get_memory_usage(), 4)} GBs")
return neurograph


Expand Down Expand Up @@ -246,7 +244,7 @@ def download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy):
# Initializations
bucket = storage.Client().bucket(bucket_name)
zip_paths = utils.list_gcs_filenames(bucket, gcs_path, ".zip")
chunk_size = int(len(zip_paths) * 0.02)
chunk_size = int(len(zip_paths) * 0.1)

# Parse
cnt = 1
Expand All @@ -262,6 +260,7 @@ def download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy):
cnt, t1 = report_progress(
i, len(zip_paths), chunk_size, cnt, t0, t1
)
break

return swc_dicts

Expand Down Expand Up @@ -394,13 +393,6 @@ def count_edges(irreducibles):


# -- Utils --
def get_paths(swc_dir):
paths = []
for f in utils.listdir(swc_dir, ext=".swc"):
paths.append(os.path.join(swc_dir, f))
return paths


def report_progress(current, total, chunk_size, cnt, t0, t1):
eta = get_eta(current, total, chunk_size, t1)
runtime = get_runtime(current, total, chunk_size, t0, t1)
Expand Down
95 changes: 41 additions & 54 deletions src/deep_neurographs/machine_learning/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@

import numpy as np
import tensorstore as ts
from time import time

from deep_neurographs import geometry, utils

CHUNK_SIZE = [64, 64, 64]
WINDOW = [5, 5, 5]
N_PROFILE_POINTS = 10
N_PROFILE_PTS = 10
N_SKEL_FEATURES = 19
SUPPORTED_MODELS = [
"AdaBoost",
Expand Down Expand Up @@ -72,7 +73,10 @@ def generate_features(
proposals = neurograph.get_proposals()

# Generate features
t0 = time()
features = {"skel": generate_skel_features(neurograph, proposals)}
print(" generate_skel_features():", time() - t0)

if model_type in ["ConvNet", "MultiModalNet"]:
msg = "Must provide img_path and label_path for model_type!"
assert img_path and labels_path, msg
Expand Down Expand Up @@ -152,7 +156,7 @@ def get_img_chunk(img, labels, coord_0, coord_1, thread_id=None):
patch_coord_1 = utils.img_to_patch(coord_1, midpoint, CHUNK_SIZE)

# Generate features
path = geometry.make_line(patch_coord_0, patch_coord_1, N_PROFILE_POINTS)
path = geometry.make_line(patch_coord_0, patch_coord_1, N_PROFILE_PTS)
profile = geometry.get_profile(img_chunk, path, window=WINDOW)
labels_chunk[labels_chunk > 0] = 1
labels_chunk = geometry.fill_path(labels_chunk, path, val=2)
Expand All @@ -166,52 +170,19 @@ def get_img_chunk(img, labels, coord_0, coord_1, thread_id=None):


def generate_img_profiles(neurograph, proposals, path):
if False: # neurograph.bbox:
return generate_img_profiles_via_superchunk(
neurograph, proposals, path
)
else:
return generate_img_profiles_via_multithreads(
neurograph, proposals, path
)


def generate_img_profiles_via_multithreads(neurograph, proposals, path):
profile_features = dict()
driver = "n5" if ".n5" in path else "zarr"
img = utils.open_tensorstore(path, driver)
with ThreadPoolExecutor() as executor:
# Assign threads
threads = [None] * len(proposals)
for i, edge in enumerate(proposals):
xyz_0, xyz_1 = neurograph.proposal_xyz(edge)
coord_0 = utils.to_img(xyz_0)
coord_1 = utils.to_img(xyz_1)
line = geometry.make_line(coord_0, coord_1, N_PROFILE_POINTS)
threads[i] = executor.submit(geometry.get_profile, img, line, edge)

# Store result
for thread in as_completed(threads):
edge, profile = thread.result()
profile_features[edge] = profile
return profile_features


def generate_img_profiles_via_superchunk(neurograph, proposals, path):
"""
Generates an image intensity profile along each edge proposal by reading
a single superchunk from cloud that contains all proposals.
from an image on the cloud.
Parameters
----------
neurograph : NeuroGraph
NeuroGraph generated from a directory of swcs generated from a
predicted segmentation.
proposals : list[frozenset]
List of edge proposals for which features will be generated.
path : str
Path to raw image.
proposals : list[frozenset], optional
List of edge proposals for which features will be generated. The
default is None.
Path to image on GCS bucket.
Returns
-------
Expand All @@ -220,25 +191,41 @@ def generate_img_profiles_via_superchunk(neurograph, proposals, path):
profile.
"""
features = dict()
driver = "n5" if ".n5" in path else "zarr"
img = utils.get_superchunk(
path, driver, neurograph.origin, neurograph.shape, from_center=False
)
img = utils.normalize_img(img)
for edge in neurograph.proposals:
# Generate coordinates to be read
coords = set()
lines = dict()
t0 = time()
for i, edge in enumerate(proposals):
xyz_0, xyz_1 = neurograph.proposal_xyz(edge)
coord_0 = utils.to_img(xyz_0) - neurograph.origin
coord_1 = utils.to_img(xyz_1) - neurograph.origin
path = geometry.make_line(coord_0, coord_1, N_PROFILE_POINTS)
features[edge] = geometry.get_profile(img, path, window=WINDOW)
return features
coord_0 = utils.to_img(xyz_0)
coord_1 = utils.to_img(xyz_1)
lines[edge] = geometry.make_line(coord_0, coord_1, N_PROFILE_PTS)
for coord in lines[edge]:
coords.add(tuple(coord))
print(" generate_coords():", time() - t0)

# Read image intensities
t0 = time()
driver = "n5" if ".n5" in path else "zarr"
img = utils.open_tensorstore(path, driver)
print(" open_img():", time() - t0)

t0 = time()
img_intensity = utils.read_img_intensities(img, list(coords))
print(" read_img_intensities():", time() - t0)

# Generate intensity profiles
t0 = time()
profile_features = dict()
for edge, line in lines.items():
profile_features[edge] = [img_intensity[tuple(xyz)] for xyz in line]
print(" generate_profiles():", time() - t0)
return profile_features


def generate_skel_features(neurograph, proposals):
features = dict()
for edge in proposals:
# print("Proposals:", edge)
i, j = tuple(edge)
features[edge] = np.concatenate(
(
Expand Down Expand Up @@ -403,7 +390,7 @@ def get_multimodal_features(neurograph, features, shift=0):
# Initialize
n_edges = neurograph.n_proposals()
X = np.zeros(((n_edges, 2) + tuple(CHUNK_SIZE)))
x = np.zeros((n_edges, N_SKEL_FEATURES + N_PROFILE_POINTS))
x = np.zeros((n_edges, N_SKEL_FEATURES + N_PROFILE_PTS))
y = np.zeros((n_edges))

# Build
Expand Down Expand Up @@ -453,7 +440,7 @@ def count_features(model_type):
Number of features.
"""
if model_type != "ConvNet":
return N_SKEL_FEATURES + N_PROFILE_POINTS + 2
return N_SKEL_FEATURES + N_PROFILE_PTS + 2


def combine_features(features):
Expand Down
53 changes: 52 additions & 1 deletion src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import networkx as nx
import numpy as np
import torch
from time import time
from torch.nn.functional import sigmoid
from torch.utils.data import DataLoader

Expand All @@ -25,7 +26,7 @@
from deep_neurographs.machine_learning import ml_utils
from deep_neurographs.neurograph import NeuroGraph

BATCH_SIZE_PROPOSALS = 2000
BATCH_SIZE_PROPOSALS = 1000
CHUNK_SHAPE = (256, 256, 256)


Expand Down Expand Up @@ -96,13 +97,23 @@ def run_without_seeds(
proposals,
batch_size_proposals=BATCH_SIZE_PROPOSALS,
confidence_threshold=0.7,
progress_bar=True,
):
# Initializations
dists = [neurograph.proposal_length(edge) for edge in proposals]
batches = utils.get_batch(np.argsort(dists), batch_size_proposals)
model = ml_utils.load_model(model_type, model_path)
n_batches = 1 + len(proposals) // BATCH_SIZE_PROPOSALS
print("# batches:", n_batches)

# Run
preds = []
progress_cnt = 1
t0, t1 = utils.init_timers()
chunk_size = max(int(n_batches * 0.02), 1)
for i, batch in enumerate(batches):
# Prediction
t2 = time()
proposals_i = [proposals[j] for j in batch]
preds_i = predict(
neurograph,
Expand All @@ -115,8 +126,18 @@ def run_without_seeds(
)

# Merge proposals
t2 = time()
preds.extend(preds_i)
stop
neurograph = build.fuse_branches(neurograph, preds_i)
print("fuse_branches():", time() - t2)

# Report progress
if i > progress_cnt * chunk_size and progress_bar:
progress_cnt, t1 = report_progress(
i, n_batches, chunk_size, progress_cnt, t0, t1
)
t0, t1 = utils.init_timers()

return neurograph, preds

Expand All @@ -131,6 +152,7 @@ def predict(
confidence_threshold=0.7,
):
# Generate features
t3 = time()
features = extracter.generate_features(
neurograph,
model_type,
Expand All @@ -139,16 +161,22 @@ def predict(
proposals=proposals,
)
dataset = ml_utils.init_dataset(neurograph, features, model_type)
print(" generate_features():", time() - t3)

# Run model
t3 = time()
proposal_probs = run_model(dataset, model, model_type)
print(" run_model():", time() - t3)

t3 = time()
proposal_preds = build.get_reconstruction(
neurograph,
proposal_probs,
dataset["idx_to_edge"],
high_threshold=0.95,
low_threshold=confidence_threshold,
)
print(" get_reconstruction():", time() - t3)
return proposal_preds


Expand Down Expand Up @@ -252,3 +280,26 @@ def run_model(dataset, model, model_type):
else:
hat_y = model.predict_proba(data["inputs"])[:, 1]
return np.array(hat_y)


# Utils
def report_progress(current, total, chunk_size, cnt, t0, t1):
eta = get_eta(current, total, chunk_size, t1)
runtime = get_runtime(current, total, chunk_size, t0, t1)
utils.progress_bar(current, total, eta=eta, runtime=runtime)
return cnt + 1, time()


def get_eta(current, total, chunk_size, t0, return_str=True):
chunk_runtime = time() - t0
remaining = total - current
eta = remaining * (chunk_runtime / chunk_size)
t, unit = utils.time_writer(eta)
return f"{round(t, 4)} {unit}" if return_str else eta


def get_runtime(current, total, chunk_size, t0, t1):
eta = get_eta(current, total, chunk_size, t1, return_str=False)
total_runtime = time() - t0 + eta
t, unit = utils.time_writer(total_runtime)
return f"{round(t, 4)} {unit}"
Loading

0 comments on commit 2d9c6f0

Please sign in to comment.