Skip to content

Commit

Permalink
[BUG] Skip WholeGraph Tests if GPU PyTorch Unavailable (#4820)
Browse files Browse the repository at this point in the history
Skips WholeGraph tests if GPU PyTorch is not available.  Required to get tests passing on ARM.  In the future, we should move all WholeGraph-dependent code, as well as the bulk sampling API, into `cugraph-gnn` so these errors do not continue.

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - James Lamb (https://github.com/jameslamb)

Approvers:
  - James Lamb (https://github.com/jameslamb)
  - Don Acosta (https://github.com/acostadon)
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4820
  • Loading branch information
alexbarghi-nv authored Jan 3, 2025
1 parent dd34c15 commit 04f0984
Showing 1 changed file with 15 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -32,6 +32,20 @@ def get_cudart_version():
return major * 1000 + minor * 10


pytestmark = [
pytest.mark.skipif(
isinstance(torch, MissingModule) or not torch.cuda.is_available(),
reason="PyTorch with GPU support not available",
),
pytest.mark.skipif(
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
),
pytest.mark.skipif(
get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8"
),
]


def runtest(rank: int, world_size: int):
torch.cuda.set_device(rank)

Expand Down Expand Up @@ -69,13 +83,6 @@ def runtest(rank: int, world_size: int):


@pytest.mark.sg
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.skipif(
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
)
@pytest.mark.skipif(
get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8"
)
def test_feature_storage_wholegraph_backend():
world_size = torch.cuda.device_count()
print("gpu count:", world_size)
Expand All @@ -87,13 +94,6 @@ def test_feature_storage_wholegraph_backend():


@pytest.mark.mg
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.skipif(
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
)
@pytest.mark.skipif(
get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8"
)
def test_feature_storage_wholegraph_backend_mg():
world_size = torch.cuda.device_count()
print("gpu count:", world_size)
Expand Down

0 comments on commit 04f0984

Please sign in to comment.