Skip to content

Commit

Permalink
upds (#131)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored May 8, 2024
1 parent 0f93b59 commit d5a5185
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 30 deletions.
31 changes: 19 additions & 12 deletions src/deep_neurographs/machine_learning/feature_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def generate_skel_features(neurograph, proposals, search_radius):
i, j = tuple(proposal)
features[proposal] = np.concatenate(
(
1, # edge type
1, # edge type
neurograph.proposal_length(proposal),
neurograph.degree[i],
neurograph.degree[j],
Expand Down Expand Up @@ -361,7 +361,12 @@ def get_directionals(neurograph, proposal, window_size):
# Compute features
inner_product_1 = abs(np.dot(proposal_direction, direction_i))
inner_product_2 = abs(np.dot(proposal_direction, direction_j))
inner_product_3 = np.dot(direction_i, direction_j)
if neurograph.is_simple(proposal):
inner_product_3 = np.dot(direction_i, direction_j)
else:
inner_product_3a = np.dot(direction_i, direction_j)
inner_product_3b = np.dot(direction_i, direction_j)
inner_product_3 = max(inner_product_3a, inner_product_3b)
return np.array([inner_product_1, inner_product_2, inner_product_3])


Expand Down Expand Up @@ -435,21 +440,23 @@ def generate_branch_features(neurograph):
i, j = tuple(edge)
features[frozenset(edge)] = np.concatenate(
(
-1, # edge type
-1,
-1, # edge type
np.zeros((32))
),
axis=None,
)
return features
"""
0,
neurograph.degree[i],
neurograph.degree[j],
-1,
0,
get_radii(neurograph, edge),
np.mean(neurograph.edges[i, j]["radius"]),
np.mean(neurograph.edges[i, j]["radius"]),
-1 * np.ones(12),
-1 * np.ones((N_PROFILE_PTS + 2)),
),
axis=None,
)
return features

np.zeros(12),
np.zeros((N_PROFILE_PTS + 2)),
"""

def compute_curvature(neurograph, edge):
kappa = curvature(neurograph.edges[edge]["xyz"])
Expand Down
10 changes: 5 additions & 5 deletions src/deep_neurographs/machine_learning/graph_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,11 @@ def __init__(
self.n_proposals = len(y_proposals)

# Initialize data
edge_index = set_edge_index(
edge_index, proposal_edges = set_edge_index(
neurograph, proposals, idxs_branches, idxs_proposals
)
self.data = GraphData(x=x, y=y, edge_index=edge_index)
self.dropout_edges = proposal_edges


class HeteroGraphDataset:
Expand Down Expand Up @@ -177,12 +178,11 @@ def set_edge_index(neurograph, proposals, idxs_branches, idxs_proposals):
# Initializations
branches_line_graph = nx.line_graph(neurograph)
proposals_line_graph = init_proposals_line_graph(neurograph, proposals)
proposal_edges = proposal_to_proposal(proposals_line_graph, idxs_proposals)

# Compute edges
edge_index = branch_to_branch(branches_line_graph, idxs_branches)
edge_index.extend(
proposal_to_proposal(proposals_line_graph, idxs_proposals)
)
edge_index.extend(proposal_edges)
edge_index.extend(
branch_to_proposal(
neurograph, proposals, idxs_branches, idxs_proposals
Expand All @@ -192,7 +192,7 @@ def set_edge_index(neurograph, proposals, idxs_branches, idxs_proposals):
# Reshape
edge_index = np.array(edge_index, dtype=np.int64).tolist()
edge_index = torch.Tensor(edge_index).t().contiguous()
return edge_index.long()
return edge_index.long(), proposal_edges


def init_proposals_line_graph(neurograph, proposals):
Expand Down
12 changes: 5 additions & 7 deletions src/deep_neurographs/machine_learning/graph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self, input_channels):
# Convolutional layers
self.conv1 = GCNConv(input_channels, input_channels // 2)
self.conv2 = GCNConv(input_channels // 2, input_channels // 2)
self.conv3 = GCNConv(input_channels, input_channels // 2)

# Activation
self.dropout = Dropout(0.3)
Expand All @@ -35,7 +34,7 @@ def __init__(self, input_channels):
self.init_weights()

def init_weights(self):
layers = [self.conv1, self.conv2, self.conv3, self.input, self.output]
layers = [self.conv1, self.conv2, self.input, self.output]
for layer in layers:
for param in layer.parameters():
if len(param.shape) > 1:
Expand Down Expand Up @@ -69,13 +68,12 @@ class GAT(torch.nn.Module):
def __init__(self, input_channels):
super().__init__()
# Linear layers
self.input = Linear(input_channels, input_channels)
self.input = Linear(input_channels, 2 * input_channels)
self.output = Linear(input_channels // 2, 1)

# Convolutional layers
self.conv1 = GATConv(input_channels, input_channels // 2)
self.conv2 = GATConv(input_channels // 2, input_channels // 2)
self.conv3 = GATConv(input_channels, input_channels // 2)
self.conv1 = GATConv(2 * input_channels, input_channels)
self.conv2 = GATConv(input_channels, input_channels // 2)

# Activation
self.dropout = Dropout(0.3)
Expand All @@ -85,7 +83,7 @@ def __init__(self, input_channels):
self.init_weights()

def init_weights(self):
layers = [self.conv1, self.conv2, self.conv3, self.input, self.output]
layers = [self.conv1, self.conv2, self.input, self.output]
for layer in layers:
for param in layer.parameters():
if len(param.shape) > 1:
Expand Down
62 changes: 56 additions & 6 deletions src/deep_neurographs/machine_learning/graph_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,22 @@
)
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.utils import subgraph

from deep_neurographs.machine_learning import ml_utils

# Training
LR = 1e-3
N_EPOCHS = 1000
SCHEDULER_GAMMA = 0.75
N_EPOCHS = 200
SCHEDULER_GAMMA = 0.5
SCHEDULER_STEP_SIZE = 1000
TEST_PERCENT = 0.15
WEIGHT_DECAY = 1e-3

# Augmentation
MAX_PROPOSAL_DROPOUT = 0.1
SCALING_FACTOR = 0.05


class GraphTrainer:
"""
Expand All @@ -44,6 +50,8 @@ def __init__(
criterion,
lr=LR,
n_epochs=N_EPOCHS,
max_proposal_dropout=MAX_PROPOSAL_DROPOUT,
scaling_factor=SCALING_FACTOR,
weight_decay=WEIGHT_DECAY,
):
"""
Expand All @@ -68,6 +76,7 @@ def __init__(
None.
"""
# Training
self.model = model.to("cuda:0")
self.criterion = criterion
self.n_epochs = n_epochs
Expand All @@ -77,14 +86,18 @@ def __init__(
self.init_scheduler()
self.writer = SummaryWriter()

# Augmentation
self.scaling_factor = scaling_factor
self.max_proposal_dropout = max_proposal_dropout

def init_scheduler(self):
self.scheduler = StepLR(
self.optimizer,
step_size=SCHEDULER_STEP_SIZE,
gamma=SCHEDULER_GAMMA,
)

def run_on_graphs(self, datasets):
def run_on_graphs(self, datasets, augment=False):
"""
Trains a graph neural network in the case where "datasets" is a
dictionary of datasets such that each corresponds to a distinct graph.
Expand Down Expand Up @@ -112,7 +125,7 @@ def run_on_graphs(self, datasets):
y, hat_y = [], []
self.model.train()
for graph_id in train_ids:
y_i, hat_y_i = self.train(datasets[graph_id].data, epoch)
y_i, hat_y_i = self.train(datasets[graph_id], epoch, augment=augment)
y.extend(toCPU(y_i))
hat_y.extend(toCPU(hat_y_i))
self.compute_metrics(y, hat_y, "train", epoch)
Expand Down Expand Up @@ -153,7 +166,7 @@ def run_on_graph(self):
"""
pass

def train(self, data, epoch):
def train(self, dataset, epoch, augment=False):
"""
Performs the forward pass and backpropagation to update the model's
weights.
Expand All @@ -164,6 +177,8 @@ def train(self, data, epoch):
Graph dataset that corresponds to a single connected component.
epoch : int
Current epoch.
augment : bool, optional
Indication of whether to augment data. Default is False.
Returns
-------
Expand All @@ -173,10 +188,17 @@ def train(self, data, epoch):
Prediction.
"""
y, hat_y = self.forward(data)
if augment:
dataset = self.augment(deepcopy(dataset))
y, hat_y = self.forward(dataset.data)
self.backpropagate(y, hat_y, epoch)
return y, hat_y

def augment(self, dataset):
augmented_dataset = rescale_data(dataset, self.scaling_factor)
#augmented_data = proposal_dropout(dataset, self.max_proposal_dropout)
return augmented_dataset

def forward(self, data):
"""
Runs "data" through "self.model" to generate a prediction.
Expand Down Expand Up @@ -397,3 +419,31 @@ def connected_components(data):
for i in range(cc_idxs.max().item() + 1):
cc_list.append(torch.nonzero(cc_idxs == i, as_tuple=False).view(-1))
return cc_list


def rescale_data(dataset, scaling_factor):
# Get scaling factor
low = 1.0 - scaling_factor
high = 1.0 + scaling_factor
scaling_factor = torch.tensor(np.random.uniform(low=low, high=high))

# Rescale
n = count_proposals(dataset)
dataset.data.x[0:n, 1] = scaling_factor * dataset.data.x[0:n, 1]
return dataset


def proposal_dropout(data, max_proposal_dropout):
n_dropout_edges = len(data.dropout_edges) // 2
dropout_prob = np.random.uniform(low=0, high=max_proposal_dropout)
n_remove = int(dropout_prob * n_dropout_edges)
remove_edges = sample(data.dropout_edges, n_remove)
for edge in remove_edges:
reversed_edge = [edge[1], edge[0]]
edges_to_remove = torch.tensor([edge, reversed_edge], dtype=torch.long)
edges_mask = torch.all(data.data.edge_index.T == edges_to_remove[:, None], dim=2).any(dim=0)
data.data.edge_index = data.data.edge_index[:, ~edges_mask]
return data

def count_proposals(dataset):
return dataset.data.y.size(0)

0 comments on commit d5a5185

Please sign in to comment.