From 97f62a62a25415e484dfe0336693e750886de752 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sat, 12 Oct 2024 06:38:08 +0000 Subject: [PATCH] feat: multimodal model and dataset --- src/deep_neurographs/generate_proposals.py | 2 +- .../groundtruth_generation.py | 2 +- src/deep_neurographs/inference.py | 77 +++++++------ .../machine_learning/datasets.py | 26 +---- .../machine_learning/feature_generation.py | 75 +++++++++---- .../machine_learning/heterograph_datasets.py | 78 +++++++++---- .../machine_learning/heterograph_models.py | 105 ++++++++++++++---- .../machine_learning/models.py | 19 ++-- src/deep_neurographs/train.py | 20 ++-- src/deep_neurographs/utils/gnn_util.py | 4 +- src/deep_neurographs/utils/graph_util.py | 2 +- src/deep_neurographs/utils/img_util.py | 3 +- src/deep_neurographs/utils/ml_util.py | 10 +- src/deep_neurographs/utils/util.py | 2 +- 14 files changed, 268 insertions(+), 157 deletions(-) diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index 6c3881f..9267b8a 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -484,4 +484,4 @@ def tangent(branch, idx, depth): """ end = min(idx + depth, len(branch)) - return geometry.tangent(branch[idx:end]) \ No newline at end of file + return geometry.tangent(branch[idx:end]) diff --git a/src/deep_neurographs/groundtruth_generation.py b/src/deep_neurographs/groundtruth_generation.py index 13818a2..5907a43 100644 --- a/src/deep_neurographs/groundtruth_generation.py +++ b/src/deep_neurographs/groundtruth_generation.py @@ -300,4 +300,4 @@ def orient_branch(branch_i, branch_j): def upd_dict(node_to_target_id, nodes, target_id): for node in nodes: node_to_target_id[node] = target_id - return node_to_target_id \ No newline at end of file + return node_to_target_id diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index e9b31bf..4d087aa 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -20,10 +20,12 @@ from tqdm import tqdm from deep_neurographs.graph_artifact_removal import remove_doubles -from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.machine_learning.feature_generation import ( + FeatureGenerator, +) from deep_neurographs.utils import gnn_util from deep_neurographs.utils import graph_util as gutil -from deep_neurographs.utils import img_util, ml_util, util +from deep_neurographs.utils import ml_util, util from deep_neurographs.utils.gnn_util import toCPU from deep_neurographs.utils.graph_util import GraphLoader @@ -65,6 +67,8 @@ def __init__( output_dir, config, device=None, + is_multimodal=False, + label_path=None, ): """ Initializes an object that executes the full GraphTrace inference @@ -79,7 +83,7 @@ def __init__( Identifier for the predicted segmentation to be processed by the inference pipeline. img_path : str - Path to the raw image of whole brain stored on a GCS bucket. + Path to the raw image assumed to be stored in a GCS bucket. model_path : str Path to machine learning model parameters. output_dir : str @@ -89,6 +93,10 @@ def __init__( for the inference pipeline. device : str, optional ... + label_path : str, optional + Path to the segmentation assumed to be stored on a GCS bucket. + is_multimodal : bool, optional + ... Returns ------- @@ -99,7 +107,6 @@ def __init__( self.accepted_proposals = list() self.sample_id = sample_id self.segmentation_id = segmentation_id - self.img_path = img_path self.model_path = model_path # Extract config settings @@ -108,13 +115,15 @@ def __init__( # Inference engine self.inference_engine = InferenceEngine( - self.img_path, + img_path, self.model_path, self.ml_config.model_type, self.graph_config.search_radius, confidence_threshold=self.ml_config.threshold, device=device, downsample_factor=self.ml_config.downsample_factor, + label_path=label_path, + is_multimodal=is_multimodal, ) # Set output directory @@ -153,15 +162,15 @@ def run(self, fragments_pointer): print(f"Total Runtime: {round(t, 4)} {unit}\n") def run_schedule( - self, fragments_pointer, search_radius_schedule, save_all_rounds=False + self, fragments_pointer, radius_schedule, save_all_rounds=False ): t0 = time() self.report_experiment() self.build_graph(fragments_pointer) - for round_id, search_radius in enumerate(search_radius_schedule): - print(f"--- Round {round_id + 1}: Radius = {search_radius} ---") + for round_id, radius in enumerate(radius_schedule): + print(f"--- Round {round_id + 1}: Radius = {radius} ---") round_id += 1 - self.generate_proposals(search_radius) + self.generate_proposals(radius) self.run_inference() if save_all_rounds: self.save_results(round_id=round_id) @@ -213,7 +222,7 @@ def build_graph(self, fragments_pointer): print(f"Module Runtime: {round(t, 4)} {unit}\n") self.print_graph_overview() - def generate_proposals(self, search_radius=None): + def generate_proposals(self, radius=None): """ Generates proposals for the fragment graph based on the specified configuration. @@ -229,13 +238,13 @@ def generate_proposals(self, search_radius=None): """ # Initializations print("(2) Generate Proposals") - if search_radius is None: - search_radius = self.graph_config.search_radius + if radius is None: + radius = self.graph_config.radius # Main t0 = time() self.graph.generate_proposals( - search_radius, + radius, complex_bool=self.graph_config.complex_bool, long_range_bool=self.graph_config.long_range_bool, proposals_per_leaf=self.graph_config.proposals_per_leaf, @@ -392,11 +401,13 @@ def __init__( img_path, model_path, model_type, - search_radius, + radius, batch_size=BATCH_SIZE, confidence_threshold=CONFIDENCE_THRESHOLD, device=None, downsample_factor=1, + label_path=None, + is_multimodal=False ): """ Initializes an inference engine by loading images and setting class @@ -410,7 +421,7 @@ def __init__( Path to machine learning model parameters. model_type : str Type of machine learning model used to perform inference. - search_radius : float + radius : float Search radius used to generate proposals. batch_size : int, optional Number of proposals to generate features and classify per batch. @@ -429,16 +440,20 @@ def __init__( """ # Set class attributes self.batch_size = batch_size - self.downsample_factor = downsample_factor self.device = "cpu" if device is None else device self.is_gnn = True if "Graph" in model_type else False - self.model_type = model_type - self.search_radius = search_radius + self.radius = radius self.threshold = confidence_threshold - # Load image and model - driver = "n5" if ".n5" in img_path else "zarr" - self.img = img_util.open_tensorstore(img_path, driver=driver) + # Features + self.feature_generator = FeatureGenerator( + img_path, + downsample_factor, + label_path=label_path, + is_multimodal=is_multimodal + ) + + # Model self.model = ml_util.load_model(model_path) if self.is_gnn: self.model = self.model.to(self.device) @@ -532,22 +547,14 @@ def get_batch_dataset(self, neurograph, batch): ... """ - # Generate features - features = feature_generation.run( - neurograph, - self.img, - self.model_type, - batch, - self.search_radius, - downsample_factor=self.downsample_factor, - ) - - # Initialize dataset + t0 = time() + features = self.feature_generator.run(neurograph, batch, self.radius) + print("Feature Generation:", time() - t0) computation_graph = batch["graph"] if type(batch) is dict else None dataset = ml_util.init_dataset( neurograph, features, - self.model_type, + self.is_gnn, computation_graph=computation_graph, ) return dataset @@ -570,7 +577,7 @@ def predict(self, dataset): """ # Get predictions - if self.model_type == "GraphNeuralNet": + if self.is_gnn: with torch.no_grad(): # Get inputs n = len(dataset.data["proposal"]["y"]) @@ -585,7 +592,7 @@ def predict(self, dataset): preds = np.array(self.model.predict_proba(dataset.data.x)[:, 1]) # Reformat prediction - idxs = dataset.idxs_proposals["idx_to_edge"] + idxs = dataset.idxs_proposals["idx_to_id"] return {idxs[i]: p for i, p in enumerate(preds)} diff --git a/src/deep_neurographs/machine_learning/datasets.py b/src/deep_neurographs/machine_learning/datasets.py index d1a7e78..59b1e3e 100644 --- a/src/deep_neurographs/machine_learning/datasets.py +++ b/src/deep_neurographs/machine_learning/datasets.py @@ -82,7 +82,7 @@ def __init__(self, proposals, x_proposals, y_proposals, idxs_proposals): """ # Conversion idxs self.block_to_idxs = idxs_proposals["block_to_idxs"] - self.idxs_proposals = init_idxs(idxs_proposals) + self.idxs_proposals = idxs_proposals self.proposals = proposals # Features @@ -291,27 +291,3 @@ def reformat(arr): """ return np.expand_dims(arr, axis=1).astype(np.float32) - - -def init_idx_mapping(idx_to_id): - """ - Adds dictionary item called "edge_to_index" which maps a branch/proposal - in a neurograph to an idx that represents it's position in the feature - matrix. - - Parameters - ---------- - idxs : dict - Dictionary that maps indices to edges in some neurograph. - - Returns - ------- - dict - Updated dictionary. - - """ - idx_mapping = { - "idx_to_id": idx_to_id, - "id_to_idx": {v: k for k, v in idx_to_id.items()} - } - return idx_mapping \ No newline at end of file diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index ca06efb..5187eab 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -15,7 +15,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from copy import deepcopy -from random import sample import numpy as np from scipy.ndimage import zoom @@ -39,7 +38,7 @@ def __init__( img_path, downsample_factor, label_path=None, - use_img_embedding=False, + is_multimodal=False, ): """ Initializes object that generates features for a graph. @@ -54,7 +53,7 @@ def __init__( label_path : str, optional Path to the segmentation assumed to be stored on a GCS bucket. The default is None. - use_img_embedding : bool, optional + is_multimodal : bool, optional ... Returns @@ -64,7 +63,7 @@ def __init__( """ # Initialize instance attributes self.downsample_factor = downsample_factor - self.use_img_embedding = use_img_embedding + self.is_multimodal = is_multimodal # Initialize image-based attributes driver = "n5" if ".n5" in img_path else "zarr" @@ -79,7 +78,7 @@ def __init__( self.label_patch_shape = self.set_patch_shape(0) # Validate embedding requirements - if self.use_img_embedding and not label_path: + if self.is_multimodal and not label_path: raise("Must provide labels to generate image embeddings") @classmethod @@ -144,7 +143,7 @@ def run(self, neurograph, proposals_dict, radius): } # Generate image patches (if applicable) - if self.use_img_embedding: + if self.is_multimodal: features["patches"] = self.proposal_patches(neurograph, proposals) return features @@ -206,7 +205,7 @@ def run_on_proposals(self, neurograph, proposals, radius): """ features = self.proposal_skeletal(neurograph, proposals, radius) - if not self.use_img_embedding: + if not self.is_multimodal: profiles = self.proposal_profiles(neurograph, proposals) for p in proposals: features[p] = np.concatenate((features[p], profiles[p])) @@ -571,35 +570,66 @@ def get_branching_path(neurograph, i): # --- Build feature matrix --- def get_matrix(features, gt_accepts=set()): # Initialize matrices - key = sample(list(features.keys()), 1)[0] - X = np.zeros((len(features.keys()), len(features[key]))) + key = util.sample_once(list(features.keys())) + x = np.zeros((len(features.keys()), len(features[key]))) y = np.zeros((len(features.keys()))) # Populate idx_to_id = dict() for i, id_i in enumerate(features): idx_to_id[i] = id_i - X[i, :] = features[id_i] + x[i, :] = features[id_i] y[i] = 1 if id_i in gt_accepts else 0 - return X, y, idx_to_id + return x, y, init_idx_mapping(idx_to_id) + + +def get_patches_matrix(patches, id_to_idx): + patch = util.sample_once(list(patches.values())) + x = np.zeros((len(id_to_idx),) + patch.shape) + for key, patch in patches.items(): + x[id_to_idx[key], ...] = patch + return x def stack_matrices(neurographs, features, blocks): - idx_to_id = dict() - X, y = None, None + x, y = None, None for block_id in blocks: - X_i, y_i, _ = get_matrix(features[block_id]) - if X is None: - X = deepcopy(X_i) + x_i, y_i, _ = get_matrix(features[block_id]) + if x is None: + x = deepcopy(x_i) y = deepcopy(y_i) else: - X = np.concatenate((X, X_i), axis=0) + x = np.concatenate((x, x_i), axis=0) y = np.concatenate((y, y_i), axis=0) - return X, y + return x, y + + +def init_idx_mapping(idx_to_id): + """ + Adds dictionary item called "edge_to_index" which maps a branch/proposal + in a neurograph to an idx that represents it's position in the feature + matrix. + + Parameters + ---------- + idxs : dict + Dictionary that maps indices to edges in some neurograph. + + Returns + ------- + dict + Updated dictionary. + + """ + idx_mapping = { + "idx_to_id": idx_to_id, + "id_to_idx": {v: k for k, v in idx_to_id.items()} + } + return idx_mapping # --- Utils --- -def get_node_dict(use_img_embedding=False): +def get_node_dict(is_multimodal=False): """ Returns the number of features for different node types. @@ -613,7 +643,10 @@ def get_node_dict(use_img_embedding=False): A dictionary containing the number of features for each node type """ - return {"branch": 2, "proposal": 34} + if is_multimodal: + return {"branch": 2, "proposal": 16} + else: + return {"branch": 2, "proposal": 34} def get_edge_dict(): @@ -635,4 +668,4 @@ def get_edge_dict(): ("branch", "edge", "branch"): 3, ("branch", "edge", "proposal"): 3 } - return edge_dict \ No newline at end of file + return edge_dict diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index a45d913..4c1a922 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -17,8 +17,10 @@ import torch from torch_geometric.data import HeteroData as HeteroGraphData -from deep_neurographs.machine_learning import datasets -from deep_neurographs.machine_learning.feature_generation import get_matrix +from deep_neurographs.machine_learning.feature_generation import ( + get_matrix, + get_patches_matrix, +) from deep_neurographs.utils import gnn_util DTYPE = torch.float32 @@ -52,24 +54,35 @@ def init(neurograph, features, computation_graph): gt_accepts = set() # Extract features - x_branches, _, idxs_branches = get_matrix(features["branches"]) - x_proposals, y_proposals, idxs_proposals = get_matrix( + idxs, x_dict = dict(), dict() + x_dict["branches"], _, idxs["branches"] = get_matrix(features["branches"]) + x_dict["proposals"], y_proposals, idxs["proposals"] = get_matrix( features["proposals"], gt_accepts ) - x_nodes = features["nodes"] + x_dict["nodes"] = features["nodes"] + + # Build patch matrix + is_multimodel = "patches" in features + if is_multimodel: + x_dict["patches"] = get_patches_matrix( + features["patches"], idxs["proposals"]["id_to_idx"] + ) # Initialize dataset proposals = list(features["proposals"].keys()) - heterograph_dataset = HeteroGraphDataset( + if is_multimodel: + heterograph_dataset_class = HeteroGraphMultiModalDataset + else: + heterograph_dataset_class = HeteroGraphDataset + + heterograph_dataset = heterograph_dataset_class( computation_graph, proposals, - x_nodes, - x_branches, - x_proposals, + x_dict, y_proposals, - idxs_branches, - idxs_proposals + idxs, ) + return heterograph_dataset @@ -84,12 +97,9 @@ def __init__( self, computation_graph, proposals, - x_nodes, - x_branches, - x_proposals, + x_dict, y_proposals, - idxs_branches, - idxs_proposals, + idxs, ): """ Constructs a HeteroGraphDataset object. @@ -118,8 +128,8 @@ def __init__( """ # Conversion idxs - self.idxs_branches = datasets.init_idx_mapping(idxs_branches) - self.idxs_proposals = datasets.init_idx_mapping(idxs_proposals) + self.idxs_branches = idxs["branches"] + self.idxs_proposals = idxs["proposals"] self.computation_graph = computation_graph self.proposals = proposals @@ -133,15 +143,15 @@ def __init__( # Features self.data = HeteroGraphData() - self.data["branch"].x = torch.tensor(x_branches, dtype=DTYPE) - self.data["proposal"].x = torch.tensor(x_proposals, dtype=DTYPE) + self.data["branch"].x = torch.tensor(x_dict["branches"], dtype=DTYPE) + self.data["proposal"].x = torch.tensor(x_dict["proposals"], dtype=DTYPE) self.data["proposal"].y = torch.tensor(y_proposals, dtype=DTYPE) # Edges self.init_edges() self.check_missing_edge_types() - self.init_edge_attrs(x_nodes) - self.n_edge_attrs = n_edge_features(x_nodes) + self.init_edge_attrs(x_dict["nodes"]) + self.n_edge_attrs = n_edge_features(x_dict["nodes"]) def init_edges(self): """ @@ -402,6 +412,28 @@ def set_hetero_edge_attrs(self, x_nodes, edge_type, idx_map_1, idx_map_2): self.data[edge_type].edge_attr = arrs +class HeteroGraphMultiModalDataset(HeteroGraphDataset): + def __init__( + self, + computation_graph, + proposals, + x_dict, + y_proposals, + idxs, + ): + # Call super constructor + super().__init__( + computation_graph, + proposals, + x_dict, + y_proposals, + idxs, + ) + + # Instance attributes + self.data["patches"].x = torch.tensor(x_dict["patches"], dtype=DTYPE) + + # -- util -- def node_intersection(idx_map, e1, e2): """ @@ -469,4 +501,4 @@ def n_edge_features(x): """ key = sample(list(x.keys()), 1)[0] - return x[key].shape[0] \ No newline at end of file + return x[key].shape[0] diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index 9124acf..3272c94 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -8,23 +8,17 @@ """ -import numpy as np import torch import torch.nn.init as init from torch import nn from torch.nn import Dropout, LeakyReLU from torch_geometric.nn import GATv2Conv as GATConv -from torch_geometric.nn import HEATConv, HeteroConv, Linear +from torch_geometric.nn import HeteroConv, Linear -from deep_neurographs import machine_learning as ml +from deep_neurographs.machine_learning.models import ConvNet -CONV_TYPES = ["GATConv", "GCNConv"] -DROPOUT = 0.3 -HEADS_1 = 1 -HEADS_2 = 1 - -class HeteroGNN(torch.nn.Module): +class HGAT(torch.nn.Module): """ Heterogeneous graph attention network that classifies proposals. @@ -38,40 +32,41 @@ class HeteroGNN(torch.nn.Module): def __init__( self, + node_dict, + edge_dict, device=None, - scale_hidden=2, - dropout=DROPOUT, - heads_1=HEADS_1, - heads_2=HEADS_2, + hidden_dim=64, + dropout=0.3, + heads_1=2, + heads_2=2, ): """ Constructs a heterogeneous graph neural network. """ super().__init__() - # Feature vector sizes - node_dict = ml.feature_generation.get_node_dict() - edge_dict = ml.feature_generation.get_edge_dict() - hidden_dim = scale_hidden * np.max(list(node_dict.values())) - output_dim = heads_1 * heads_2 * hidden_dim + # Layer dimensions + hidden_dim_1 = hidden_dim + hidden_dim_2 = hidden_dim_1 * heads_2 + output_dim = hidden_dim * heads_1 * heads_2 # Nonlinear activation self.dropout = dropout self.dropout_layer = Dropout(dropout) self.leaky_relu = LeakyReLU() - # Linear layers + # Linear layers self.input_nodes = nn.ModuleDict() self.input_edges = dict() for key, d in node_dict.items(): - self.input_nodes[key] = nn.Linear(d, hidden_dim, device=device) + self.input_nodes[key] = nn.Linear(d, hidden_dim_1, device=device) for key, d in edge_dict.items(): - self.input_edges[key] = nn.Linear(d, hidden_dim, device=device) + self.input_edges[key] = nn.Linear(d, hidden_dim_1, device=device) self.output = Linear(output_dim, 1).to(device) # Message passing layers - self.gat1 = self.init_gat_layer(hidden_dim, hidden_dim, heads_1) - self.gat2 = self.init_gat_layer(hidden_dim * heads_2, hidden_dim, heads_2) + self.gat1 = self.init_gat_layer(hidden_dim_1, hidden_dim_1, heads_1) + self.gat2 = self.init_gat_layer(hidden_dim_2, hidden_dim_1, heads_2) # Initialize weights self.init_weights() @@ -109,6 +104,7 @@ def init_gat_mixed(self, hidden_dim, edge_dim, heads): heads=heads, ) return gat_layer + def init_weights(self): """ Initializes linear layers. @@ -158,7 +154,68 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict): edge_attr_dict[key] = f(edge_attr_dict[key]) edge_attr_dict = self.activation(edge_attr_dict) - # Convolutional layers + # Message passing layers + x_dict = self.gat1( + x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict + ) + x_dict = self.gat2( + x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict + ) + + # Output + x_dict = self.output(x_dict["proposal"]) + return x_dict + + +class MultiModalHGAT(HGAT): + """ + Heterogeneous graph attention network that uses multimodal features which + includes an image patch of the proposal and a vector of geometric and + graphical features. + + """ + + def __init__( + self, + node_dict, + edge_dict, + device=None, + hidden_dim=64, + dropout=0.3, + heads_1=2, + heads_2=2, + ): + # Call super constructor + super().__init__( + node_dict, + edge_dict, + device, + hidden_dim, + dropout, + heads_1, + heads_2, + ) + + # Instance attributes + self.input_patches = ConvNet(hidden_dim) + + def forward(self, x_dict, edge_index_dict, edge_attr_dict): + # Input - Patches + x_patches = self.input_patches(x_dict["patches"]) + del x_dict["patches"] + + # Input - Nodes + x_dict = {key: f(x_dict[key]) for key, f in self.input_nodes.items()} + x_dict = self.activation(x_dict) + + # Input - Edges + for key, f in self.input_edges.items(): + edge_attr_dict[key] = f(edge_attr_dict[key]) + edge_attr_dict = self.activation(edge_attr_dict) + + # Concatenate multimodal embeddings + + # Message passing layers x_dict = self.gat1( x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict ) diff --git a/src/deep_neurographs/machine_learning/models.py b/src/deep_neurographs/machine_learning/models.py index ed7169b..1562c4f 100644 --- a/src/deep_neurographs/machine_learning/models.py +++ b/src/deep_neurographs/machine_learning/models.py @@ -91,7 +91,7 @@ class ConvNet(nn.Module): """ - def __init__(self): + def __init__(self, patch_shape, output_dim): """ Constructs a ConvNet object. @@ -105,10 +105,12 @@ def __init__(self): """ nn.Module.__init__(self) - self.conv1 = self._init_conv_layer(2, 4) - self.conv2 = self._init_conv_layer(4, 4) + self.conv1 = self._init_conv_layer(2, 32) + self.conv2 = self._init_conv_layer(32, 64) self.output = nn.Sequential( - nn.Linear(10976, 64), nn.LeakyReLU(), nn.Linear(64, 1) + nn.Linear(-1, 64), + nn.LeakyReLU(), + nn.Linear(output_dim, output_dim), ) def _init_conv_layer(self, in_channels, out_channels): @@ -132,14 +134,14 @@ def _init_conv_layer(self, in_channels, out_channels): nn.Conv3d( in_channels, out_channels, - kernel_size=(3, 3, 3), + kernel_size=3, stride=1, padding=0, ), nn.BatchNorm3d(out_channels), nn.LeakyReLU(), - nn.Dropout(p=0.2), - nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2), + nn.Dropout(p=0.3), + nn.MaxPool3d(kernel_size=2, stride=2), ) return conv_layer @@ -163,9 +165,6 @@ def forward(self, x): x = self.output(vectorize(x)) return x - def model_type(self): - return "ConvNet" - class MultiModalNet(nn.Module): """ diff --git a/src/deep_neurographs/train.py b/src/deep_neurographs/train.py index 24a223b..70f44f7 100644 --- a/src/deep_neurographs/train.py +++ b/src/deep_neurographs/train.py @@ -25,8 +25,10 @@ from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter -from deep_neurographs.machine_learning.feature_generation import FeatureGenerator -from deep_neurographs.utils import gnn_util, img_util, ml_util, util +from deep_neurographs.machine_learning.feature_generation import ( + FeatureGenerator, +) +from deep_neurographs.utils import gnn_util, ml_util, util from deep_neurographs.utils.gnn_util import toCPU from deep_neurographs.utils.graph_util import GraphLoader @@ -50,7 +52,7 @@ def __init__( model_type, criterion=None, output_dir=None, - use_img_embedding=False, + is_multimodal=False, validation_ids=None, save_model_bool=True, ): @@ -65,7 +67,7 @@ def __init__( self.model_type = model_type self.output_dir = output_dir self.save_model_bool = save_model_bool - self.use_img_embedding = use_img_embedding + self.is_multimodal = is_multimodal self.validation_ids = validation_ids # Set data structures for training examples @@ -116,7 +118,7 @@ def load_example( pred_pointer, sample_id, example_id=None, - pred_id=None, + segmentation_id=None, metadata_path=None, ): # Read metadata @@ -140,7 +142,7 @@ def load_example( { "sample_id": sample_id, "example_id": example_id, - "pred_id": pred_id, + "segmentation_id": segmentation_id, } ) @@ -152,7 +154,7 @@ def load_img( img_path, downsample_factor, label_path=label_path, - use_img_embedding=self.use_img_embedding, + is_multimodal=self.is_multimodal, ) # --- main pipeline --- @@ -360,7 +362,7 @@ def predict(self, data): Prediction. """ - x, edge_index, edge_attr = gnn_util.get_inputs(data, "HeteroGNN") + x, edge_index, edge_attr = gnn_util.get_inputs(data) hat_y = self.model(x, edge_index, edge_attr) y = data["proposal"]["y"] return truncate(hat_y, y), y @@ -470,4 +472,4 @@ def get_predictions(hat_y, threshold=0.5): Binary predictions based on the given threshold. """ - return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist() \ No newline at end of file + return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist() diff --git a/src/deep_neurographs/utils/gnn_util.py b/src/deep_neurographs/utils/gnn_util.py index 9cb54a4..7621730 100644 --- a/src/deep_neurographs/utils/gnn_util.py +++ b/src/deep_neurographs/utils/gnn_util.py @@ -21,7 +21,7 @@ # --- Tensor Operations --- -def get_inputs(data, device=None): +def get_inputs(data, device=None, is_multimodal=False): x = data.x_dict edge_index = data.edge_index_dict edge_attr = data.edge_attr_dict @@ -297,4 +297,4 @@ def init_line_graph(edges): """ graph = nx.Graph() graph.add_edges_from(edges) - return nx.line_graph(graph) \ No newline at end of file + return nx.line_graph(graph) diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index dd8a3dd..e0ef859 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -858,4 +858,4 @@ def largest_components(neurograph, k): node_ids.pop(-1) break i += 1 - return node_ids \ No newline at end of file + return node_ids diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index db50b06..3abf0a6 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -13,7 +13,6 @@ import numpy as np import tensorstore as ts from skimage.color import label2rgb -from tifffile import imwrite from deep_neurographs.utils import util @@ -421,4 +420,4 @@ def find_img_path(bucket_name, img_root, dataset_name): for subdir in util.list_gcs_subdirectories(bucket_name, img_root): if dataset_name in subdir: return subdir + "whole-brain/fused.zarr/" - raise f"Dataset not found in {bucket_name} - {img_root}" \ No newline at end of file + raise f"Dataset not found in {bucket_name} - {img_root}" diff --git a/src/deep_neurographs/utils/ml_util.py b/src/deep_neurographs/utils/ml_util.py index bb2e236..3c25e2f 100644 --- a/src/deep_neurographs/utils/ml_util.py +++ b/src/deep_neurographs/utils/ml_util.py @@ -61,7 +61,13 @@ def save_model(path, model, model_type): # --- dataset utils --- -def init_dataset(neurograph, features, is_gnn=True, computation_graph=None): +def init_dataset( + neurograph, + features, + is_gnn=True, + is_multimodal=False, + computation_graph=None +): """ Initializes a dataset given features generated from some set of proposals and neurograph. @@ -140,4 +146,4 @@ def get_kfolds(filenames, k): folds.append(samples_i) if n_samples > len(samples): break - return folds \ No newline at end of file + return folds diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index a82cb28..023b6a6 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -663,4 +663,4 @@ def spaced_idxs(container, k): idxs = np.arange(0, len(container) + k, k)[:-1] if len(container) % 2 == 0: idxs = np.append(idxs, len(container) - 1) - return idxs \ No newline at end of file + return idxs