Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible PyTorch implementation of WL kernel #153

Closed
wants to merge 62 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
36fc3bd
Add a PyTorch implementation of WL kernel
vladislavalerievich Oct 28, 2024
b0d3842
Fix imports
vladislavalerievich Oct 29, 2024
f87abd6
Remove redundant copy
vladislavalerievich Oct 29, 2024
358fbb7
Increase precision for allclose
vladislavalerievich Oct 29, 2024
de140b6
Fix calculation for graphs with reordered edges
vladislavalerievich Oct 29, 2024
08c7aea
Increase test coverage
vladislavalerievich Oct 29, 2024
6f07858
Improve readability of TorchWLKernel
vladislavalerievich Oct 30, 2024
896f461
Add additional comments to TorchWLKernel
vladislavalerievich Oct 30, 2024
383e924
Add MixedSingleTaskGP to process graphs
vladislavalerievich Nov 8, 2024
65666a3
Refactor WLKernelWrapper into a standalone WLKernel class.
vladislavalerievich Nov 20, 2024
7fa9432
Update tests
vladislavalerievich Nov 20, 2024
4227f22
Add a check for empty inputs
vladislavalerievich Nov 20, 2024
f194bd2
Improve and combine tests
vladislavalerievich Nov 20, 2024
a104840
Update WLKernel
vladislavalerievich Nov 21, 2024
246f9f6
Add acquisition function with graph sampling
vladislavalerievich Nov 21, 2024
770c626
Add a custom __call__ method to pass graphs during optimization
vladislavalerievich Nov 21, 2024
8bf7ea7
Update MixedSingleTaskGP
vladislavalerievich Dec 7, 2024
84d0104
Remove not used argument
vladislavalerievich Dec 7, 2024
d63239a
Update sample_graphs
vladislavalerievich Dec 7, 2024
3db3f89
Handle different batch dimensions
vladislavalerievich Dec 7, 2024
f69ddbe
Set num_restarts=10
vladislavalerievich Dec 7, 2024
1c4cc83
Add acquisition function
vladislavalerievich Dec 7, 2024
dab9a8c
Update WLKernel
vladislavalerievich Dec 7, 2024
2999582
Make train_inputs private
vladislavalerievich Dec 7, 2024
ad55030
Update tests
vladislavalerievich Dec 7, 2024
8093d31
fix: Implement graph acquisition
eddiebergman Dec 16, 2024
9f978d6
fix: Implement graph acquisition (#164)
vladislavalerievich Dec 24, 2024
a1a29a8
Delete unused MixedSingleTaskGP
vladislavalerievich Dec 24, 2024
046ad66
Add seed_all and min_max_scale
vladislavalerievich Dec 24, 2024
0a609f7
Refactor optimize.py
vladislavalerievich Dec 24, 2024
5486dcc
Speed up WL kernel computations
vladislavalerievich Dec 24, 2024
f140c56
Process wl iterations in batches
vladislavalerievich Dec 24, 2024
371b530
Use CSR
vladislavalerievich Dec 25, 2024
1478fd9
Implement caching
vladislavalerievich Jan 16, 2025
a4ffaaf
Clean up __init__ methods
vladislavalerievich Jan 16, 2025
2ec7d5b
Split _compute_kernel logic into smaller methods
vladislavalerievich Jan 16, 2025
8d6b63b
Rename kernel to BoTorchWLKernel
vladislavalerievich Jan 16, 2025
f18642b
Move GraphDataset class into utils.py
vladislavalerievich Jan 16, 2025
bb92de4
Delete GraphDataset
vladislavalerievich Jan 19, 2025
e409798
Update tests
vladislavalerievich Jan 20, 2025
51e6ae4
Simplify TorchWLKernel
vladislavalerievich Jan 20, 2025
bdd32db
Remove torch_wl_usage_example.py
vladislavalerievich Jan 21, 2025
7747e49
Update grakel_wl_usage_example.py
vladislavalerievich Jan 23, 2025
21b32c8
Update TestTorchWLKernel
vladislavalerievich Jan 23, 2025
dabf4f0
Create graphs_to_tensors function
vladislavalerievich Jan 23, 2025
22cf6d5
Add docstring to BoTorchWLKernel
vladislavalerievich Jan 23, 2025
52b3b14
Add tests for the BoTorchWLKernel
vladislavalerievich Jan 23, 2025
fe79d63
Move redundant files to examples directory
vladislavalerievich Jan 23, 2025
7729d2c
Combine set_graph_lookup context managers into one
vladislavalerievich Jan 23, 2025
1cecacc
Update comments for the optimize_acqf_graph function
vladislavalerievich Jan 23, 2025
ab730d3
Move sample_graphs into utils.py
vladislavalerievich Jan 23, 2025
3eb793d
Rename mixed_single_task_gp_usage_example.py
vladislavalerievich Jan 23, 2025
0cfae28
Add comments
vladislavalerievich Jan 23, 2025
88ddfe1
Move set_graph_lookup into its own file
vladislavalerievich Jan 23, 2025
6d9ea56
Update imports
vladislavalerievich Jan 23, 2025
be04ad2
Print results
vladislavalerievich Jan 23, 2025
458d420
Provide better file names
vladislavalerievich Jan 23, 2025
f7922db
Organize imports
vladislavalerievich Jan 23, 2025
4cc0b29
Use lru_cache instead of simple dict cache
vladislavalerievich Jan 23, 2025
4e8bdad
Improve tests
vladislavalerievich Jan 23, 2025
ea77e44
Fix ruff and mypy complaints
vladislavalerievich Jan 23, 2025
5e2a33b
Improve kernels
vladislavalerievich Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions grakel_replace/mixed_single_task_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from botorch.models import SingleTaskGP
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel
from gpytorch.module import Module
from grakel_replace.torch_wl_kernel import GraphDataset, TorchWLKernel

if TYPE_CHECKING:
import networkx as nx
from torch import Tensor


class MixedSingleTaskGP(SingleTaskGP):
"""A Gaussian Process model that handles numerical, categorical, and graph inputs.

This class extends BoTorch's SingleTaskGP to work with hybrid input spaces containing:
- Numerical features
- Categorical features
- Graph structures

It uses the Weisfeiler-Lehman (WL) kernel for graph inputs and combines it with
standard kernels for numerical/categorical features using an additive kernel structure

Attributes:
_wl_kernel (TorchWLKernel): The Weisfeiler-Lehman kernel for graph similarity
_train_graphs (List[nx.Graph]): Training set graph instances
_K_graph (Tensor): Pre-computed graph kernel matrix for training data
num_cat_kernel (Optional[Module]): Kernel for numerical/categorical features
"""

def __init__(
self,
train_X: Tensor, # Shape: (n_samples, n_numerical_categorical_features)
train_graphs: list[nx.Graph], # List of n_samples graph instances
train_Y: Tensor, # Shape: (n_samples, n_outputs)
train_Yvar: Tensor | None = None, # Shape: (n_samples, n_outputs) or None
num_cat_kernel: Module | None = None,
wl_kernel: TorchWLKernel | None = None,
**kwargs # Additional arguments passed to SingleTaskGP
) -> None:
"""Initialize the mixed input Gaussian Process model.

Args:
train_X: Training data tensor for numerical and categorical features
train_graphs: List of training graphs
train_Y: Target values
train_Yvar: Observation noise variance (optional)
num_cat_kernel: Kernel for numerical/categorical features (optional)
wl_kernel: Custom Weisfeiler-Lehman kernel instance (optional)
**kwargs: Additional arguments for SingleTaskGP initialization
"""
# Initialize parent class with initial covar_module
super().__init__(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
covar_module=num_cat_kernel or self._graph_kernel_wrapper(),
**kwargs
)

# Initialize WL kernel with default parameters if not provided
self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True)
self._train_graphs = train_graphs

# Convert graphs to required format and compute kernel matrix
self._train_graph_dataset = GraphDataset.from_networkx(train_graphs)
self._K_train = self._wl_kernel(self._train_graph_dataset)

if num_cat_kernel is not None:
# Create additive kernel combining numerical/categorical and graph kernels
combined_kernel = AdditiveKernel(
num_cat_kernel,
self._graph_kernel_wrapper()
)
self.covar_module = combined_kernel

self.num_cat_kernel = num_cat_kernel

def _graph_kernel_wrapper(self) -> Module:
"""Creates a GPyTorch-compatible kernel module wrapping the WL kernel.

This wrapper allows the WL kernel to be used within the GPyTorch framework
by providing a forward method that returns the pre-computed kernel matrix.

Returns:
Module: A GPyTorch kernel module wrapping the WL kernel computation
"""

class WLKernelWrapper(Module):
def __init__(self, parent: MixedSingleTaskGP):
super().__init__()
self.parent = parent

def forward(
self,
x1: Tensor,
x2: Tensor | None = None,
diag: bool = False,
last_dim_is_batch: bool = False
) -> Tensor:
"""Compute the kernel matrix for the graph inputs.

Args:
x1: First input tensor (unused, required for interface compatibility)
x2: Second input tensor (must be None)
diag: Whether to return only diagonal elements
last_dim_is_batch: Whether the last dimension is a batch dimension

Returns:
Tensor: Pre-computed graph kernel matrix

Raises:
NotImplementedError: If x2 is not None (cross-covariance not implemented)
"""
if x2 is None:
return self.parent._K_train

# Compute cross-covariance between train and test graphs
test_dataset = GraphDataset.from_networkx(self.parent._test_graphs)
return self.parent._wl_kernel(
self.parent._train_graph_dataset,
test_dataset
)

return WLKernelWrapper(self)
vladislavalerievich marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal:
"""Forward pass computing the GP distribution for given inputs.

Computes the kernel matrix for both numerical/categorical features and graphs,
combines them if both are present, and returns the resulting GP distribution.

Args:
X: Input tensor for numerical and categorical features
graphs: List of input graphs

Returns:
MultivariateNormal: GP distribution for the given inputs
"""
if len(X) != len(graphs):
raise ValueError(
f"Number of feature vectors ({len(X)}) must match "
f"number of graphs ({len(graphs)})"
)

# Process new graphs and compute kernel matrix
proc_graphs = GraphDataset.from_networkx(graphs)
K_new = self._wl_kernel(proc_graphs) # Shape: (n_samples, n_samples)

# If we have both numerical/categorical and graph features
if self.num_cat_kernel is not None:
# Compute kernel for numerical/categorical features
K_num_cat = self.num_cat_kernel(X)
# Add the kernels (element-wise addition)
K_combined = K_num_cat + K_new
else:
K_combined = K_new

# Compute mean using the mean module
mean_x = self.mean_module(X)

return MultivariateNormal(mean_x, K_combined)
102 changes: 102 additions & 0 deletions grakel_replace/mixed_single_task_gp_usage_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import networkx as nx
import torch
from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel, MaternKernel
from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP
from grakel_replace.torch_wl_kernel import TorchWLKernel

TRAIN_CONFIGS = 10
TEST_CONFIGS = 10
TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS

N_NUMERICAL = 2
N_CATEGORICAL = 2
N_CATEGORICAL_VALUES_PER_CATEGORY = 3
N_GRAPH = 2

kernels = []

# Create numerical and categorical features
X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64)
if N_NUMERICAL > 0:
X[:, :N_NUMERICAL] = torch.rand(
size=(TOTAL_CONFIGS, N_NUMERICAL),
dtype=torch.float64,
)

if N_CATEGORICAL > 0:
X[:, N_NUMERICAL:] = torch.randint(
0,
N_CATEGORICAL_VALUES_PER_CATEGORY,
size=(TOTAL_CONFIGS, N_CATEGORICAL),
dtype=torch.float64,
)

# Create random graph architectures
graphs = []
for _ in range(TOTAL_CONFIGS):
G = nx.erdos_renyi_graph(n=5, p=0.5) # Random graph with 5 nodes
graphs.append(G)

# Create random target values
y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64)

# Setup kernels for numerical and categorical features
if N_NUMERICAL > 0:
matern = ScaleKernel(
MaternKernel(
nu=2.5,
ard_num_dims=N_NUMERICAL,
active_dims=tuple(range(N_NUMERICAL)),
),
)
kernels.append(matern)

if N_CATEGORICAL > 0:
hamming = ScaleKernel(
CategoricalKernel(
ard_num_dims=N_CATEGORICAL,
active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)),
),
)
kernels.append(hamming)

# Combine numerical and categorical kernels
combined_num_cat_kernel = AdditiveKernel(*kernels) if kernels else None

# Create WL kernel for graphs
wl_kernel = TorchWLKernel(n_iter=5, normalize=True)
vladislavalerievich marked this conversation as resolved.
Show resolved Hide resolved

# Split into train and test sets
train_x = X[:TRAIN_CONFIGS]
train_graphs = graphs[:TRAIN_CONFIGS]
train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch

test_x = X[TRAIN_CONFIGS:]
test_graphs = graphs[TRAIN_CONFIGS:]
test_y = y[TRAIN_CONFIGS:].unsqueeze(-1)

# Initialize the mixed GP
gp = MixedSingleTaskGP(
train_X=train_x,
train_graphs=train_graphs,
train_Y=train_y,
num_cat_kernel=combined_num_cat_kernel,
wl_kernel=wl_kernel,
)

# Compute the posterior distribution
multivariate_normal: MultivariateNormal = gp.forward(train_x, train_graphs)
print("Posterior distribution:", multivariate_normal)

# Making predictions on test data
with torch.no_grad():
posterior = gp.forward(test_x, test_graphs)
predictions = posterior.mean
uncertainties = posterior.variance.sqrt()
covar = posterior.covariance_matrix

print("\nMean:", predictions)
print("Variance:", uncertainties)
print("Covariance matrix:", covar)
81 changes: 81 additions & 0 deletions grakel_replace/single_task_gp_usage_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
from botorch.models import SingleTaskGP
from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel, MaternKernel

TRAIN_CONFIGS = 10
TEST_CONFIGS = 10
TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS

N_NUMERICAL = 2
N_CATEGORICAL = 2
N_CATEGORICAL_VALUES_PER_CATEGORY = 3

kernels = []

# Create some random encoded hyperparameter configurations
X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64)
if N_NUMERICAL > 0:
X[:, :N_NUMERICAL] = torch.rand(
size=(TOTAL_CONFIGS, N_NUMERICAL),
dtype=torch.float64,
)

if N_CATEGORICAL > 0:
X[:, N_NUMERICAL:] = torch.randint(
0,
N_CATEGORICAL_VALUES_PER_CATEGORY,
size=(TOTAL_CONFIGS, N_CATEGORICAL),
dtype=torch.float64,
)

y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64)

if N_NUMERICAL > 0:
matern = ScaleKernel(
MaternKernel(
nu=2.5,
ard_num_dims=N_NUMERICAL,
active_dims=tuple(range(N_NUMERICAL)),
),
)
kernels.append(matern)

if N_CATEGORICAL > 0:
hamming = ScaleKernel(
CategoricalKernel(
ard_num_dims=N_CATEGORICAL,
active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)),
),
)
kernels.append(hamming)


combined_num_cat_kernel = AdditiveKernel(*kernels)

train_x = X[:TRAIN_CONFIGS]
train_y = y[:TRAIN_CONFIGS]

test_x = X[TRAIN_CONFIGS:]
test_y = y[TRAIN_CONFIGS:]

K_matrix = combined_num_cat_kernel.forward(train_x, train_x)
print(
"K_matrix: ", K_matrix.to_dense()
)

train_y = train_y.unsqueeze(-1)
test_y = test_y.unsqueeze(-1)

gp = SingleTaskGP(
train_X=train_x,
train_Y=train_y,
mean_module=None, # We can leave it as the default it uses which is `ConstantMean`
covar_module=combined_num_cat_kernel,
)

multivariate_normal: MultivariateNormal = gp.forward(train_x)
print("Mean:", multivariate_normal.mean)
print("Variance:", multivariate_normal.variance)
print("Covariance matrix:", multivariate_normal.covariance_matrix)
Loading
Loading