Skip to content

Commit

Permalink
feat: multimodal gnn is functional (#268)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Oct 12, 2024
1 parent 97f62a6 commit f248ecb
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 38 deletions.
25 changes: 13 additions & 12 deletions src/deep_neurographs/machine_learning/feature_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
@author: Anna Grim
@email: [email protected]
Generates features for training a model and performing inference.
Generates features for training a machine learning model and performing
inference.
Conventions:
(1) "xyz" refers to a real world coordinate such as those from an swc file
(2) "voxel" refers to an voxel coordinate in a whole exaspim image.
"""
Expand All @@ -26,7 +26,7 @@
class FeatureGenerator:
"""
Class that generates features vectors that are used by a graph neural
network to classify proposals.
network (GNN) to classify proposals.
"""
# Class attributes
Expand Down Expand Up @@ -54,7 +54,8 @@ def __init__(
Path to the segmentation assumed to be stored on a GCS bucket. The
default is None.
is_multimodal : bool, optional
...
Indication of whether to generate multimodal features (i.e. image
and label patch for each proposal). The default is False.
Returns
-------
Expand Down Expand Up @@ -118,7 +119,7 @@ def run(self, neurograph, proposals_dict, radius):
proposals_dict : dict
Dictionary that contains the items (1) "proposals" which are the
proposals from "neurograph" that features will be generated and
(2) "graph" which is the computation graph used by the gnn.
(2) "graph" which is the computation graph used by the GNN.
radius : float
Search radius used to generate proposals.
Expand Down Expand Up @@ -156,7 +157,7 @@ def run_on_nodes(self, neurograph, computation_graph):
neurograph : NeuroGraph
NeuroGraph generated from a predicted segmentation.
computation_graph : networkx.Graph
Graph used by gnn to classify proposals.
Graph used by GNN to classify proposals.
Returns
-------
Expand All @@ -175,12 +176,12 @@ def run_on_branches(self, neurograph, computation_graph):
neurograph : NeuroGraph
NeuroGraph generated from a predicted segmentation.
computation_graph : networkx.Graph
Graph used by gnn to classify proposals.
Graph used by GNN to classify proposals.
Returns
-------
dict
Dictionary that maps an edge id to a feature vector.
Dictionary that maps an branch id to a feature vector.
"""
return self.branch_skeletal(neurograph, computation_graph)
Expand Down Expand Up @@ -221,7 +222,7 @@ def node_skeletal(self, neurograph, computation_graph):
neurograph : NeuroGraph
NeuroGraph generated from a predicted segmentation.
computation_graph : networkx.Graph
Graph used by gnn to classify proposals.
Graph used by GNN to classify proposals.
Returns
-------
Expand Down Expand Up @@ -250,7 +251,7 @@ def branch_skeletal(self, neurograph, computation_graph):
neurograph : NeuroGraph
NeuroGraph generated from a predicted segmentation.
computation_graph : networkx.Graph
Graph used by gnn to classify proposals.
Graph used by GNN to classify proposals.
Returns
-------
Expand Down Expand Up @@ -313,7 +314,7 @@ def node_profiles(self, neurograph, computation_graph):
neurograph : NeuroGraph
NeuroGraph generated from a predicted segmentation.
computation_graph : networkx.Graph
Graph used by gnn to classify proposals.
Graph used by GNN to classify proposals.
Returns
-------
Expand Down Expand Up @@ -435,7 +436,7 @@ def get_profile(self, xyz_path, profile_id):
def get_spec(self, xyz_path):
"""
Gets image bounding box and voxel coordinates needed to compute an
image intensity profile or extract image chunk for cnn embedding.
image intensity profile or extract image patch.
Parameters
----------
Expand Down
60 changes: 51 additions & 9 deletions src/deep_neurographs/machine_learning/heterograph_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def init(neurograph, features, computation_graph):
# Build patch matrix
is_multimodel = "patches" in features
if is_multimodel:
x_dict["patches"] = get_patches_matrix(
x_dict["patch"] = get_patches_matrix(
features["patches"], idxs["proposals"]["id_to_idx"]
)

Expand Down Expand Up @@ -142,17 +142,18 @@ def __init__(
]

# Features
self.data = HeteroGraphData()
self.data["branch"].x = torch.tensor(x_dict["branches"], dtype=DTYPE)
self.data["proposal"].x = torch.tensor(x_dict["proposals"], dtype=DTYPE)
self.data["proposal"].y = torch.tensor(y_proposals, dtype=DTYPE)

# Edges
self.init_nodes(x_dict, y_proposals)
self.init_edges()
self.check_missing_edge_types()
self.init_edge_attrs(x_dict["nodes"])
self.n_edge_attrs = n_edge_features(x_dict["nodes"])

def init_nodes(self, x_dict, y_proposals):
self.data = HeteroGraphData()
self.data["branch"].x = torch.tensor(x_dict["branches"], dtype=DTYPE)
self.data["proposal"].x = torch.tensor(x_dict["proposals"], dtype=DTYPE)
self.data["proposal"].y = torch.tensor(y_proposals, dtype=DTYPE)

def init_edges(self):
"""
Initializes edge index for a graph dataset.
Expand Down Expand Up @@ -430,8 +431,49 @@ def __init__(
idxs,
)

# Instance attributes
self.data["patches"].x = torch.tensor(x_dict["patches"], dtype=DTYPE)
def init_nodes(self, x_dict, y_proposals):
self.data = HeteroGraphData()
self.data["branch"].x = torch.tensor(x_dict["branches"], dtype=DTYPE)
self.data["proposal"].x = torch.tensor(x_dict["proposals"], dtype=DTYPE)
self.data["proposal"].y = torch.tensor(y_proposals, dtype=DTYPE)
self.data["patch"].x = torch.tensor(x_dict["patch"], dtype=DTYPE)

def check_missing_edge_types(self):
for node_type in ["branch", "proposal"]:
edge_type = (node_type, "edge", node_type)
if len(self.data[edge_type].edge_index) == 0:
# Add dummy features - nodes
dtype = self.data[node_type].x.dtype
if node_type == "branch":
d = self.n_branch_features()
else:
d = self.n_proposal_features()

zeros = torch.zeros(2, d, dtype=dtype)
self.data[node_type].x = torch.cat(
(self.data[node_type].x, zeros), dim=0
)

# Add dummy features - patches
if node_type == "proposal":
patch_shape = self.data["patch"].x.size()[1:]
zeros = torch.zeros((2,) + patch_shape, dtype=dtype)
self.data["patch"].x = torch.cat(
(self.data["patch"].x, zeros), dim=0
)

# Update edge_index
n = self.data[node_type]["x"].size(0)
e_1 = frozenset({-1, -2})
e_2 = frozenset({-2, -3})
edges = [[n - 1, n - 2], [n - 2, n - 1]]
self.data[edge_type].edge_index = gnn_util.toTensor(edges)
if node_type == "branch":
self.idxs_branches["idx_to_id"][n - 1] = e_1
self.idxs_branches["idx_to_id"][n - 2] = e_2
else:
self.idxs_proposals["idx_to_id"][n - 1] = e_1
self.idxs_proposals["idx_to_id"][n - 2] = e_2


# -- util --
Expand Down
53 changes: 37 additions & 16 deletions src/deep_neurographs/machine_learning/heterograph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,29 @@ def __init__(
"""
super().__init__()
# Layer dimensions
hidden_dim_1 = hidden_dim
hidden_dim_2 = hidden_dim_1 * heads_2
output_dim = hidden_dim * heads_1 * heads_2

# Nonlinear activation
self.dropout = dropout
self.dropout_layer = Dropout(dropout)
self.leaky_relu = LeakyReLU()

# Linear layers
# Initial Embedding
self.input_nodes = nn.ModuleDict()
self.input_edges = dict()
for key, d in node_dict.items():
self.input_nodes[key] = nn.Linear(d, hidden_dim_1, device=device)
self.input_nodes[key] = nn.Linear(d, hidden_dim, device=device)

self.input_edges = dict()
for key, d in edge_dict.items():
self.input_edges[key] = nn.Linear(d, hidden_dim_1, device=device)
self.output = Linear(output_dim, 1).to(device)
self.input_edges[key] = nn.Linear(d, hidden_dim, device=device)

# Layer dimensions
hidden_dim_1 = hidden_dim
hidden_dim_2 = hidden_dim_1 * heads_2
output_dim = hidden_dim_1 * heads_1 * heads_2

# Message passing layers
self.gat1 = self.init_gat_layer(hidden_dim_1, hidden_dim_1, heads_1)
self.gat2 = self.init_gat_layer(hidden_dim_2, hidden_dim_1, heads_2)
self.output = Linear(output_dim, 1).to(device)

# Initialize weights
self.init_weights()
Expand Down Expand Up @@ -190,22 +191,41 @@ def __init__(
node_dict,
edge_dict,
device,
hidden_dim,
hidden_dim * 2,
dropout,
heads_1,
heads_2,
)

# Instance attributes
self.input_patches = ConvNet(hidden_dim)
# Patch Embedding
self.input_patches = ConvNet((48, 48, 48), hidden_dim)

# Node Embedding
proposal_dim = node_dict["proposal"]
branch_dim = node_dict["branch"]
self.input_nodes = nn.ModuleDict({
"proposal": nn.Linear(proposal_dim, hidden_dim, device=device),
"branch": nn.Linear(branch_dim, hidden_dim * 2, device=device),
})

# Edge Embedding
self.input_edges = dict()
for key, d in edge_dict.items():
self.input_edges[key] = nn.Linear(
d, hidden_dim * 2, device=device
)

# Initialize weights
self.init_weights()

def forward(self, x_dict, edge_index_dict, edge_attr_dict):
# Input - Patches
x_patches = self.input_patches(x_dict["patches"])
del x_dict["patches"]
x_patch = self.input_patches(x_dict["patch"])
del x_dict["patch"]

# Input - Nodes
x_dict = {key: f(x_dict[key]) for key, f in self.input_nodes.items()}
for key, f in self.input_nodes.items():
x_dict[key] = f(x_dict[key])
x_dict = self.activation(x_dict)

# Input - Edges
Expand All @@ -214,6 +234,7 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict):
edge_attr_dict = self.activation(edge_attr_dict)

# Concatenate multimodal embeddings
x_dict["proposal"] = torch.cat((x_dict["proposal"], x_patch), dim=1)

# Message passing layers
x_dict = self.gat1(
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/machine_learning/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, patch_shape, output_dim):
self.conv1 = self._init_conv_layer(2, 32)
self.conv2 = self._init_conv_layer(32, 64)
self.output = nn.Sequential(
nn.Linear(-1, 64),
nn.Linear(64000, 64),
nn.LeakyReLU(),
nn.Linear(output_dim, output_dim),
)
Expand Down

0 comments on commit f248ecb

Please sign in to comment.