diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index f0c5256..b4292ae 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -331,7 +331,7 @@ def generate_skel_features(neurograph, proposals, search_radius): i, j = tuple(proposal) features[proposal] = np.concatenate( ( - 1, # edge type + 1, # edge type neurograph.proposal_length(proposal), neurograph.degree[i], neurograph.degree[j], @@ -361,7 +361,12 @@ def get_directionals(neurograph, proposal, window_size): # Compute features inner_product_1 = abs(np.dot(proposal_direction, direction_i)) inner_product_2 = abs(np.dot(proposal_direction, direction_j)) - inner_product_3 = np.dot(direction_i, direction_j) + if neurograph.is_simple(proposal): + inner_product_3 = np.dot(direction_i, direction_j) + else: + inner_product_3a = np.dot(direction_i, direction_j) + inner_product_3b = np.dot(direction_i, direction_j) + inner_product_3 = max(inner_product_3a, inner_product_3b) return np.array([inner_product_1, inner_product_2, inner_product_3]) @@ -435,21 +440,23 @@ def generate_branch_features(neurograph): i, j = tuple(edge) features[frozenset(edge)] = np.concatenate( ( - -1, # edge type - -1, + -1, # edge type + np.zeros((32)) + ), + axis=None, + ) + return features +""" + 0, neurograph.degree[i], neurograph.degree[j], - -1, + 0, get_radii(neurograph, edge), np.mean(neurograph.edges[i, j]["radius"]), np.mean(neurograph.edges[i, j]["radius"]), - -1 * np.ones(12), - -1 * np.ones((N_PROFILE_PTS + 2)), - ), - axis=None, - ) - return features - + np.zeros(12), + np.zeros((N_PROFILE_PTS + 2)), +""" def compute_curvature(neurograph, edge): kappa = curvature(neurograph.edges[edge]["xyz"]) diff --git a/src/deep_neurographs/machine_learning/graph_datasets.py b/src/deep_neurographs/machine_learning/graph_datasets.py index b704239..371dff2 100644 --- a/src/deep_neurographs/machine_learning/graph_datasets.py +++ b/src/deep_neurographs/machine_learning/graph_datasets.py @@ -97,10 +97,11 @@ def __init__( self.n_proposals = len(y_proposals) # Initialize data - edge_index = set_edge_index( + edge_index, proposal_edges = set_edge_index( neurograph, proposals, idxs_branches, idxs_proposals ) self.data = GraphData(x=x, y=y, edge_index=edge_index) + self.dropout_edges = proposal_edges class HeteroGraphDataset: @@ -177,12 +178,11 @@ def set_edge_index(neurograph, proposals, idxs_branches, idxs_proposals): # Initializations branches_line_graph = nx.line_graph(neurograph) proposals_line_graph = init_proposals_line_graph(neurograph, proposals) + proposal_edges = proposal_to_proposal(proposals_line_graph, idxs_proposals) # Compute edges edge_index = branch_to_branch(branches_line_graph, idxs_branches) - edge_index.extend( - proposal_to_proposal(proposals_line_graph, idxs_proposals) - ) + edge_index.extend(proposal_edges) edge_index.extend( branch_to_proposal( neurograph, proposals, idxs_branches, idxs_proposals @@ -192,7 +192,7 @@ def set_edge_index(neurograph, proposals, idxs_branches, idxs_proposals): # Reshape edge_index = np.array(edge_index, dtype=np.int64).tolist() edge_index = torch.Tensor(edge_index).t().contiguous() - return edge_index.long() + return edge_index.long(), proposal_edges def init_proposals_line_graph(neurograph, proposals): diff --git a/src/deep_neurographs/machine_learning/graph_models.py b/src/deep_neurographs/machine_learning/graph_models.py index bff4bf1..4c88f49 100644 --- a/src/deep_neurographs/machine_learning/graph_models.py +++ b/src/deep_neurographs/machine_learning/graph_models.py @@ -25,7 +25,6 @@ def __init__(self, input_channels): # Convolutional layers self.conv1 = GCNConv(input_channels, input_channels // 2) self.conv2 = GCNConv(input_channels // 2, input_channels // 2) - self.conv3 = GCNConv(input_channels, input_channels // 2) # Activation self.dropout = Dropout(0.3) @@ -35,7 +34,7 @@ def __init__(self, input_channels): self.init_weights() def init_weights(self): - layers = [self.conv1, self.conv2, self.conv3, self.input, self.output] + layers = [self.conv1, self.conv2, self.input, self.output] for layer in layers: for param in layer.parameters(): if len(param.shape) > 1: @@ -69,13 +68,12 @@ class GAT(torch.nn.Module): def __init__(self, input_channels): super().__init__() # Linear layers - self.input = Linear(input_channels, input_channels) + self.input = Linear(input_channels, 2 * input_channels) self.output = Linear(input_channels // 2, 1) # Convolutional layers - self.conv1 = GATConv(input_channels, input_channels // 2) - self.conv2 = GATConv(input_channels // 2, input_channels // 2) - self.conv3 = GATConv(input_channels, input_channels // 2) + self.conv1 = GATConv(2 * input_channels, input_channels) + self.conv2 = GATConv(input_channels, input_channels // 2) # Activation self.dropout = Dropout(0.3) @@ -85,7 +83,7 @@ def __init__(self, input_channels): self.init_weights() def init_weights(self): - layers = [self.conv1, self.conv2, self.conv3, self.input, self.output] + layers = [self.conv1, self.conv2, self.input, self.output] for layer in layers: for param in layer.parameters(): if len(param.shape) > 1: diff --git a/src/deep_neurographs/machine_learning/graph_trainer.py b/src/deep_neurographs/machine_learning/graph_trainer.py index 25937c0..04c6e6e 100644 --- a/src/deep_neurographs/machine_learning/graph_trainer.py +++ b/src/deep_neurographs/machine_learning/graph_trainer.py @@ -21,16 +21,22 @@ ) from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter +from torch_geometric.utils import subgraph from deep_neurographs.machine_learning import ml_utils +# Training LR = 1e-3 -N_EPOCHS = 1000 -SCHEDULER_GAMMA = 0.75 +N_EPOCHS = 200 +SCHEDULER_GAMMA = 0.5 SCHEDULER_STEP_SIZE = 1000 TEST_PERCENT = 0.15 WEIGHT_DECAY = 1e-3 +# Augmentation +MAX_PROPOSAL_DROPOUT = 0.1 +SCALING_FACTOR = 0.05 + class GraphTrainer: """ @@ -44,6 +50,8 @@ def __init__( criterion, lr=LR, n_epochs=N_EPOCHS, + max_proposal_dropout=MAX_PROPOSAL_DROPOUT, + scaling_factor=SCALING_FACTOR, weight_decay=WEIGHT_DECAY, ): """ @@ -68,6 +76,7 @@ def __init__( None. """ + # Training self.model = model.to("cuda:0") self.criterion = criterion self.n_epochs = n_epochs @@ -77,6 +86,10 @@ def __init__( self.init_scheduler() self.writer = SummaryWriter() + # Augmentation + self.scaling_factor = scaling_factor + self.max_proposal_dropout = max_proposal_dropout + def init_scheduler(self): self.scheduler = StepLR( self.optimizer, @@ -84,7 +97,7 @@ def init_scheduler(self): gamma=SCHEDULER_GAMMA, ) - def run_on_graphs(self, datasets): + def run_on_graphs(self, datasets, augment=False): """ Trains a graph neural network in the case where "datasets" is a dictionary of datasets such that each corresponds to a distinct graph. @@ -112,7 +125,7 @@ def run_on_graphs(self, datasets): y, hat_y = [], [] self.model.train() for graph_id in train_ids: - y_i, hat_y_i = self.train(datasets[graph_id].data, epoch) + y_i, hat_y_i = self.train(datasets[graph_id], epoch, augment=augment) y.extend(toCPU(y_i)) hat_y.extend(toCPU(hat_y_i)) self.compute_metrics(y, hat_y, "train", epoch) @@ -153,7 +166,7 @@ def run_on_graph(self): """ pass - def train(self, data, epoch): + def train(self, dataset, epoch, augment=False): """ Performs the forward pass and backpropagation to update the model's weights. @@ -164,6 +177,8 @@ def train(self, data, epoch): Graph dataset that corresponds to a single connected component. epoch : int Current epoch. + augment : bool, optional + Indication of whether to augment data. Default is False. Returns ------- @@ -173,10 +188,17 @@ def train(self, data, epoch): Prediction. """ - y, hat_y = self.forward(data) + if augment: + dataset = self.augment(deepcopy(dataset)) + y, hat_y = self.forward(dataset.data) self.backpropagate(y, hat_y, epoch) return y, hat_y + def augment(self, dataset): + augmented_dataset = rescale_data(dataset, self.scaling_factor) + #augmented_data = proposal_dropout(dataset, self.max_proposal_dropout) + return augmented_dataset + def forward(self, data): """ Runs "data" through "self.model" to generate a prediction. @@ -397,3 +419,31 @@ def connected_components(data): for i in range(cc_idxs.max().item() + 1): cc_list.append(torch.nonzero(cc_idxs == i, as_tuple=False).view(-1)) return cc_list + + +def rescale_data(dataset, scaling_factor): + # Get scaling factor + low = 1.0 - scaling_factor + high = 1.0 + scaling_factor + scaling_factor = torch.tensor(np.random.uniform(low=low, high=high)) + + # Rescale + n = count_proposals(dataset) + dataset.data.x[0:n, 1] = scaling_factor * dataset.data.x[0:n, 1] + return dataset + + +def proposal_dropout(data, max_proposal_dropout): + n_dropout_edges = len(data.dropout_edges) // 2 + dropout_prob = np.random.uniform(low=0, high=max_proposal_dropout) + n_remove = int(dropout_prob * n_dropout_edges) + remove_edges = sample(data.dropout_edges, n_remove) + for edge in remove_edges: + reversed_edge = [edge[1], edge[0]] + edges_to_remove = torch.tensor([edge, reversed_edge], dtype=torch.long) + edges_mask = torch.all(data.data.edge_index.T == edges_to_remove[:, None], dim=2).any(dim=0) + data.data.edge_index = data.data.edge_index[:, ~edges_mask] + return data + +def count_proposals(dataset): + return dataset.data.y.size(0)