diff --git a/rapids_dask_dependency/dask_loader.py b/rapids_dask_dependency/dask_loader.py index da61d3e..9d2492d 100644 --- a/rapids_dask_dependency/dask_loader.py +++ b/rapids_dask_dependency/dask_loader.py @@ -7,7 +7,7 @@ import sys from contextlib import contextmanager -from rapids_dask_dependency.utils import patch_warning_stacklevel +from rapids_dask_dependency.utils import _patching_enabled, patch_warning_stacklevel class DaskLoader(importlib.machinery.SourceFileLoader): @@ -59,6 +59,8 @@ def disable(self, name): def find_spec(self, fullname: str, _, __=None): if fullname in self._blocklist: return None + if not _patching_enabled(): + return None if ( fullname in ("dask", "distributed") or fullname.startswith("dask.") diff --git a/rapids_dask_dependency/utils.py b/rapids_dask_dependency/utils.py index 3adc184..5801da3 100644 --- a/rapids_dask_dependency/utils.py +++ b/rapids_dask_dependency/utils.py @@ -1,5 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. +import os import warnings from contextlib import contextmanager from functools import lru_cache @@ -24,3 +25,30 @@ def patch_warning_stacklevel(level): warnings.warn = _make_warning_func(level) yield warnings.warn = previous_warn + + +# Default patching behavior depends on the value of the +# `RAPIDS_DASK_PATCHING` environment variable. If this +# environment variable does not exist, patching will be +# enabled. Otherwise, this variable must be set to +# `'True'` for patching to be enabled. + + +_env = "RAPIDS_DASK_PATCHING" + + +def _patching_enabled() -> bool: + return os.environ.get(_env, "True") == "True" + + +@contextmanager +def patching_context(enabled: bool = True): + original = os.environ.get(_env) + os.environ[_env] = "True" if enabled else "False" + try: + yield + finally: + if original is None: + os.environ.pop(_env, None) + else: + os.environ[_env] = "True" if original else "False" diff --git a/tests/test_patch.py b/tests/test_patch.py index 586cdce..cc5c922 100644 --- a/tests/test_patch.py +++ b/tests/test_patch.py @@ -102,3 +102,15 @@ def test_distributed_cli_dask_spec_as_module(): print(e.stdout.decode()) print(e.stderr.decode()) raise + + +@run_test_in_subprocess +def test_dask_patching_disabled(): + from rapids_dask_dependency.utils import patching_context + + with patching_context(enabled=False): + import dask + import distributed + + assert not hasattr(dask, "_rapids_patched") + assert not hasattr(distributed, "_rapids_patched")