Skip to content

Commit

Permalink
[BUG] Skip WholeGraph Tests if GPU PyTorch Unavailable
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Dec 12, 2024
1 parent 5a33526 commit 4bc3bb9
Showing 1 changed file with 21 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import numpy as np
import os

import numba.cuda

from cugraph.gnn import FeatureStore

from cugraph.utilities.utils import import_optional, MissingModule
Expand All @@ -25,6 +27,11 @@
wgth = import_optional("pylibwholegraph.torch")


def get_cudart_version():
major, minor = numba.cuda.runtime.get_version()
return major * 1000 + minor * 10


def runtest(rank: int, world_size: int):
torch.cuda.set_device(rank)

Expand Down Expand Up @@ -62,10 +69,16 @@ def runtest(rank: int, world_size: int):


@pytest.mark.sg
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.skipif(
isinstance(torch, MissingModule) or not torch.cuda.is_available(),
reason="PyTorch with GPU support not available",
)
@pytest.mark.skipif(
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
)
@pytest.mark.skipif(
get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8"
)
def test_feature_storage_wholegraph_backend():
world_size = torch.cuda.device_count()
print("gpu count:", world_size)
Expand All @@ -77,10 +90,16 @@ def test_feature_storage_wholegraph_backend():


@pytest.mark.mg
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.skipif(
isinstance(torch, MissingModule) or not torch.cuda.is_available(),
reason="PyTorch with GPU support not available",
)
@pytest.mark.skipif(
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
)
@pytest.mark.skipif(
get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8"
)
def test_feature_storage_wholegraph_backend_mg():
world_size = torch.cuda.device_count()
print("gpu count:", world_size)
Expand Down

0 comments on commit 4bc3bb9

Please sign in to comment.