-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible PyTorch implementation of WL kernel #153
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems really good as mentioned previously in MM!
There's some more changes we'll need to discuss regarding the representation we pass in and how we'll get it to work with BoTorch but we can discuss that later. For now, there's a bug, which from the looks of things, extends to either a bug in GraKel or at least in how its used, based on your comparison test with it.
Reproduction with description:
import networkx as nx
from torch_wl_kernel import GraphDataset, TorchWLKernel
# Create the same graphs as for the Grakel example
G1 = nx.Graph()
G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)])
G2 = nx.Graph()
G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])
G3 = nx.Graph()
G3.add_edges_from([(0, 1), (1, 3), (3, 2)])
# Process graphs
graphs = GraphDataset.from_networkx([G1, G2, G3])
# Initialize and run WL kernel
wl_kernel = TorchWLKernel(n_iter=2, normalize=False)
K = wl_kernel(graphs)
print("Kernel matrix (pairwise similarities):")
print(K)
# Issue: GraphDataset.from_networkx() relabels nodes independantly which is incorrect,
# we assume the edges all refer to the same, i.e. the 0, 1, 2, 3, 4 in the graphs
# above are not independant of each-other
# ------------------------------------------------------------------
# Below, we re-ordered the edges, placing the first edge at the end.
# This is the same graph, yet the kernel returns something different.
#
#
# v-------------------------------------------v
# [(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]
# [(1, 2), (1, 3), (1, 4), (2, 3), (0, 1)]
#
# Take a look at the implementation of `nx.convert_node_labels_to_integers()`
# that is used in `GraphDataset.from_networkx()`. We likely need to create
# our own mapping and relabel the nodes as they do.
G1 = nx.Graph()
edges_g1 = [(1, 2), (1, 3), (1, 4), (2, 3), (0, 1)]
G1.add_edges_from(edges_g1)
print(list(G1.nodes()))
G2 = nx.Graph()
G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])
print(list(G2.nodes()))
G3 = nx.Graph()
G3.add_edges_from([(0, 1), (1, 3), (3, 2)])
print(list(G3.nodes()))
# Process graphs
graphs = GraphDataset.from_networkx([G1, G2, G3])
for g in graphs:
print(g.edges())
# Initialize and run WL kernel
wl_kernel = TorchWLKernel(n_iter=2, normalize=False)
K = wl_kernel(graphs)
print("Kernel matrix (pairwise similarities):")
print(K)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good but I have some questions regarding the definition of the class within the actual Kernel.
Otherwise, looks really solid and the test are really good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks solid! I did a skim over the tests but not enough to have a good idea. I do see that you're testing vs grakel which is nice!
Lets talk about it tomorrow with last steps to merge in :)
grakel_replace/torch_wl_kernel.py
Outdated
out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device) | ||
for q in range(q_dim_size): | ||
out[q] = self._compute_kernel(x1[q], x2[q], diag=diag) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, like the usage of out
here :)
grakel_replace/torch_wl_kernel.py
Outdated
def _prepare_indices(self, x1: Tensor, x2: Tensor) -> tuple[list[int], list[int]]: | ||
"""Convert tensor indices to integer lists and handle special cases.""" | ||
indices1 = x1.flatten().round().to(torch.int64).tolist() | ||
indices2 = x2.flatten().round().to(torch.int64).tolist() | ||
|
||
# Handle special case for -1 index | ||
if -1 in indices1 or -1 in indices2: | ||
self._handle_negative_one_index() | ||
|
||
return indices1, indices2 | ||
|
||
def _handle_negative_one_index(self) -> None: | ||
"""Handle the special case where -1 index is present.""" | ||
if -1 not in self.adjacency_cache: | ||
last_graph_idx = len(self.graph_lookup) - 1 | ||
self.adjacency_cache[-1] = self.adjacency_cache[last_graph_idx] | ||
self.label_cache[-1] = self.label_cache[last_graph_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This -1
you'll have to explain to me.
Slight optimization things since we're at it.
- Why the
round()
in_prepare_indices
? If we manage to keep all indices inint64
format then we shouldn't need to do this. - Slight optimization around the
if -1 in indices1
or-1 in indices2:
... It's a nice little trick to check the smaller one first, which can sometimes shorcut out the longer one ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, just for fun, I did some benchmark if it would be faster to check if -1
is in a torch tensor if you're interested: (spoiler, it's not unless the -1
occurs at the end of the list and torch.size>1000
)
import torch
import time
def needle_haystack_torch(
needle: int,
haystack: torch.Tensor,
iters: int = 10,
warmup: int = 3,
) -> float:
# Avoiding logic to prevent branch misses
# Avoiding lists to avoid memory trashing
masks = [0] * warmup + [1] * iters
sum_time = 0.0
for m in masks:
now = time.perf_counter_ns()
b = torch.sum(haystack == needle)
sum_time += (time.perf_counter_ns() - now) * m
return sum_time / iters
def needle_haystack_assumed_minus_one_needle(
haystack: torch.Tensor,
iters: int = 10,
warmup: int = 3,
) -> float:
# Avoiding logic to prevent branch misses
# Avoiding lists to avoid memory trashing
masks = [0] * warmup + [1] * iters
sum_time = 0.0
for m in masks:
now = time.perf_counter_ns()
b = haystack.bool().any()
sum_time += (time.perf_counter_ns() - now) * m
return sum_time / iters
def needle_haystack_list(
needle: int,
haystack: list[int],
iters: int = 10,
warmup: int = 3,
) -> float:
# Avoiding logic to prevent branch misses
# Avoiding lists to avoid memory trashing
masks = [0] * warmup + [1] * iters
sum_time = 0.0
for m in masks:
now = time.perf_counter_ns()
b = needle in haystack
sum_time += (time.perf_counter_ns() - now) * m
return sum_time / iters
def needle_haystack_torch_contains(
needle: int,
haystack: torch.Tensor,
iters: int = 10,
warmup: int = 3,
) -> float:
# Avoiding logic to prevent branch misses
# Avoiding lists to avoid memory trashing
masks = [0] * warmup + [1] * iters
sum_time = 0.0
needle_i64 = torch.tensor(needle, dtype=torch.int64)
for m in masks:
now = time.perf_counter_ns()
b = needle_i64 in haystack
sum_time += (time.perf_counter_ns() - now) * m
return sum_time / iters
import random
sizes = [10, 100, 1_000, 10_000]
NEEDLE = -1
bias_first_ten_percent_needles = [random.randint(0, int(s * 0.1)) for s in sizes]
bias_last_ten_percent_needles = [random.randint(int(s * 0.9), s - 1) for s in sizes]
bias_middle_ten_percent_needles = [
random.randint(int(s * 0.4), int(s * 0.6)) for s in sizes
]
# list
print("=================================")
print("List")
print("=================================")
for name, needles in (
("early", bias_first_ten_percent_needles),
("mid", bias_middle_ten_percent_needles),
("late", bias_last_ten_percent_needles),
):
print(f"{name}")
for needle_loc, size in zip(needles, sizes):
haystack = torch.randint(0, 100, (size,)).to(torch.int64).tolist()
haystack[needle_loc] = NEEDLE
_time = needle_haystack_list(NEEDLE, haystack)
print(f"{size:>8} : {_time} ns")
print("=================================")
print("Torch __contains__")
print("=================================")
# torch
for name, needles in (
("early", bias_first_ten_percent_needles),
("mid", bias_middle_ten_percent_needles),
("late", bias_last_ten_percent_needles),
):
print(f"{name}")
for needle_loc, size in zip(needles, sizes):
haystack = torch.randint(0, 100, (size,)).to(torch.int64)
haystack[needle_loc] = NEEDLE
_time = needle_haystack_torch_contains(NEEDLE, haystack)
print(f"{size:>8} : {_time} ns")
print("=================================")
print("Torch (branchless")
print("=================================")
# torch
for name, needles in (
("early", bias_first_ten_percent_needles),
("mid", bias_middle_ten_percent_needles),
("late", bias_last_ten_percent_needles),
):
print(f"{name}")
for needle_loc, size in zip(needles, sizes):
haystack = torch.randint(0, 100, (size,)).to(torch.int64)
haystack[needle_loc] = NEEDLE
_time = needle_haystack_torch(NEEDLE, haystack)
print(f"{size:>8} : {_time} ns")
# Assuming haystack is -1
print("=================================")
print("Assumed -1 Needle Torch")
print("=================================")
for name, needles in (
("early", bias_first_ten_percent_needles),
("mid", bias_middle_ten_percent_needles),
("late", bias_last_ten_percent_needles),
):
print(f"{name}")
for needle_loc, size in zip(needles, sizes):
haystack = torch.randint(0, 100, (size,)).to(torch.int64)
haystack[needle_loc] = NEEDLE
_time = needle_haystack_assumed_minus_one_needle(haystack)
print(f"{size:>8} : {_time} ns")
# Assuming haystack is -1 (32bit int)
print("=================================")
print("Assumed -1 Needle Torch (32bit)")
print("=================================")
for name, needles in (
("early", bias_first_ten_percent_needles),
("mid", bias_middle_ten_percent_needles),
("late", bias_last_ten_percent_needles),
):
print(f"{name}")
for needle_loc, size in zip(needles, sizes):
haystack = torch.randint(0, 100, (size,)).to(torch.int32)
haystack[needle_loc] = NEEDLE
_time = needle_haystack_assumed_minus_one_needle(haystack)
print(f"{size:>8} : {_time} ns")
=================================
List
=================================
early
10 : 114.9 ns
100 : 150.5 ns
1000 : 493.8 ns
10000 : 7026.0 ns
mid
10 : 144.5 ns
100 : 516.3 ns
1000 : 4377.6 ns
10000 : 37341.5 ns
late
10 : 183.1 ns
100 : 803.7 ns
1000 : 7243.2 ns
10000 : 70000.9 ns
=================================
Torch __contains__
=================================
early
10 : 5270.2 ns
100 : 5586.7 ns
1000 : 6733.4 ns
10000 : 18434.5 ns
mid
10 : 5317.3 ns
100 : 5547.9 ns
1000 : 6803.5 ns
10000 : 18253.8 ns
late
10 : 5482.2 ns
100 : 5674.8 ns
1000 : 6891.6 ns
10000 : 19183.8 ns
=================================
Torch (branchless
=================================
early
10 : 7294.9 ns
100 : 7318.5 ns
1000 : 8565.4 ns
10000 : 16958.5 ns
mid
10 : 7301.2 ns
100 : 7333.3 ns
1000 : 8558.1 ns
10000 : 17051.4 ns
late
10 : 7274.5 ns
100 : 8327.8 ns
1000 : 8515.6 ns
10000 : 17387.1 ns
=================================
Assumed -1 Needle Torch
=================================
early
10 : 4597.8 ns
100 : 4610.9 ns
1000 : 5755.6 ns
10000 : 17368.8 ns
mid
10 : 4594.5 ns
100 : 4732.0 ns
1000 : 5830.6 ns
10000 : 15438.9 ns
late
10 : 4621.7 ns
100 : 4710.7 ns
1000 : 5925.3 ns
10000 : 15835.6 ns
=================================
Assumed -1 Needle Torch (32bit)
=================================
early
10 : 4641.2 ns
100 : 4812.4 ns
1000 : 5747.4 ns
10000 : 15771.1 ns
mid
10 : 4804.2 ns
100 : 5974.5 ns
1000 : 5961.7 ns
10000 : 14754.5 ns
late
10 : 4773.3 ns
100 : 4962.7 ns
1000 : 7646.2 ns
10000 : 15066.1 ns
Update from meeting:
def set_graph_lookup(...):
...
# Save the current graph lookup and set the new graph lookup
for kern in modules:
if isinstance(kern, TorchWLKernel):
kern._get_node_neighbors.cache_clear()
kern._wl_iteration.cache_clear()
elif isinstance(kern, BoTorchWLKernel):
kern._compute_kernel.cache_clear()
kernel_prev_graphs.append((kern, kern.graph_lookup))
if append:
kern.set_graph_lookup([*kern.graph_lookup, *new_graphs])
else:
kern.set_graph_lookup(new_graphs)
def optimize_acqf_graph(
...,
sampled_graphs, # We would likely want to either pass in sampled graphs,
# or otherwise have some graph sampler to pass in.
# For now we just leave it as is, it doesn't work generically
# but we need the `Grammar` Parameter first to define this so it's fine.
):
# Due to limitations BoTorch only supporting tensors, we
# use a column of the input tensor as indices into a list of graphs.
# The `with set_graph_lookup()` contextmanager injects the test graph into the
# kernel, i.e. theses are what the indices in the tensor refer to.
# We set the index to `-1` to indicate the test graph, i.e. the last one in `graph_lookup`
# Since we set the test graph in a loop, the `best_candidates[best_idx]` will
# always contain a `-1`, which says (oh the only sample to consider was the best one), but
# this `-1` means something different at each iteration.
# tldr; we should save `best_graph` seperately in the for loop for now.
# The [best_idx, :-1] effectively removes the `-1` graph index column
return best_candidates[best_idx, :-1], best_graph, best_scores[best_idx].item()
Ultimately, we don't have the infrastructure to actually use this function yet so it's fine if it's not perfect, the implementation will evolve slowly, as long as the core loop works, we're all good. @timurcarstensen We have a basic Botorch kernel that can jointly optimize nx.graphs and other parameters, it's just not hooked up to the system as a whole yet as we don't know what that looks like yet. |
The |
Ahh sorry, you can do |
This pull request introduces a custom PyTorch implementation of the Weisfeiler-Lehman (WL) kernel to replace the existing Grakel-based implementation. The most important changes include adding new example scripts, creating the custom WL kernel class, and adding tests to ensure correctness and compatibility.
New Implementation of Weisfeiler-Lehman Kernel:
grakel_replace/torch_wl_kernel.py
: Implemented a custom PyTorch classTorchWLKernel
for the WL kernel, including methods to convert NetworkX graphs to sparse adjacency tensors, initialize node labels, perform WL iterations, and compute the kernel matrix.