Skip to content

Commit

Permalink
Move torch.jit.script check to test (#194)
Browse files Browse the repository at this point in the history
* update

* update

* update
  • Loading branch information
rusty1s authored Oct 11, 2023
1 parent 89b74f0 commit 29cd22b
Show file tree
Hide file tree
Showing 16 changed files with 54 additions and 22 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.0)
project(torchcluster)
set(CMAKE_CXX_STANDARD 14)
set(TORCHCLUSTER_VERSION 1.6.2)
set(TORCHCLUSTER_VERSION 1.6.3)

option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_PYTHON "Link to Python when building" ON)
Expand Down
2 changes: 1 addition & 1 deletion conda/pytorch-cluster/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package:
name: pytorch-cluster
version: 1.6.2
version: 1.6.3

source:
path: ../..
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension)

__version__ = '1.6.2'
__version__ = '1.6.3'
URL = 'https://github.com/rusty1s/pytorch_cluster'

WITH_CUDA = False
Expand Down
4 changes: 4 additions & 0 deletions test/test_graclus.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device):

cluster = graclus_cluster(row, col, weight)
assert_correct(row, col, cluster)

jit = torch.jit.script(graclus_cluster)
cluster = jit(row, col, weight)
assert_correct(row, col, cluster)
3 changes: 3 additions & 0 deletions test/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device):

cluster = grid_cluster(pos, size, start, end)
assert cluster.tolist() == test['cluster']

jit = torch.jit.script(grid_cluster)
assert torch.equal(jit(pos, size, start, end), cluster)
9 changes: 9 additions & 0 deletions test/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def test_knn(dtype, device):
edge_index = knn(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])

jit = torch.jit.script(knn)
edge_index = jit(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])

edge_index = knn(x, y, 2, batch_x, batch_y)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])

Expand Down Expand Up @@ -65,6 +69,11 @@ def test_knn_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])

jit = torch.jit.script(knn_graph)
edge_index = jit(x, k=2, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])


@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_knn_graph_large(dtype, device):
Expand Down
15 changes: 14 additions & 1 deletion test/test_radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def test_radius(dtype, device):
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])

jit = torch.jit.script(radius)
edge_index = jit(x, y, 2, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])

edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
(1, 6)])
Expand Down Expand Up @@ -64,12 +69,20 @@ def test_radius_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])

jit = torch.jit.script(radius_graph)
edge_index = jit(x, r=2.5, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])


@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device)

edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
edge_index = radius_graph(x,
r=0.5,
flow='target_to_source',
loop=True,
max_num_neighbors=2000)

tree = scipy.spatial.cKDTree(x.cpu().numpy())
Expand Down
3 changes: 3 additions & 0 deletions test/test_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def test_rw_small(device):
out = random_walk(row, col, start, walk_length, num_nodes=3)
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]

jit = torch.jit.script(random_walk)
assert torch.equal(jit(row, col, start, walk_length, num_nodes=3), out)


@pytest.mark.parametrize('device', devices)
def test_rw_large_with_edge_indices(device):
Expand Down
2 changes: 1 addition & 1 deletion torch_cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

__version__ = '1.6.2'
__version__ = '1.6.3'

for library in [
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest',
Expand Down
10 changes: 6 additions & 4 deletions torch_cluster/graclus.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import torch


@torch.jit.script
def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
weight: Optional[torch.Tensor] = None,
num_nodes: Optional[int] = None) -> torch.Tensor:
def graclus_cluster(
row: torch.Tensor,
col: torch.Tensor,
weight: Optional[torch.Tensor] = None,
num_nodes: Optional[int] = None,
) -> torch.Tensor:
"""A greedy clustering algorithm of picking an unmarked vertex and matching
it with one its unmarked neighbors (that maximizes its edge weight).
Expand Down
10 changes: 6 additions & 4 deletions torch_cluster/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import torch


@torch.jit.script
def grid_cluster(pos: torch.Tensor, size: torch.Tensor,
start: Optional[torch.Tensor] = None,
end: Optional[torch.Tensor] = None) -> torch.Tensor:
def grid_cluster(
pos: torch.Tensor,
size: torch.Tensor,
start: Optional[torch.Tensor] = None,
end: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""A clustering algorithm, which overlays a regular grid of user-defined
size over a point cloud and clusters all points within a voxel.
Expand Down
2 changes: 0 additions & 2 deletions torch_cluster/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch


@torch.jit.script
def knn(
x: torch.Tensor,
y: torch.Tensor,
Expand Down Expand Up @@ -83,7 +82,6 @@ def knn(
num_workers)


@torch.jit.script
def knn_graph(
x: torch.Tensor,
k: int,
Expand Down
2 changes: 0 additions & 2 deletions torch_cluster/radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch


@torch.jit.script
def radius(
x: torch.Tensor,
y: torch.Tensor,
Expand Down Expand Up @@ -84,7 +83,6 @@ def radius(
max_num_neighbors, num_workers)


@torch.jit.script
def radius_graph(
x: torch.Tensor,
r: float,
Expand Down
4 changes: 1 addition & 3 deletions torch_cluster/rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch import Tensor


@torch.jit.script
def random_walk(
row: Tensor,
col: Tensor,
Expand Down Expand Up @@ -55,8 +54,7 @@ def random_walk(
torch.cumsum(deg, 0, out=rowptr[1:])

node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
rowptr, col, start, walk_length, p, q,
)
rowptr, col, start, walk_length, p, q)

if return_edge_indices:
return node_seq, edge_seq
Expand Down
1 change: 0 additions & 1 deletion torch_cluster/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch


@torch.jit.script
def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float):
assert not start.is_cuda

Expand Down
5 changes: 4 additions & 1 deletion torch_cluster/typing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import torch

WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')
try:
WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')
except Exception:
WITH_PTR_LIST = False

0 comments on commit 29cd22b

Please sign in to comment.