Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 1, 2024
1 parent 8076663 commit 26614bc
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,4 @@ cython_debug/

.data/
data/
raw_data/
raw_data/
70 changes: 51 additions & 19 deletions notebooks/model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import torch
from __future__ import annotations

import torch
from pytorch_lightning.core.mixins import HyperparametersMixin
from torch_cluster import knn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear

from torch_cluster import knn
from pytorch_lightning.core.mixins import HyperparametersMixin

class GravNet(MessagePassing):

def __init__(self, in_channels: int, out_channels: int, space_dimensions: int, k: int = 4, message_multiple: int = 2, **kwargs):
super().__init__(aggr=['mean', 'max'], flow='source_to_target', **kwargs)
def __init__(
self,
in_channels: int,
out_channels: int,
space_dimensions: int,
k: int = 4,
message_multiple: int = 2,
**kwargs,
):
super().__init__(aggr=["mean", "max"], flow="source_to_target", **kwargs)

assert not (in_channels != out_channels and message_multiple == 0)

Expand All @@ -28,44 +37,67 @@ def __init__(self, in_channels: int, out_channels: int, space_dimensions: int, k
self.in_channels = in_channels
self.k = k

def forward(self, x, batch_index = None):
def forward(self, x, batch_index=None):

m_1, m_2, s = self.lin_embed(x).split(self.in_channels, dim=-1)
edge_index = knn(s, s, self.k, batch_index, batch_index).flip([0])
edge_weight = torch.exp(-10. * (s[edge_index[0]] - s[edge_index[1]]).pow(2).sum(-1))
out = self.propagate(edge_index, x=(m_1, m_2), edge_weight=edge_weight, size=None).view(x.size()[0], -1)
edge_weight = torch.exp(
-10.0 * (s[edge_index[0]] - s[edge_index[1]]).pow(2).sum(-1)
)
out = self.propagate(
edge_index, x=(m_1, m_2), edge_weight=edge_weight, size=None
).view(x.size()[0], -1)

return self.lin_out(torch.cat([x, out], dim=-1))

def message(self, x_i, x_j, edge_weight):
if self.lin_message != None:
mes = self.lin_message(torch.cat([x_j, x_i], dim=-1))
return (mes/torch.linalg.vector_norm(mes, dim=-1).unsqueeze(-1)) * edge_weight.unsqueeze(1)
return (
mes / torch.linalg.vector_norm(mes, dim=-1).unsqueeze(-1)
) * edge_weight.unsqueeze(1)
else:
return x_j * edge_weight.unsqueeze(1)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, k={self.k})')

return (
f"{self.__class__.__name__}({self.in_channels}, "
f"{self.out_channels}, k={self.k})"
)


class Model(torch.nn.Module, HyperparametersMixin):
def __init__(self, embed_dim, space_dim, num_layers, k = 4, message_multiple = 2, input_dim = 14, output_dim = 4):
def __init__(
self,
embed_dim,
space_dim,
num_layers,
k=4,
message_multiple=2,
input_dim=14,
output_dim=4,
):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(num_layers):
self.layers.append(GravNet(in_channels=embed_dim, out_channels=embed_dim, space_dimensions=space_dim, k=k, message_multiple=message_multiple))
self.layers.append(
GravNet(
in_channels=embed_dim,
out_channels=embed_dim,
space_dimensions=space_dim,
k=k,
message_multiple=message_multiple,
)
)
self.linear_in = Linear(input_dim, embed_dim)
self.linear_out = Linear(embed_dim, output_dim + 1)

self.act = torch.nn.LeakyReLU()

def forward(self, batch, batch_index):
batch = self.act(self.linear_in(batch))
for layer in self.layers():
batch = self.act(layer(batch, batch_index))
batch = self.linear_out(batch)

return {
"B" : torch.sigmoid(batch[:, 0]),
"H" : batch[:, 1:]
}
return {"B": torch.sigmoid(batch[:, 0]), "H": batch[:, 1:]}
24 changes: 15 additions & 9 deletions notebooks/project.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"from pytorch_lightning import Trainer\n",
"from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger\n",
"from pytorch_lightning.callbacks import RichProgressBar\n",
"\n",
"# from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback\n",
"\n",
"from model import Model\n",
Expand Down Expand Up @@ -127,7 +128,7 @@
}
],
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(\"Current device:\", device)"
]
},
Expand Down Expand Up @@ -156,7 +157,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = Model(embed_dim = 64, space_dim = 4, num_layers = 4)"
"model = Model(embed_dim=64, space_dim=4, num_layers=4)"
]
},
{
Expand Down Expand Up @@ -200,10 +201,8 @@
" max_epochs=1,\n",
" accelerator=device,\n",
" log_every_n_steps=1,\n",
" callbacks=[\n",
" PrintValidationMetrics()\n",
" ],\n",
" logger=logger\n",
" callbacks=[PrintValidationMetrics()],\n",
" logger=logger,\n",
")"
]
},
Expand Down Expand Up @@ -269,15 +268,22 @@
"POS = y.detach()\n",
"ID = raw_data[850].particle_id\n",
"\n",
"id_np = ID.to('cpu').numpy()\n",
"pos_np = POS.to('cpu').numpy()\n",
"id_np = ID.to(\"cpu\").numpy()\n",
"pos_np = POS.to(\"cpu\").numpy()\n",
"\n",
"unique_ids = np.unique(id_np)\n",
"highlight_id = np.random.choice(unique_ids)\n",
"highlight_id_2 = np.random.choice(unique_ids)\n",
"highlight_id_3 = np.random.choice(unique_ids)\n",
"\n",
"colors = ['black' if (id == highlight_id or id == highlight_id_2 or id == highlight_id_3) else 'wheat' for id in id_np]\n",
"colors = [\n",
" (\n",
" \"black\"\n",
" if (id == highlight_id or id == highlight_id_2 or id == highlight_id_3)\n",
" else \"wheat\"\n",
" )\n",
" for id in id_np\n",
"]\n",
"\n",
"plt.figure(figsize=(10, 8))\n",
"plt.scatter(pos_np[:, 0], pos_np[:, 1], color=colors, alpha=0.25)\n",
Expand Down

0 comments on commit 26614bc

Please sign in to comment.