Skip to content

Commit

Permalink
feat: multimodal model and dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Oct 12, 2024
1 parent 07de06f commit 97f62a6
Show file tree
Hide file tree
Showing 14 changed files with 268 additions and 157 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/generate_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,4 +484,4 @@ def tangent(branch, idx, depth):
"""
end = min(idx + depth, len(branch))
return geometry.tangent(branch[idx:end])
return geometry.tangent(branch[idx:end])
2 changes: 1 addition & 1 deletion src/deep_neurographs/groundtruth_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,4 @@ def orient_branch(branch_i, branch_j):
def upd_dict(node_to_target_id, nodes, target_id):
for node in nodes:
node_to_target_id[node] = target_id
return node_to_target_id
return node_to_target_id
77 changes: 42 additions & 35 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
from tqdm import tqdm

from deep_neurographs.graph_artifact_removal import remove_doubles
from deep_neurographs.machine_learning import feature_generation
from deep_neurographs.machine_learning.feature_generation import (
FeatureGenerator,
)
from deep_neurographs.utils import gnn_util
from deep_neurographs.utils import graph_util as gutil
from deep_neurographs.utils import img_util, ml_util, util
from deep_neurographs.utils import ml_util, util
from deep_neurographs.utils.gnn_util import toCPU
from deep_neurographs.utils.graph_util import GraphLoader

Expand Down Expand Up @@ -65,6 +67,8 @@ def __init__(
output_dir,
config,
device=None,
is_multimodal=False,
label_path=None,
):
"""
Initializes an object that executes the full GraphTrace inference
Expand All @@ -79,7 +83,7 @@ def __init__(
Identifier for the predicted segmentation to be processed by the
inference pipeline.
img_path : str
Path to the raw image of whole brain stored on a GCS bucket.
Path to the raw image assumed to be stored in a GCS bucket.
model_path : str
Path to machine learning model parameters.
output_dir : str
Expand All @@ -89,6 +93,10 @@ def __init__(
for the inference pipeline.
device : str, optional
...
label_path : str, optional
Path to the segmentation assumed to be stored on a GCS bucket.
is_multimodal : bool, optional
...
Returns
-------
Expand All @@ -99,7 +107,6 @@ def __init__(
self.accepted_proposals = list()
self.sample_id = sample_id
self.segmentation_id = segmentation_id
self.img_path = img_path
self.model_path = model_path

# Extract config settings
Expand All @@ -108,13 +115,15 @@ def __init__(

# Inference engine
self.inference_engine = InferenceEngine(
self.img_path,
img_path,
self.model_path,
self.ml_config.model_type,
self.graph_config.search_radius,
confidence_threshold=self.ml_config.threshold,
device=device,
downsample_factor=self.ml_config.downsample_factor,
label_path=label_path,
is_multimodal=is_multimodal,
)

# Set output directory
Expand Down Expand Up @@ -153,15 +162,15 @@ def run(self, fragments_pointer):
print(f"Total Runtime: {round(t, 4)} {unit}\n")

def run_schedule(
self, fragments_pointer, search_radius_schedule, save_all_rounds=False
self, fragments_pointer, radius_schedule, save_all_rounds=False
):
t0 = time()
self.report_experiment()
self.build_graph(fragments_pointer)
for round_id, search_radius in enumerate(search_radius_schedule):
print(f"--- Round {round_id + 1}: Radius = {search_radius} ---")
for round_id, radius in enumerate(radius_schedule):
print(f"--- Round {round_id + 1}: Radius = {radius} ---")
round_id += 1
self.generate_proposals(search_radius)
self.generate_proposals(radius)
self.run_inference()
if save_all_rounds:
self.save_results(round_id=round_id)
Expand Down Expand Up @@ -213,7 +222,7 @@ def build_graph(self, fragments_pointer):
print(f"Module Runtime: {round(t, 4)} {unit}\n")
self.print_graph_overview()

def generate_proposals(self, search_radius=None):
def generate_proposals(self, radius=None):
"""
Generates proposals for the fragment graph based on the specified
configuration.
Expand All @@ -229,13 +238,13 @@ def generate_proposals(self, search_radius=None):
"""
# Initializations
print("(2) Generate Proposals")
if search_radius is None:
search_radius = self.graph_config.search_radius
if radius is None:
radius = self.graph_config.radius

# Main
t0 = time()
self.graph.generate_proposals(
search_radius,
radius,
complex_bool=self.graph_config.complex_bool,
long_range_bool=self.graph_config.long_range_bool,
proposals_per_leaf=self.graph_config.proposals_per_leaf,
Expand Down Expand Up @@ -392,11 +401,13 @@ def __init__(
img_path,
model_path,
model_type,
search_radius,
radius,
batch_size=BATCH_SIZE,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=None,
downsample_factor=1,
label_path=None,
is_multimodal=False
):
"""
Initializes an inference engine by loading images and setting class
Expand All @@ -410,7 +421,7 @@ def __init__(
Path to machine learning model parameters.
model_type : str
Type of machine learning model used to perform inference.
search_radius : float
radius : float
Search radius used to generate proposals.
batch_size : int, optional
Number of proposals to generate features and classify per batch.
Expand All @@ -429,16 +440,20 @@ def __init__(
"""
# Set class attributes
self.batch_size = batch_size
self.downsample_factor = downsample_factor
self.device = "cpu" if device is None else device
self.is_gnn = True if "Graph" in model_type else False
self.model_type = model_type
self.search_radius = search_radius
self.radius = radius
self.threshold = confidence_threshold

# Load image and model
driver = "n5" if ".n5" in img_path else "zarr"
self.img = img_util.open_tensorstore(img_path, driver=driver)
# Features
self.feature_generator = FeatureGenerator(
img_path,
downsample_factor,
label_path=label_path,
is_multimodal=is_multimodal
)

# Model
self.model = ml_util.load_model(model_path)
if self.is_gnn:
self.model = self.model.to(self.device)
Expand Down Expand Up @@ -532,22 +547,14 @@ def get_batch_dataset(self, neurograph, batch):
...
"""
# Generate features
features = feature_generation.run(
neurograph,
self.img,
self.model_type,
batch,
self.search_radius,
downsample_factor=self.downsample_factor,
)

# Initialize dataset
t0 = time()
features = self.feature_generator.run(neurograph, batch, self.radius)
print("Feature Generation:", time() - t0)
computation_graph = batch["graph"] if type(batch) is dict else None
dataset = ml_util.init_dataset(
neurograph,
features,
self.model_type,
self.is_gnn,
computation_graph=computation_graph,
)
return dataset
Expand All @@ -570,7 +577,7 @@ def predict(self, dataset):
"""
# Get predictions
if self.model_type == "GraphNeuralNet":
if self.is_gnn:
with torch.no_grad():
# Get inputs
n = len(dataset.data["proposal"]["y"])
Expand All @@ -585,7 +592,7 @@ def predict(self, dataset):
preds = np.array(self.model.predict_proba(dataset.data.x)[:, 1])

# Reformat prediction
idxs = dataset.idxs_proposals["idx_to_edge"]
idxs = dataset.idxs_proposals["idx_to_id"]
return {idxs[i]: p for i, p in enumerate(preds)}


Expand Down
26 changes: 1 addition & 25 deletions src/deep_neurographs/machine_learning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, proposals, x_proposals, y_proposals, idxs_proposals):
"""
# Conversion idxs
self.block_to_idxs = idxs_proposals["block_to_idxs"]
self.idxs_proposals = init_idxs(idxs_proposals)
self.idxs_proposals = idxs_proposals
self.proposals = proposals

# Features
Expand Down Expand Up @@ -291,27 +291,3 @@ def reformat(arr):
"""
return np.expand_dims(arr, axis=1).astype(np.float32)


def init_idx_mapping(idx_to_id):
"""
Adds dictionary item called "edge_to_index" which maps a branch/proposal
in a neurograph to an idx that represents it's position in the feature
matrix.
Parameters
----------
idxs : dict
Dictionary that maps indices to edges in some neurograph.
Returns
-------
dict
Updated dictionary.
"""
idx_mapping = {
"idx_to_id": idx_to_id,
"id_to_idx": {v: k for k, v in idx_to_id.items()}
}
return idx_mapping
Loading

0 comments on commit 97f62a6

Please sign in to comment.