Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible PyTorch implementation of WL kernel #153

Closed
wants to merge 62 commits into from

Conversation

vladislavalerievich
Copy link
Collaborator

@vladislavalerievich vladislavalerievich commented Oct 28, 2024

This pull request introduces a custom PyTorch implementation of the Weisfeiler-Lehman (WL) kernel to replace the existing Grakel-based implementation. The most important changes include adding new example scripts, creating the custom WL kernel class, and adding tests to ensure correctness and compatibility.

New Implementation of Weisfeiler-Lehman Kernel:

  • grakel_replace/torch_wl_kernel.py: Implemented a custom PyTorch class TorchWLKernel for the WL kernel, including methods to convert NetworkX graphs to sparse adjacency tensors, initialize node labels, perform WL iterations, and compute the kernel matrix.

Copy link
Contributor

@eddiebergman eddiebergman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems really good as mentioned previously in MM!

There's some more changes we'll need to discuss regarding the representation we pass in and how we'll get it to work with BoTorch but we can discuss that later. For now, there's a bug, which from the looks of things, extends to either a bug in GraKel or at least in how its used, based on your comparison test with it.

Reproduction with description:

import networkx as nx
from torch_wl_kernel import GraphDataset, TorchWLKernel

# Create the same graphs as for the Grakel example

G1 = nx.Graph()
G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)])

G2 = nx.Graph()
G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])
G3 = nx.Graph()
G3.add_edges_from([(0, 1), (1, 3), (3, 2)])

# Process graphs
graphs = GraphDataset.from_networkx([G1, G2, G3])

# Initialize and run WL kernel
wl_kernel = TorchWLKernel(n_iter=2, normalize=False)

K = wl_kernel(graphs)

print("Kernel matrix (pairwise similarities):")
print(K)

# Issue: GraphDataset.from_networkx() relabels nodes independantly which is incorrect,
# we assume the edges all refer to the same, i.e. the 0, 1, 2, 3, 4 in the graphs
# above are not independant of each-other
# ------------------------------------------------------------------
# Below, we re-ordered the edges, placing the first edge at the end.
# This is the same graph, yet the kernel returns something different.
#
#
#     v-------------------------------------------v
#  [(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]
#          [(1, 2), (1, 3), (1, 4), (2, 3), (0, 1)]
#
# Take a look at the implementation of `nx.convert_node_labels_to_integers()`
# that is used in `GraphDataset.from_networkx()`. We likely need to create
# our own mapping and relabel the nodes as they do.
G1 = nx.Graph()
edges_g1 = [(1, 2), (1, 3), (1, 4), (2, 3), (0, 1)]
G1.add_edges_from(edges_g1)
print(list(G1.nodes()))

G2 = nx.Graph()
G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])
print(list(G2.nodes()))
G3 = nx.Graph()
G3.add_edges_from([(0, 1), (1, 3), (3, 2)])
print(list(G3.nodes()))

# Process graphs
graphs = GraphDataset.from_networkx([G1, G2, G3])
for g in graphs:
    print(g.edges())

# Initialize and run WL kernel
wl_kernel = TorchWLKernel(n_iter=2, normalize=False)

K = wl_kernel(graphs)

print("Kernel matrix (pairwise similarities):")
print(K)

grakel_replace/grakel_wl_usage_example.py Outdated Show resolved Hide resolved
grakel_replace/torch_wl_kernel.py Outdated Show resolved Hide resolved
grakel_replace/torch_wl_kernel.py Outdated Show resolved Hide resolved
tests/test_torch_wl_kernel.py Outdated Show resolved Hide resolved
tests/test_torch_wl_kernel.py Outdated Show resolved Hide resolved
Copy link
Contributor

@eddiebergman eddiebergman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good but I have some questions regarding the definition of the class within the actual Kernel.

Otherwise, looks really solid and the test are really good!

grakel_replace/mixed_single_task_gp.py Outdated Show resolved Hide resolved
grakel_replace/mixed_single_task_gp_usage_example.py Outdated Show resolved Hide resolved
@vladislavalerievich vladislavalerievich added the enhancement New feature or request label Nov 22, 2024
@vladislavalerievich vladislavalerievich removed the enhancement New feature or request label Dec 7, 2024
@vladislavalerievich vladislavalerievich self-assigned this Dec 7, 2024
@vladislavalerievich vladislavalerievich added the enhancement New feature or request label Jan 23, 2025
Copy link
Contributor

@eddiebergman eddiebergman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks solid! I did a skim over the tests but not enough to have a good idea. I do see that you're testing vs grakel which is nice!

Lets talk about it tomorrow with last steps to merge in :)

Comment on lines 111 to 114
out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device)
for q in range(q_dim_size):
out[q] = self._compute_kernel(x1[q], x2[q], diag=diag)
return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, like the usage of out here :)

Comment on lines 116 to 132
def _prepare_indices(self, x1: Tensor, x2: Tensor) -> tuple[list[int], list[int]]:
"""Convert tensor indices to integer lists and handle special cases."""
indices1 = x1.flatten().round().to(torch.int64).tolist()
indices2 = x2.flatten().round().to(torch.int64).tolist()

# Handle special case for -1 index
if -1 in indices1 or -1 in indices2:
self._handle_negative_one_index()

return indices1, indices2

def _handle_negative_one_index(self) -> None:
"""Handle the special case where -1 index is present."""
if -1 not in self.adjacency_cache:
last_graph_idx = len(self.graph_lookup) - 1
self.adjacency_cache[-1] = self.adjacency_cache[last_graph_idx]
self.label_cache[-1] = self.label_cache[last_graph_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This -1 you'll have to explain to me.

Slight optimization things since we're at it.

  • Why the round() in _prepare_indices? If we manage to keep all indices in int64 format then we shouldn't need to do this.
  • Slight optimization around the if -1 in indices1 or -1 in indices2: ... It's a nice little trick to check the smaller one first, which can sometimes shorcut out the longer one ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, just for fun, I did some benchmark if it would be faster to check if -1 is in a torch tensor if you're interested: (spoiler, it's not unless the -1 occurs at the end of the list and torch.size>1000)

import torch
import time


def needle_haystack_torch(
    needle: int,
    haystack: torch.Tensor,
    iters: int = 10,
    warmup: int = 3,
) -> float:
    # Avoiding logic to prevent branch misses
    # Avoiding lists to avoid memory trashing
    masks = [0] * warmup + [1] * iters
    sum_time = 0.0

    for m in masks:
        now = time.perf_counter_ns()
        b = torch.sum(haystack == needle)
        sum_time += (time.perf_counter_ns() - now) * m

    return sum_time / iters


def needle_haystack_assumed_minus_one_needle(
    haystack: torch.Tensor,
    iters: int = 10,
    warmup: int = 3,
) -> float:
    # Avoiding logic to prevent branch misses
    # Avoiding lists to avoid memory trashing
    masks = [0] * warmup + [1] * iters
    sum_time = 0.0

    for m in masks:
        now = time.perf_counter_ns()
        b = haystack.bool().any()
        sum_time += (time.perf_counter_ns() - now) * m

    return sum_time / iters


def needle_haystack_list(
    needle: int,
    haystack: list[int],
    iters: int = 10,
    warmup: int = 3,
) -> float:
    # Avoiding logic to prevent branch misses
    # Avoiding lists to avoid memory trashing
    masks = [0] * warmup + [1] * iters
    sum_time = 0.0

    for m in masks:
        now = time.perf_counter_ns()
        b = needle in haystack
        sum_time += (time.perf_counter_ns() - now) * m

    return sum_time / iters


def needle_haystack_torch_contains(
    needle: int,
    haystack: torch.Tensor,
    iters: int = 10,
    warmup: int = 3,
) -> float:
    # Avoiding logic to prevent branch misses
    # Avoiding lists to avoid memory trashing
    masks = [0] * warmup + [1] * iters
    sum_time = 0.0
    needle_i64 = torch.tensor(needle, dtype=torch.int64)

    for m in masks:
        now = time.perf_counter_ns()
        b = needle_i64 in haystack
        sum_time += (time.perf_counter_ns() - now) * m

    return sum_time / iters


import random

sizes = [10, 100, 1_000, 10_000]

NEEDLE = -1
bias_first_ten_percent_needles = [random.randint(0, int(s * 0.1)) for s in sizes]
bias_last_ten_percent_needles = [random.randint(int(s * 0.9), s - 1) for s in sizes]
bias_middle_ten_percent_needles = [
    random.randint(int(s * 0.4), int(s * 0.6)) for s in sizes
]

# list
print("=================================")
print("List")
print("=================================")
for name, needles in (
    ("early", bias_first_ten_percent_needles),
    ("mid", bias_middle_ten_percent_needles),
    ("late", bias_last_ten_percent_needles),
):
    print(f"{name}")
    for needle_loc, size in zip(needles, sizes):
        haystack = torch.randint(0, 100, (size,)).to(torch.int64).tolist()
        haystack[needle_loc] = NEEDLE

        _time = needle_haystack_list(NEEDLE, haystack)
        print(f"{size:>8} : {_time} ns")

print("=================================")
print("Torch __contains__")
print("=================================")
# torch
for name, needles in (
    ("early", bias_first_ten_percent_needles),
    ("mid", bias_middle_ten_percent_needles),
    ("late", bias_last_ten_percent_needles),
):
    print(f"{name}")
    for needle_loc, size in zip(needles, sizes):
        haystack = torch.randint(0, 100, (size,)).to(torch.int64)
        haystack[needle_loc] = NEEDLE

        _time = needle_haystack_torch_contains(NEEDLE, haystack)
        print(f"{size:>8} : {_time} ns")

print("=================================")
print("Torch (branchless")
print("=================================")
# torch
for name, needles in (
    ("early", bias_first_ten_percent_needles),
    ("mid", bias_middle_ten_percent_needles),
    ("late", bias_last_ten_percent_needles),
):
    print(f"{name}")
    for needle_loc, size in zip(needles, sizes):
        haystack = torch.randint(0, 100, (size,)).to(torch.int64)
        haystack[needle_loc] = NEEDLE

        _time = needle_haystack_torch(NEEDLE, haystack)
        print(f"{size:>8} : {_time} ns")

# Assuming haystack is -1
print("=================================")
print("Assumed -1 Needle Torch")
print("=================================")
for name, needles in (
    ("early", bias_first_ten_percent_needles),
    ("mid", bias_middle_ten_percent_needles),
    ("late", bias_last_ten_percent_needles),
):
    print(f"{name}")
    for needle_loc, size in zip(needles, sizes):
        haystack = torch.randint(0, 100, (size,)).to(torch.int64)
        haystack[needle_loc] = NEEDLE

        _time = needle_haystack_assumed_minus_one_needle(haystack)
        print(f"{size:>8} : {_time} ns")


# Assuming haystack is -1 (32bit int)
print("=================================")
print("Assumed -1 Needle Torch (32bit)")
print("=================================")
for name, needles in (
    ("early", bias_first_ten_percent_needles),
    ("mid", bias_middle_ten_percent_needles),
    ("late", bias_last_ten_percent_needles),
):
    print(f"{name}")
    for needle_loc, size in zip(needles, sizes):
        haystack = torch.randint(0, 100, (size,)).to(torch.int32)
        haystack[needle_loc] = NEEDLE

        _time = needle_haystack_assumed_minus_one_needle(haystack)
        print(f"{size:>8} : {_time} ns")
=================================
List
=================================
early
      10 : 114.9 ns
     100 : 150.5 ns
    1000 : 493.8 ns
   10000 : 7026.0 ns
mid
      10 : 144.5 ns
     100 : 516.3 ns
    1000 : 4377.6 ns
   10000 : 37341.5 ns
late
      10 : 183.1 ns
     100 : 803.7 ns
    1000 : 7243.2 ns
   10000 : 70000.9 ns
=================================
Torch __contains__
=================================
early
      10 : 5270.2 ns
     100 : 5586.7 ns
    1000 : 6733.4 ns
   10000 : 18434.5 ns
mid
      10 : 5317.3 ns
     100 : 5547.9 ns
    1000 : 6803.5 ns
   10000 : 18253.8 ns
late
      10 : 5482.2 ns
     100 : 5674.8 ns
    1000 : 6891.6 ns
   10000 : 19183.8 ns
=================================
Torch (branchless
=================================
early
      10 : 7294.9 ns
     100 : 7318.5 ns
    1000 : 8565.4 ns
   10000 : 16958.5 ns
mid
      10 : 7301.2 ns
     100 : 7333.3 ns
    1000 : 8558.1 ns
   10000 : 17051.4 ns
late
      10 : 7274.5 ns
     100 : 8327.8 ns
    1000 : 8515.6 ns
   10000 : 17387.1 ns
=================================
Assumed -1 Needle Torch
=================================
early
      10 : 4597.8 ns
     100 : 4610.9 ns
    1000 : 5755.6 ns
   10000 : 17368.8 ns
mid
      10 : 4594.5 ns
     100 : 4732.0 ns
    1000 : 5830.6 ns
   10000 : 15438.9 ns
late
      10 : 4621.7 ns
     100 : 4710.7 ns
    1000 : 5925.3 ns
   10000 : 15835.6 ns
=================================
Assumed -1 Needle Torch (32bit)
=================================
early
      10 : 4641.2 ns
     100 : 4812.4 ns
    1000 : 5747.4 ns
   10000 : 15771.1 ns
mid
      10 : 4804.2 ns
     100 : 5974.5 ns
    1000 : 5961.7 ns
   10000 : 14754.5 ns
late
      10 : 4773.3 ns
     100 : 4962.7 ns
    1000 : 7646.2 ns
   10000 : 15066.1 ns

grakel_replace/torch_wl_kernel.py Outdated Show resolved Hide resolved
grakel_replace/torch_wl_kernel.py Outdated Show resolved Hide resolved
grakel_replace/torch_wl_kernel.py Outdated Show resolved Hide resolved
grakel_replace/torch_wl_kernel.py Outdated Show resolved Hide resolved
grakel_replace/torch_wl_kernel.py Outdated Show resolved Hide resolved
tests/test_torch_wl_kernel.py Outdated Show resolved Hide resolved
@eddiebergman
Copy link
Contributor

eddiebergman commented Jan 24, 2025

Update from meeting:

  • Decided it's best to make a new branch and put the code in on a fresh branch off main.
  • Most graph stuff should live under neps/optimizers/bayesian_optimization/graphs/
  • There is still a bug in that we should reclear the cache at each fixed graph iterations.
def set_graph_lookup(...):
    ...
    
    # Save the current graph lookup and set the new graph lookup
    for kern in modules:
        if isinstance(kern, TorchWLKernel):
            kern._get_node_neighbors.cache_clear()
            kern._wl_iteration.cache_clear()
        elif isinstance(kern, BoTorchWLKernel):
            kern._compute_kernel.cache_clear()

        kernel_prev_graphs.append((kern, kern.graph_lookup))
        if append:
            kern.set_graph_lookup([*kern.graph_lookup, *new_graphs])
        else:
            kern.set_graph_lookup(new_graphs)
  • There are some remaining parts related to improving acquisition function, however the core loop works.
def optimize_acqf_graph(
    ...,
    sampled_graphs,  # We would likely want to either pass in sampled graphs,
                     # or otherwise have some graph sampler to pass in.
                     # For now we just leave it as is, it doesn't work generically
                     # but we need the `Grammar` Parameter first to define this so it's fine.
):

    # Due to limitations BoTorch only supporting tensors, we
    # use a column of the input tensor as indices into a list of graphs.
    # The `with set_graph_lookup()` contextmanager injects the test graph into the
    # kernel, i.e. theses are what the indices in the tensor refer to.
    # We set the index to `-1` to indicate the test graph, i.e. the last one in `graph_lookup`
    # Since we set the test graph in a loop, the `best_candidates[best_idx]` will
    # always contain a `-1`, which says (oh the only sample to consider was the best one), but
    # this `-1` means something different at each iteration.
    # tldr; we should save `best_graph` seperately in the for loop for now.
    
    # The [best_idx, :-1] effectively removes the `-1` graph index column
    return best_candidates[best_idx, :-1], best_graph, best_scores[best_idx].item()
  • What we might actually prefer to do is be able to put all sampled graphs into the kernel and compute the scores all at once. This would change the above logic. We can do this in a follow up PR. Please leave a comment along these lines in the code.
    • The caching may need some revisiting if we accept all graphs at once. Probably not as the kernel is an element-wise one and doesn't have a batch matrix implementation. However doing all at once means one less out loop in the following:
        for graph in sampled_graphs:
          # Temporarily set the graph lookup for the kernel
          with set_graph_lookup(acq_function.model.covar_module, [graph], append=True):
              # Iterate through each fixed feature configuration (if provided)
              for fixed_features in fixed_features_list or [{}]:
                  # Add the graph index to the fixed features, indicating that the last
                  # graphin the lookup should be used
                  updated_fixed_features = {**fixed_features, graph_idx: -1.0}
    
                  # Optimize the acquisition function with the updated fixed features
                  candidates, scores = optimize_acqf_mixed(
                      acq_function=acq_function,
                      bounds=bounds,
                      fixed_features_list=[updated_fixed_features],
                      num_restarts=num_restarts,
                      raw_samples=raw_samples,
                      q=q,
                  )
    
                  # Store the candidates and their scores
                  best_candidates.append(candidates)
                  best_scores.append(scores)

Ultimately, we don't have the infrastructure to actually use this function yet so it's fine if it's not perfect, the implementation will evolve slowly, as long as the core loop works, we're all good.


@timurcarstensen We have a basic Botorch kernel that can jointly optimize nx.graphs and other parameters, it's just not hooked up to the system as a whole yet as we don't know what that looks like yet.

@vladislavalerievich
Copy link
Collaborator Author

  • Most graph stuff should live under neps/optimizers/bayesian_optimization/graphs/

The bayesian_optimization folder is no longer there.

@eddiebergman
Copy link
Contributor

  • Most graph stuff should live under neps/optimizers/bayesian_optimization/graphs/

The bayesian_optimization folder is no longer there.

Ahh sorry, you can do neps/optimizers/models/graphs

@vladislavalerievich vladislavalerievich deleted the feat-torch-wl-kernel branch January 29, 2025 12:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

2 participants