Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: optimize image profile generation #100

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def build_neurograph_from_local(
# Process swc files
assert swc_dir or swc_paths, "Provide swc_dir or swc_paths!"
img_bbox = utils.get_img_bbox(img_patch_origin, img_patch_shape)
paths = get_paths(swc_dir) if swc_dir else swc_paths
paths = utils.list_paths(swc_dir, ext=".swc") if swc_dir else swc_paths
swc_dicts, paths = process_local_paths(
paths, anisotropy=anisotropy, min_size=min_size, img_bbox=img_bbox
)
Expand Down Expand Up @@ -213,11 +213,9 @@ def build_neurograph_from_gcs_zips(
smooth=smooth,
)
t, unit = utils.time_writer(time() - t0)
print(f"Memory Consumption: {round(utils.get_memory_usage(), 4)} GBs")
print(f"Module Runtime: {round(t, 4)} {unit} \n")

t, unit = utils.time_writer(time() - total_runtime)
print(f"Total Runtime: {round(t, 4)} {unit}")
print(f"Memory Consumption: {round(utils.get_memory_usage(), 4)} GBs")
return neurograph


Expand Down Expand Up @@ -246,7 +244,7 @@ def download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy):
# Initializations
bucket = storage.Client().bucket(bucket_name)
zip_paths = utils.list_gcs_filenames(bucket, gcs_path, ".zip")
chunk_size = int(len(zip_paths) * 0.02)
chunk_size = int(len(zip_paths) * 0.1)

# Parse
cnt = 1
Expand All @@ -262,6 +260,7 @@ def download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy):
cnt, t1 = report_progress(
i, len(zip_paths), chunk_size, cnt, t0, t1
)
break

return swc_dicts

Expand Down Expand Up @@ -394,13 +393,6 @@ def count_edges(irreducibles):


# -- Utils --
def get_paths(swc_dir):
paths = []
for f in utils.listdir(swc_dir, ext=".swc"):
paths.append(os.path.join(swc_dir, f))
return paths


def report_progress(current, total, chunk_size, cnt, t0, t1):
eta = get_eta(current, total, chunk_size, t1)
runtime = get_runtime(current, total, chunk_size, t0, t1)
Expand Down
95 changes: 41 additions & 54 deletions src/deep_neurographs/machine_learning/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@

import numpy as np
import tensorstore as ts
from time import time

from deep_neurographs import geometry, utils

CHUNK_SIZE = [64, 64, 64]
WINDOW = [5, 5, 5]
N_PROFILE_POINTS = 10
N_PROFILE_PTS = 10
N_SKEL_FEATURES = 19
SUPPORTED_MODELS = [
"AdaBoost",
Expand Down Expand Up @@ -72,7 +73,10 @@ def generate_features(
proposals = neurograph.get_proposals()

# Generate features
t0 = time()
features = {"skel": generate_skel_features(neurograph, proposals)}
print(" generate_skel_features():", time() - t0)

if model_type in ["ConvNet", "MultiModalNet"]:
msg = "Must provide img_path and label_path for model_type!"
assert img_path and labels_path, msg
Expand Down Expand Up @@ -152,7 +156,7 @@ def get_img_chunk(img, labels, coord_0, coord_1, thread_id=None):
patch_coord_1 = utils.img_to_patch(coord_1, midpoint, CHUNK_SIZE)

# Generate features
path = geometry.make_line(patch_coord_0, patch_coord_1, N_PROFILE_POINTS)
path = geometry.make_line(patch_coord_0, patch_coord_1, N_PROFILE_PTS)
profile = geometry.get_profile(img_chunk, path, window=WINDOW)
labels_chunk[labels_chunk > 0] = 1
labels_chunk = geometry.fill_path(labels_chunk, path, val=2)
Expand All @@ -166,52 +170,19 @@ def get_img_chunk(img, labels, coord_0, coord_1, thread_id=None):


def generate_img_profiles(neurograph, proposals, path):
if False: # neurograph.bbox:
return generate_img_profiles_via_superchunk(
neurograph, proposals, path
)
else:
return generate_img_profiles_via_multithreads(
neurograph, proposals, path
)


def generate_img_profiles_via_multithreads(neurograph, proposals, path):
profile_features = dict()
driver = "n5" if ".n5" in path else "zarr"
img = utils.open_tensorstore(path, driver)
with ThreadPoolExecutor() as executor:
# Assign threads
threads = [None] * len(proposals)
for i, edge in enumerate(proposals):
xyz_0, xyz_1 = neurograph.proposal_xyz(edge)
coord_0 = utils.to_img(xyz_0)
coord_1 = utils.to_img(xyz_1)
line = geometry.make_line(coord_0, coord_1, N_PROFILE_POINTS)
threads[i] = executor.submit(geometry.get_profile, img, line, edge)

# Store result
for thread in as_completed(threads):
edge, profile = thread.result()
profile_features[edge] = profile
return profile_features


def generate_img_profiles_via_superchunk(neurograph, proposals, path):
"""
Generates an image intensity profile along each edge proposal by reading
a single superchunk from cloud that contains all proposals.
from an image on the cloud.

Parameters
----------
neurograph : NeuroGraph
NeuroGraph generated from a directory of swcs generated from a
predicted segmentation.
proposals : list[frozenset]
List of edge proposals for which features will be generated.
path : str
Path to raw image.
proposals : list[frozenset], optional
List of edge proposals for which features will be generated. The
default is None.
Path to image on GCS bucket.

Returns
-------
Expand All @@ -220,25 +191,41 @@ def generate_img_profiles_via_superchunk(neurograph, proposals, path):
profile.

"""
features = dict()
driver = "n5" if ".n5" in path else "zarr"
img = utils.get_superchunk(
path, driver, neurograph.origin, neurograph.shape, from_center=False
)
img = utils.normalize_img(img)
for edge in neurograph.proposals:
# Generate coordinates to be read
coords = set()
lines = dict()
t0 = time()
for i, edge in enumerate(proposals):
xyz_0, xyz_1 = neurograph.proposal_xyz(edge)
coord_0 = utils.to_img(xyz_0) - neurograph.origin
coord_1 = utils.to_img(xyz_1) - neurograph.origin
path = geometry.make_line(coord_0, coord_1, N_PROFILE_POINTS)
features[edge] = geometry.get_profile(img, path, window=WINDOW)
return features
coord_0 = utils.to_img(xyz_0)
coord_1 = utils.to_img(xyz_1)
lines[edge] = geometry.make_line(coord_0, coord_1, N_PROFILE_PTS)
for coord in lines[edge]:
coords.add(tuple(coord))
print(" generate_coords():", time() - t0)

# Read image intensities
t0 = time()
driver = "n5" if ".n5" in path else "zarr"
img = utils.open_tensorstore(path, driver)
print(" open_img():", time() - t0)

t0 = time()
img_intensity = utils.read_img_intensities(img, list(coords))
print(" read_img_intensities():", time() - t0)

# Generate intensity profiles
t0 = time()
profile_features = dict()
for edge, line in lines.items():
profile_features[edge] = [img_intensity[tuple(xyz)] for xyz in line]
print(" generate_profiles():", time() - t0)
return profile_features


def generate_skel_features(neurograph, proposals):
features = dict()
for edge in proposals:
# print("Proposals:", edge)
i, j = tuple(edge)
features[edge] = np.concatenate(
(
Expand Down Expand Up @@ -403,7 +390,7 @@ def get_multimodal_features(neurograph, features, shift=0):
# Initialize
n_edges = neurograph.n_proposals()
X = np.zeros(((n_edges, 2) + tuple(CHUNK_SIZE)))
x = np.zeros((n_edges, N_SKEL_FEATURES + N_PROFILE_POINTS))
x = np.zeros((n_edges, N_SKEL_FEATURES + N_PROFILE_PTS))
y = np.zeros((n_edges))

# Build
Expand Down Expand Up @@ -453,7 +440,7 @@ def count_features(model_type):
Number of features.
"""
if model_type != "ConvNet":
return N_SKEL_FEATURES + N_PROFILE_POINTS + 2
return N_SKEL_FEATURES + N_PROFILE_PTS + 2


def combine_features(features):
Expand Down
53 changes: 52 additions & 1 deletion src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import networkx as nx
import numpy as np
import torch
from time import time
from torch.nn.functional import sigmoid
from torch.utils.data import DataLoader

Expand All @@ -25,7 +26,7 @@
from deep_neurographs.machine_learning import ml_utils
from deep_neurographs.neurograph import NeuroGraph

BATCH_SIZE_PROPOSALS = 2000
BATCH_SIZE_PROPOSALS = 1000
CHUNK_SHAPE = (256, 256, 256)


Expand Down Expand Up @@ -96,13 +97,23 @@ def run_without_seeds(
proposals,
batch_size_proposals=BATCH_SIZE_PROPOSALS,
confidence_threshold=0.7,
progress_bar=True,
):
# Initializations
dists = [neurograph.proposal_length(edge) for edge in proposals]
batches = utils.get_batch(np.argsort(dists), batch_size_proposals)
model = ml_utils.load_model(model_type, model_path)
n_batches = 1 + len(proposals) // BATCH_SIZE_PROPOSALS
print("# batches:", n_batches)

# Run
preds = []
progress_cnt = 1
t0, t1 = utils.init_timers()
chunk_size = max(int(n_batches * 0.02), 1)
for i, batch in enumerate(batches):
# Prediction
t2 = time()
proposals_i = [proposals[j] for j in batch]
preds_i = predict(
neurograph,
Expand All @@ -115,8 +126,18 @@ def run_without_seeds(
)

# Merge proposals
t2 = time()
preds.extend(preds_i)
stop
neurograph = build.fuse_branches(neurograph, preds_i)
print("fuse_branches():", time() - t2)

# Report progress
if i > progress_cnt * chunk_size and progress_bar:
progress_cnt, t1 = report_progress(
i, n_batches, chunk_size, progress_cnt, t0, t1
)
t0, t1 = utils.init_timers()

return neurograph, preds

Expand All @@ -131,6 +152,7 @@ def predict(
confidence_threshold=0.7,
):
# Generate features
t3 = time()
features = extracter.generate_features(
neurograph,
model_type,
Expand All @@ -139,16 +161,22 @@ def predict(
proposals=proposals,
)
dataset = ml_utils.init_dataset(neurograph, features, model_type)
print(" generate_features():", time() - t3)

# Run model
t3 = time()
proposal_probs = run_model(dataset, model, model_type)
print(" run_model():", time() - t3)

t3 = time()
proposal_preds = build.get_reconstruction(
neurograph,
proposal_probs,
dataset["idx_to_edge"],
high_threshold=0.95,
low_threshold=confidence_threshold,
)
print(" get_reconstruction():", time() - t3)
return proposal_preds


Expand Down Expand Up @@ -252,3 +280,26 @@ def run_model(dataset, model, model_type):
else:
hat_y = model.predict_proba(data["inputs"])[:, 1]
return np.array(hat_y)


# Utils
def report_progress(current, total, chunk_size, cnt, t0, t1):
eta = get_eta(current, total, chunk_size, t1)
runtime = get_runtime(current, total, chunk_size, t0, t1)
utils.progress_bar(current, total, eta=eta, runtime=runtime)
return cnt + 1, time()


def get_eta(current, total, chunk_size, t0, return_str=True):
chunk_runtime = time() - t0
remaining = total - current
eta = remaining * (chunk_runtime / chunk_size)
t, unit = utils.time_writer(eta)
return f"{round(t, 4)} {unit}" if return_str else eta


def get_runtime(current, total, chunk_size, t0, t1):
eta = get_eta(current, total, chunk_size, t1, return_str=False)
total_runtime = time() - t0 + eta
t, unit = utils.time_writer(total_runtime)
return f"{round(t, 4)} {unit}"
Loading
Loading