From 26614bc58c74f97d653d77b1d4337d48d5f33da4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Apr 2024 22:42:50 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .gitignore | 2 +- notebooks/model.py | 70 ++++++++++++++++++++++++++++++----------- notebooks/project.ipynb | 24 ++++++++------ 3 files changed, 67 insertions(+), 29 deletions(-) diff --git a/.gitignore b/.gitignore index 3b1df20..6b12674 100644 --- a/.gitignore +++ b/.gitignore @@ -175,4 +175,4 @@ cython_debug/ .data/ data/ -raw_data/ \ No newline at end of file +raw_data/ diff --git a/notebooks/model.py b/notebooks/model.py index 71802ff..42571ed 100644 --- a/notebooks/model.py +++ b/notebooks/model.py @@ -1,15 +1,24 @@ -import torch +from __future__ import annotations +import torch +from pytorch_lightning.core.mixins import HyperparametersMixin +from torch_cluster import knn from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear -from torch_cluster import knn -from pytorch_lightning.core.mixins import HyperparametersMixin class GravNet(MessagePassing): - def __init__(self, in_channels: int, out_channels: int, space_dimensions: int, k: int = 4, message_multiple: int = 2, **kwargs): - super().__init__(aggr=['mean', 'max'], flow='source_to_target', **kwargs) + def __init__( + self, + in_channels: int, + out_channels: int, + space_dimensions: int, + k: int = 4, + message_multiple: int = 2, + **kwargs, + ): + super().__init__(aggr=["mean", "max"], flow="source_to_target", **kwargs) assert not (in_channels != out_channels and message_multiple == 0) @@ -28,44 +37,67 @@ def __init__(self, in_channels: int, out_channels: int, space_dimensions: int, k self.in_channels = in_channels self.k = k - def forward(self, x, batch_index = None): + def forward(self, x, batch_index=None): m_1, m_2, s = self.lin_embed(x).split(self.in_channels, dim=-1) edge_index = knn(s, s, self.k, batch_index, batch_index).flip([0]) - edge_weight = torch.exp(-10. * (s[edge_index[0]] - s[edge_index[1]]).pow(2).sum(-1)) - out = self.propagate(edge_index, x=(m_1, m_2), edge_weight=edge_weight, size=None).view(x.size()[0], -1) + edge_weight = torch.exp( + -10.0 * (s[edge_index[0]] - s[edge_index[1]]).pow(2).sum(-1) + ) + out = self.propagate( + edge_index, x=(m_1, m_2), edge_weight=edge_weight, size=None + ).view(x.size()[0], -1) return self.lin_out(torch.cat([x, out], dim=-1)) def message(self, x_i, x_j, edge_weight): if self.lin_message != None: mes = self.lin_message(torch.cat([x_j, x_i], dim=-1)) - return (mes/torch.linalg.vector_norm(mes, dim=-1).unsqueeze(-1)) * edge_weight.unsqueeze(1) + return ( + mes / torch.linalg.vector_norm(mes, dim=-1).unsqueeze(-1) + ) * edge_weight.unsqueeze(1) else: return x_j * edge_weight.unsqueeze(1) def __repr__(self) -> str: - return (f'{self.__class__.__name__}({self.in_channels}, ' - f'{self.out_channels}, k={self.k})') - + return ( + f"{self.__class__.__name__}({self.in_channels}, " + f"{self.out_channels}, k={self.k})" + ) + + class Model(torch.nn.Module, HyperparametersMixin): - def __init__(self, embed_dim, space_dim, num_layers, k = 4, message_multiple = 2, input_dim = 14, output_dim = 4): + def __init__( + self, + embed_dim, + space_dim, + num_layers, + k=4, + message_multiple=2, + input_dim=14, + output_dim=4, + ): super().__init__() self.layers = torch.nn.ModuleList() for _ in range(num_layers): - self.layers.append(GravNet(in_channels=embed_dim, out_channels=embed_dim, space_dimensions=space_dim, k=k, message_multiple=message_multiple)) + self.layers.append( + GravNet( + in_channels=embed_dim, + out_channels=embed_dim, + space_dimensions=space_dim, + k=k, + message_multiple=message_multiple, + ) + ) self.linear_in = Linear(input_dim, embed_dim) self.linear_out = Linear(embed_dim, output_dim + 1) self.act = torch.nn.LeakyReLU() - + def forward(self, batch, batch_index): batch = self.act(self.linear_in(batch)) for layer in self.layers(): batch = self.act(layer(batch, batch_index)) batch = self.linear_out(batch) - return { - "B" : torch.sigmoid(batch[:, 0]), - "H" : batch[:, 1:] - } + return {"B": torch.sigmoid(batch[:, 0]), "H": batch[:, 1:]} diff --git a/notebooks/project.ipynb b/notebooks/project.ipynb index 35f3540..93b4a11 100644 --- a/notebooks/project.ipynb +++ b/notebooks/project.ipynb @@ -20,6 +20,7 @@ "from pytorch_lightning import Trainer\n", "from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger\n", "from pytorch_lightning.callbacks import RichProgressBar\n", + "\n", "# from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback\n", "\n", "from model import Model\n", @@ -127,7 +128,7 @@ } ], "source": [ - "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(\"Current device:\", device)" ] }, @@ -156,7 +157,7 @@ "metadata": {}, "outputs": [], "source": [ - "model = Model(embed_dim = 64, space_dim = 4, num_layers = 4)" + "model = Model(embed_dim=64, space_dim=4, num_layers=4)" ] }, { @@ -200,10 +201,8 @@ " max_epochs=1,\n", " accelerator=device,\n", " log_every_n_steps=1,\n", - " callbacks=[\n", - " PrintValidationMetrics()\n", - " ],\n", - " logger=logger\n", + " callbacks=[PrintValidationMetrics()],\n", + " logger=logger,\n", ")" ] }, @@ -269,15 +268,22 @@ "POS = y.detach()\n", "ID = raw_data[850].particle_id\n", "\n", - "id_np = ID.to('cpu').numpy()\n", - "pos_np = POS.to('cpu').numpy()\n", + "id_np = ID.to(\"cpu\").numpy()\n", + "pos_np = POS.to(\"cpu\").numpy()\n", "\n", "unique_ids = np.unique(id_np)\n", "highlight_id = np.random.choice(unique_ids)\n", "highlight_id_2 = np.random.choice(unique_ids)\n", "highlight_id_3 = np.random.choice(unique_ids)\n", "\n", - "colors = ['black' if (id == highlight_id or id == highlight_id_2 or id == highlight_id_3) else 'wheat' for id in id_np]\n", + "colors = [\n", + " (\n", + " \"black\"\n", + " if (id == highlight_id or id == highlight_id_2 or id == highlight_id_3)\n", + " else \"wheat\"\n", + " )\n", + " for id in id_np\n", + "]\n", "\n", "plt.figure(figsize=(10, 8))\n", "plt.scatter(pos_np[:, 0], pos_np[:, 1], color=colors, alpha=0.25)\n",