diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index aea81c152e0..c167da995c8 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -13,6 +13,7 @@ jobs: # Please keep pr-builder as the top job here pr-builder: needs: + - check-nightly-ci - changed-files - checks - conda-cpp-build @@ -42,6 +43,18 @@ jobs: - name: Telemetry setup if: ${{ vars.TELEMETRY_ENABLED == 'true' }} uses: rapidsai/shared-actions/telemetry-dispatch-stash-base-env-vars@main + check-nightly-ci: + # Switch to ubuntu-latest once it defaults to a version of Ubuntu that + # provides at least Python 3.11 (see + # https://docs.python.org/3/library/datetime.html#datetime.date.fromisoformat) + runs-on: ubuntu-24.04 + env: + RAPIDS_GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - name: Check if nightly CI is passing + uses: rapidsai/shared-actions/check_nightly_success/dispatch@main + with: + repo: cugraph changed-files: secrets: inherit needs: telemetry-setup diff --git a/ci/notebook_list.py b/ci/notebook_list.py index f7a284beeeb..659ac4de755 100644 --- a/ci/notebook_list.py +++ b/ci/notebook_list.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,7 @@ import glob from pathlib import Path -from numba import cuda +from cuda.bindings import runtime # for adding another run type and skip file name add to this dictionary runtype_dict = { @@ -30,20 +30,27 @@ def skip_book_dir(runtype): # Add all run types here, currently only CI supported + return runtype in runtype_dict and Path(runtype_dict.get(runtype)).is_file() - if runtype in runtype_dict.keys(): - if Path(runtype_dict.get(runtype)).is_file(): - return True - return False +def _get_cuda_version_string(): + status, version = runtime.getLocalRuntimeVersion() + if status != runtime.cudaError_t.cudaSuccess: + raise RuntimeError("Could not get CUDA runtime version.") + major, minor = divmod(version, 1000) + minor //= 10 + return f"{major}.{minor}" + + +def _is_ampere_or_newer(): + status, device_id = runtime.cudaGetDevice() + if status != runtime.cudaError_t.cudaSuccess: + raise RuntimeError("Could not get CUDA device.") + status, device_prop = runtime.cudaGetDeviceProperties(device_id) + if status != runtime.cudaError_t.cudaSuccess: + raise RuntimeError("Could not get CUDA device properties.") + return (device_prop.major, device_prop.minor) >= (8, 0) -cuda_version_string = ".".join([str(n) for n in cuda.runtime.get_version()]) -# -# Not strictly true... however what we mean is -# Pascal or earlier -# -ampere = False -device = cuda.get_current_device() parser = argparse.ArgumentParser(description="Condition for running the notebook tests") parser.add_argument("runtype", type=str) @@ -52,19 +59,10 @@ def skip_book_dir(runtype): runtype = args.runtype -if runtype not in runtype_dict.keys(): +if runtype not in runtype_dict: print(f"Unknown Run Type = {runtype}", file=sys.stderr) exit() - -# check for the attribute using both pre and post numba 0.53 names -cc = getattr(device, "COMPUTE_CAPABILITY", None) or getattr( - device, "compute_capability" -) -if cc[0] >= 8: - ampere = True - -skip = False for filename in glob.iglob("**/*.ipynb", recursive=True): skip = False if skip_book_dir(runtype): @@ -88,7 +86,7 @@ def skip_book_dir(runtype): ) skip = True break - elif ampere and re.search("# Does not run on Ampere", line): + elif _is_ampere_or_newer() and re.search("# Does not run on Ampere", line): print(f"SKIPPING {filename} (does not run on Ampere)", file=sys.stderr) skip = True break diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 8f56372c81a..62633d95c64 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -1,5 +1,5 @@ #============================================================================= -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ function(find_and_configure_raft) endif() rapids_cpm_find(raft ${PKG_VERSION} - GLOBAL_TARGETS raft::raft + GLOBAL_TARGETS raft::raft raft::raft_logger raft::raft_logger_impl BUILD_EXPORT_SET cugraph-exports INSTALL_EXPORT_SET cugraph-exports COMPONENTS ${RAFT_COMPONENTS} diff --git a/cpp/src/c_api/neighbor_sampling.cpp b/cpp/src/c_api/neighbor_sampling.cpp index 927d7d9c769..9c6c1f0f021 100644 --- a/cpp/src/c_api/neighbor_sampling.cpp +++ b/cpp/src/c_api/neighbor_sampling.cpp @@ -948,11 +948,9 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { std::exclusive_scan( recvcounts.begin(), recvcounts.end(), displacements.begin(), size_t{0}); - rmm::device_uvector tmp_label_to_comm_rank( + label_to_comm_rank = rmm::device_uvector( displacements.back() + recvcounts.back(), handle_.get_stream()); - label_to_comm_rank = std::move(tmp_label_to_comm_rank); - cugraph::device_allgatherv(handle_.get_comms(), local_label_to_comm_rank.begin(), (*label_to_comm_rank).begin(), diff --git a/python/cugraph/cugraph/dask/common/mg_utils.py b/python/cugraph/cugraph/dask/common/mg_utils.py index b04f293dc0e..e4e3ac9a44e 100644 --- a/python/cugraph/cugraph/dask/common/mg_utils.py +++ b/python/cugraph/cugraph/dask/common/mg_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,7 +13,7 @@ import os import gc -import numba.cuda +from cuda.bindings import runtime # FIXME: this raft import breaks the library if ucx-py is @@ -53,11 +53,10 @@ def prepare_worker_to_parts(data, client=None): def is_single_gpu(): - ngpus = len(numba.cuda.gpus) - if ngpus > 1: - return False - else: - return True + status, count = runtime.cudaGetDeviceCount() + if status != runtime.cudaError_t.cudaSuccess: + raise RuntimeError("Could not get CUDA device count.") + return count > 1 def get_visible_devices(): diff --git a/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py b/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py index f760ef3e1ba..964449276a2 100644 --- a/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py +++ b/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,6 +32,20 @@ def get_cudart_version(): return major * 1000 + minor * 10 +pytestmark = [ + pytest.mark.skipif( + isinstance(torch, MissingModule) or not torch.cuda.is_available(), + reason="PyTorch with GPU support not available", + ), + pytest.mark.skipif( + isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" + ), + pytest.mark.skipif( + get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8" + ), +] + + def runtest(rank: int, world_size: int): torch.cuda.set_device(rank) @@ -69,13 +83,6 @@ def runtest(rank: int, world_size: int): @pytest.mark.sg -@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") -@pytest.mark.skipif( - isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" -) -@pytest.mark.skipif( - get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8" -) def test_feature_storage_wholegraph_backend(): world_size = torch.cuda.device_count() print("gpu count:", world_size) @@ -87,13 +94,6 @@ def test_feature_storage_wholegraph_backend(): @pytest.mark.mg -@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") -@pytest.mark.skipif( - isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" -) -@pytest.mark.skipif( - get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8" -) def test_feature_storage_wholegraph_backend_mg(): world_size = torch.cuda.device_count() print("gpu count:", world_size) diff --git a/python/cugraph/cugraph/tests/docs/test_doctests.py b/python/cugraph/cugraph/tests/docs/test_doctests.py index 2095fd41fe9..9d9f8436b99 100644 --- a/python/cugraph/cugraph/tests/docs/test_doctests.py +++ b/python/cugraph/cugraph/tests/docs/test_doctests.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,14 +25,21 @@ import cugraph import pylibcugraph import cudf -from numba import cuda +from cuda.bindings import runtime from cugraph.testing import utils modules_to_skip = ["dask", "proto", "raft"] datasets = utils.RAPIDS_DATASET_ROOT_DIR_PATH -cuda_version_string = ".".join([str(n) for n in cuda.runtime.get_version()]) + +def _get_cuda_version_string(): + status, version = runtime.getLocalRuntimeVersion() + if status != runtime.cudaError_t.cudaSuccess: + raise RuntimeError("Could not get CUDA runtime version.") + major = version // 1000 + minor = (version % 1000) // 10 + return f"{major}.{minor}" def _is_public_name(name): @@ -131,6 +138,7 @@ def skip_docstring(docstring_obj): NOTE: this function is currently not available on CUDA 11.4 systems. """ docstring = docstring_obj.docstring + cuda_version_string = _get_cuda_version_string() for line in docstring.splitlines(): if f"currently not available on CUDA {cuda_version_string} systems" in line: return f"docstring example not supported on CUDA {cuda_version_string}" diff --git a/python/cugraph/cugraph/utilities/path_retrieval_wrapper.pyx b/python/cugraph/cugraph/utilities/path_retrieval_wrapper.pyx index 98d11ad07df..8e71c7aae4e 100644 --- a/python/cugraph/cugraph/utilities/path_retrieval_wrapper.pyx +++ b/python/cugraph/cugraph/utilities/path_retrieval_wrapper.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,7 +19,6 @@ from cugraph.utilities.path_retrieval cimport get_traversed_cost as c_get_traversed_cost from cugraph.structure.graph_primtypes cimport * from libc.stdint cimport uintptr_t -from numba import cuda import cudf import numpy as np diff --git a/python/cugraph/cugraph/utilities/utils.py b/python/cugraph/cugraph/utilities/utils.py index 0257da4ffc0..074503e2f60 100644 --- a/python/cugraph/cugraph/utilities/utils.py +++ b/python/cugraph/cugraph/utilities/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2024, NVIDIA CORPORATION. +# Copyright (c) 2020-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,13 +15,10 @@ import os import shutil -from numba import cuda - import cudf from cudf.core.column import as_column -from cuda.cudart import cudaDeviceAttr -from rmm._cuda.gpu import getDeviceAttribute +from cuda.bindings import runtime from warnings import warn @@ -210,45 +207,42 @@ def get_traversed_path_list(df, id): return answer -def is_cuda_version_less_than(min_version=(10, 2)): +def is_cuda_version_less_than(min_version): """ Returns True if the version of CUDA being used is less than min_version """ - this_cuda_ver = cuda.runtime.get_version() # returns (, ) - if this_cuda_ver[0] > min_version[0]: - return False - if this_cuda_ver[0] < min_version[0]: - return True - if this_cuda_ver[1] < min_version[1]: - return True - return False + status, version = runtime.getLocalRuntimeVersion() + if status != runtime.cudaError_t.cudaSuccess: + raise RuntimeError("Could not get CUDA runtime version.") + major = version // 1000 + minor = (version % 1000) // 10 + return (major, minor) < min_version -def is_device_version_less_than(min_version=(7, 0)): +def is_device_version_less_than(min_version): """ Returns True if the version of CUDA being used is less than min_version """ - major_version = getDeviceAttribute( - cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, 0 - ) - minor_version = getDeviceAttribute( - cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, 0 - ) - if major_version > min_version[0]: - return False - if major_version < min_version[0]: - return True - if minor_version < min_version[1]: - return True - return False + status, device_id = runtime.cudaGetDevice() + if status != runtime.cudaError_t.cudaSuccess: + raise RuntimeError("Could not get CUDA device.") + status, device_prop = runtime.cudaGetDeviceProperties(device_id) + if status != runtime.cudaError_t.cudaSuccess: + raise RuntimeError("Could not get CUDA device properties.") + return (device_prop.major, device_prop.minor) < min_version def get_device_memory_info(): """ Returns the total amount of global memory on the device in bytes """ - meminfo = cuda.current_context().get_memory_info() - return meminfo[1] + status, device_id = runtime.cudaGetDevice() + if status != runtime.cudaError_t.cudaSuccess: + raise RuntimeError("Could not get CUDA device.") + status, device_prop = runtime.cudaGetDeviceProperties(device_id) + if status != runtime.cudaError_t.cudaSuccess: + raise RuntimeError("Could not get CUDA device properties.") + return device_prop.totalGlobalMem # FIXME: if G is a Nx type, the weight attribute is assumed to be "weight", if diff --git a/scripts/dask/README.md b/scripts/dask/README.md new file mode 100644 index 00000000000..0c8853351b2 --- /dev/null +++ b/scripts/dask/README.md @@ -0,0 +1,55 @@ +# Dask scripts for multi-GPU environments + +This directory contains tools for configuring environments for single-node or +multi-node, multi-gpu (SNMG or MNMG) Dask-based cugraph runs, currently +consisting of shell and python scripts. + +Users should also consult the multi-GPU utilities in the +`python/cugraph/cugraph/testing/mg_utils.py` module, specifically the +`start_dask_client()` function, to see how to create `client` and `cluster` +instances in Python code to access the corresponding Dask processes created by +the tools here. + + +### run-dask-process.sh + + This script is used to start the Dask scheduler and workers as needed. + + To start a scheduler and workers on a node, run it like this: + ``` + bash$ run-dask-process.sh scheduler workers + ``` + Once a scheduler is running on a node in the cluster, workers can be started + on other nodes in the cluster by running the script on each worker node like + this: + ``` + bash$ run-dask-process.sh workers + ``` + The env var SCHEDULER_FILE must be set to the location where the scheduler + will generate the scheduler JSON file. The same env var is used by the + workers to locate the generated scheduler JSON file for reading. + + The script will ensure the scheduler is started before the workers when both + are specified. + + Additional options can be specified for using different communication + mechanisms: + ``` + --tcp - initalize a TCP cluster (default) + --ucx - initialize a UCX cluster with NVLink + --ucxib | --ucx-ib - initialize a UCX cluster with InfiniBand+NVLink + ``` + Finally, the script can be run with `-h` or `--help` to see the full set of + options. + +### wait_for_workers.py + + This script can be used to ensure all workers that are expected to be present + in the cluster are up and running. This is useful for automation that sets up + the Dask cluster and cannot proceed until the Dask cluster is available + to accept tasks. + + This example waits for 16 workers to be present: + ``` + bash$ python wait_for_workers.py --scheduler-file-path=$SCHEDULER_FILE --num-expected-workers=16 + ``` diff --git a/scripts/dask/run-dask-process.sh b/scripts/dask/run-dask-process.sh new file mode 100755 index 00000000000..9eef17fc5e5 --- /dev/null +++ b/scripts/dask/run-dask-process.sh @@ -0,0 +1,274 @@ +#!/bin/bash +# Copyright (c) 2024-2025, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +################################################################################ +NUMARGS=$# +ARGS=$* +VALIDARGS="-h --help scheduler workers --tcp --ucx --ucxib --ucx-ib" +HELP="$0 [ ...] [ ...] + where is: + scheduler - start dask scheduler + workers - start dask workers + and is: + --tcp - initalize a TCP cluster (default) + --ucx - initialize a UCX cluster with NVLink + --ucxib | --ucx-ib - initialize a UCX cluster with InfiniBand+NVLink + -h | --help - print this text + + The cluster config order of precedence is any specification on the + command line (--tcp, --ucx, etc.) if provided, then the value of the + env var DASK_CLUSTER_CONFIG_TYPE if set, then the default value of TCP. + + The env var SCHEDULER_FILE must be set to the location of the dask scheduler + file that the scheduler will generate and the worker(s) will read. This + location must be accessible by the scheduler and workers, meaning a multi-node + configuration will need to set this to a location on a shared file system. +" + +# Default configuration variables. Most are defined using the bash := or :- +# syntax, which means they will be set only if they were previously unset in +# the environment. +WORKER_RMM_POOL_SIZE=${WORKER_RMM_POOL_SIZE:-12G} +DASK_CUDA_INTERFACE=${DASK_CUDA_INTERFACE:-ibp5s0f0} +DASK_SCHEDULER_PORT=${DASK_SCHEDULER_PORT:-8792} +DASK_DEVICE_MEMORY_LIMIT=${DASK_DEVICE_MEMORY_LIMIT:-auto} +DASK_HOST_MEMORY_LIMIT=${DASK_HOST_MEMORY_LIMIT:-auto} + +# Logs can be written to a specific location by setting the DASK_LOGS_DIR +# env var. If unset, all logs are created under a dir named after the +# current PID. +DASK_LOGS_DIR=${DASK_LOGS_DIR:-dask_logs-$$} +DASK_SCHEDULER_LOG=${DASK_LOGS_DIR}/scheduler_log.txt +DASK_WORKERS_LOG=${DASK_LOGS_DIR}/worker-${HOSTNAME}_log.txt + +# DASK_CLUSTER_CONFIG_TYPE defaults to the env var value if set, else TCP. CLI +# options to this script take precedence. Valid values are TCP, UCX, UCXIB +DASK_CLUSTER_CONFIG_TYPE=${DASK_CLUSTER_CONFIG_TYPE:-TCP} + + +################################################################################ +# FUNCTIONS + +numargs=$# +args=$* +hasArg () { + (( ${numargs} != 0 )) && (echo " ${args} " | grep -q " $1 ") +} + +logger_prefix=">>>> " +logger () { + if (( $# > 0 )) && [ "$1" == "-p" ]; then + shift + echo -e "${logger_prefix}$@" + else + echo -e "$(date --utc "+%D-%T.%N")_UTC${logger_prefix}$@" + fi +} + +buildTcpArgs () { + export DASK_DISTRIBUTED__COMM__TIMEOUTS__CONNECT="100s" + export DASK_DISTRIBUTED__COMM__TIMEOUTS__TCP="600s" + export DASK_DISTRIBUTED__COMM__RETRY__DELAY__MIN="1s" + export DASK_DISTRIBUTED__COMM__RETRY__DELAY__MAX="60s" + export DASK_DISTRIBUTED__WORKER__MEMORY__Terminate="False" + + SCHEDULER_ARGS="--protocol=tcp + --port=$DASK_SCHEDULER_PORT + --scheduler-file $SCHEDULER_FILE + " + + WORKER_ARGS="--rmm-pool-size=$WORKER_RMM_POOL_SIZE + --local-directory=/tmp/$LOGNAME + --scheduler-file=$SCHEDULER_FILE + --memory-limit=$DASK_HOST_MEMORY_LIMIT + --device-memory-limit=$DASK_DEVICE_MEMORY_LIMIT + " + +} + +buildUCXWithInfinibandArgs () { + export DASK_RMM__POOL_SIZE=0.5GB + export DASK_DISTRIBUTED__COMM__UCX__CREATE_CUDA_CONTEXT=True + + SCHEDULER_ARGS="--protocol=ucx + --port=$DASK_SCHEDULER_PORT + --interface=$DASK_CUDA_INTERFACE + --scheduler-file $SCHEDULER_FILE + " + + WORKER_ARGS="--interface=$DASK_CUDA_INTERFACE + --rmm-pool-size=$WORKER_RMM_POOL_SIZE + --rmm-async + --local-directory=/tmp/$LOGNAME + --scheduler-file=$SCHEDULER_FILE + --memory-limit=$DASK_HOST_MEMORY_LIMIT + --device-memory-limit=$DASK_DEVICE_MEMORY_LIMIT + " +} + +buildUCXwithoutInfinibandArgs () { + export UCX_TCP_CM_REUSEADDR=y + export UCX_MAX_RNDV_RAILS=1 + export UCX_TCP_TX_SEG_SIZE=8M + export UCX_TCP_RX_SEG_SIZE=8M + + export DASK_DISTRIBUTED__COMM__UCX__CUDA_COPY=True + export DASK_DISTRIBUTED__COMM__UCX__TCP=True + export DASK_DISTRIBUTED__COMM__UCX__NVLINK=True + export DASK_DISTRIBUTED__COMM__UCX__INFINIBAND=False + export DASK_DISTRIBUTED__COMM__UCX__RDMACM=False + export DASK_RMM__POOL_SIZE=0.5GB + + + SCHEDULER_ARGS="--protocol=ucx + --port=$DASK_SCHEDULER_PORT + --scheduler-file $SCHEDULER_FILE + " + + WORKER_ARGS="--enable-tcp-over-ucx + --enable-nvlink + --disable-infiniband + --disable-rdmacm + --rmm-pool-size=$WORKER_RMM_POOL_SIZE + --local-directory=/tmp/$LOGNAME + --scheduler-file=$SCHEDULER_FILE + --memory-limit=$DASK_HOST_MEMORY_LIMIT + --device-memory-limit=$DASK_DEVICE_MEMORY_LIMIT + " +} + +scheduler_pid="" +worker_pid="" +num_scheduler_tries=0 + +startScheduler () { + mkdir -p $(dirname $SCHEDULER_FILE) + echo "RUNNING: \"dask scheduler $SCHEDULER_ARGS\"" > $DASK_SCHEDULER_LOG + dask scheduler $SCHEDULER_ARGS >> $DASK_SCHEDULER_LOG 2>&1 & + scheduler_pid=$! +} + + +################################################################################ +# READ CLI OPTIONS + +START_SCHEDULER=0 +START_WORKERS=0 + +if (( ${NUMARGS} == 0 )); then + echo "${HELP}" + exit 0 +else + if hasArg -h || hasArg --help; then + echo "${HELP}" + exit 0 + fi + for a in ${ARGS}; do + if ! (echo " ${VALIDARGS} " | grep -q " ${a} "); then + echo "Invalid option: ${a}" + exit 1 + fi + done +fi + +if [ -z ${SCHEDULER_FILE+x} ]; then + echo "Env var SCHEDULER_FILE must be set. See -h for details" + exit 1 +fi + +if hasArg scheduler; then + START_SCHEDULER=1 +fi +if hasArg workers; then + START_WORKERS=1 +fi +# Allow the command line to take precedence +if hasArg --tcp; then + DASK_CLUSTER_CONFIG_TYPE=TCP +elif hasArg --ucx; then + DASK_CLUSTER_CONFIG_TYPE=UCX +elif hasArg --ucxib || hasArg --ucx-ib; then + DASK_CLUSTER_CONFIG_TYPE=UCXIB +fi + + +################################################################################ +# SETUP & RUN + +#export DASK_LOGGING__DISTRIBUTED="DEBUG" +#ulimit -n 100000 + +if [[ "$DASK_CLUSTER_CONFIG_TYPE" == "UCX" ]]; then + logger "Using cluster configurtion for UCX" + buildUCXwithoutInfinibandArgs +elif [[ "$DASK_CLUSTER_CONFIG_TYPE" == "UCXIB" ]]; then + logger "Using cluster configurtion for UCX with Infiniband" + buildUCXWithInfinibandArgs +else + logger "Using cluster configurtion for TCP" + buildTcpArgs +fi + +mkdir -p $DASK_LOGS_DIR +logger "Logs written to: $DASK_LOGS_DIR" + +if [[ $START_SCHEDULER == 1 ]]; then + rm -f $SCHEDULER_FILE $DASK_SCHEDULER_LOG $DASK_WORKERS_LOG + + startScheduler + sleep 6 + num_scheduler_tries=$(( num_scheduler_tries+1 )) + + # Wait for the scheduler to start first before proceeding, since + # it may require several retries (if prior run left ports open + # that need time to close, etc.) + while [ ! -f "$SCHEDULER_FILE" ]; do + scheduler_alive=$(ps -p $scheduler_pid > /dev/null ; echo $?) + if [[ $scheduler_alive != 0 ]]; then + if [[ $num_scheduler_tries != 30 ]]; then + logger "scheduler failed to start, retry #$num_scheduler_tries" + startScheduler + sleep 6 + num_scheduler_tries=$(( num_scheduler_tries+1 )) + else + logger "could not start scheduler, exiting." + exit 1 + fi + fi + done + logger "scheduler started." +fi + +if [[ $START_WORKERS == 1 ]]; then + rm -f $DASK_WORKERS_LOG + while [ ! -f "$SCHEDULER_FILE" ]; do + logger "run-dask-process.sh: $SCHEDULER_FILE not present - waiting to start workers..." + sleep 2 + done + echo "RUNNING: \"dask_cuda_worker $WORKER_ARGS\"" > $DASK_WORKERS_LOG + dask-cuda-worker $WORKER_ARGS >> $DASK_WORKERS_LOG 2>&1 & + worker_pid=$! + logger "worker(s) started." +fi + +# This script will not return until the following background process +# have been completed/killed. +if [[ $worker_pid != "" ]]; then + logger "waiting for worker pid $worker_pid to finish before exiting script..." + wait $worker_pid +fi +if [[ $scheduler_pid != "" ]]; then + logger "waiting for scheduler pid $scheduler_pid to finish before exiting script..." + wait $scheduler_pid +fi diff --git a/scripts/dask/wait_for_workers.py b/scripts/dask/wait_for_workers.py new file mode 100644 index 00000000000..931e991c4cf --- /dev/null +++ b/scripts/dask/wait_for_workers.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024-2025, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time + +from dask.distributed import Client + + +def initialize_dask_cuda(communication_type): + communication_type = communication_type.lower() + if "ucx" in communication_type: + os.environ["UCX_MAX_RNDV_RAILS"] = "1" + + if communication_type == "ucx-ib": + os.environ["UCX_MEMTYPE_REG_WHOLE_ALLOC_TYPES"]="cuda" + os.environ["DASK_RMM__POOL_SIZE"]="0.5GB" + os.environ["DASK_DISTRIBUTED__COMM__UCX__CREATE_CUDA_CONTEXT"]="True" + + +def wait_for_workers( + num_expected_workers, scheduler_file_path, communication_type, timeout_after=0 +): + """ + Waits until num_expected_workers workers are available based on + the workers managed by scheduler_file_path, then returns 0. If + timeout_after is specified, will return 1 if num_expected_workers + workers are not available before the timeout. + """ + # FIXME: use scheduler file path from global environment if none + # supplied in configuration yaml + + print("wait_for_workers.py - initializing client...", end="") + sys.stdout.flush() + initialize_dask_cuda(communication_type) + print("done.") + sys.stdout.flush() + + ready = False + start_time = time.time() + while not ready: + if timeout_after and ((time.time() - start_time) >= timeout_after): + print( + f"wait_for_workers.py timed out after {timeout_after} seconds before finding {num_expected_workers} workers." + ) + sys.stdout.flush() + break + with Client(scheduler_file=scheduler_file_path) as client: + num_workers = len(client.scheduler_info()["workers"]) + if num_workers < num_expected_workers: + print( + f"wait_for_workers.py expected {num_expected_workers} but got {num_workers}, waiting..." + ) + sys.stdout.flush() + time.sleep(5) + else: + print(f"wait_for_workers.py got {num_workers} workers, done.") + sys.stdout.flush() + ready = True + + if ready is False: + return 1 + return 0 + + +if __name__ == "__main__": + import argparse + + ap = argparse.ArgumentParser() + ap.add_argument( + "--num-expected-workers", + type=int, + required=False, + help="Number of workers to wait for. If not specified, " + "uses the NUM_WORKERS env var if set, otherwise defaults " + "to 16.", + ) + ap.add_argument( + "--scheduler-file-path", + type=str, + required=True, + help="Path to shared scheduler file to read.", + ) + ap.add_argument( + "--communication-type", + type=str, + default="tcp", + required=False, + help="Initiliaze dask_cuda based on the cluster communication type." + "Supported values are tcp(default), ucx, ucxib, ucx-ib.", + ) + ap.add_argument( + "--timeout-after", + type=int, + default=0, + required=False, + help="Number of seconds to wait for workers. " + "Default is 0 which means wait forever.", + ) + args = ap.parse_args() + + if args.num_expected_workers is None: + args.num_expected_workers = os.environ.get("NUM_WORKERS", 16) + + exitcode = wait_for_workers( + num_expected_workers=args.num_expected_workers, + scheduler_file_path=args.scheduler_file_path, + communication_type=args.communication_type, + timeout_after=args.timeout_after, + ) + + sys.exit(exitcode)