Skip to content

Commit

Permalink
New precheck procedure to enhance stability. (#1453)
Browse files Browse the repository at this point in the history
* add precheck operator definition

* done basic impl and test

* add design doc

* add args support

* done basic impl and ut

* fixed

* ut fix

* ut fix

* upgrade patch cov

* add debug log

* lint

* fixed

* fix and optimization

* ut fix

* ut fix and doc updated

* lint

* lint
  • Loading branch information
BalaBalaYi authored Feb 8, 2025
1 parent 09a46f9 commit e33cfb4
Show file tree
Hide file tree
Showing 27 changed files with 543 additions and 13 deletions.
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ coverage:
flag_coverage_not_uploaded_behavior: include
patch:
default:
target: 80%
target: 85%
threshold: 3%
removed_code_behavior: fully_covered_patch
project:
Expand Down
12 changes: 12 additions & 0 deletions dlrover/python/common/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ class HeartBeat(Message):
timestamp: int = 0


@dataclass
class PreCheckRequest(Message):
timestamp: int = 0
type: str = "INITIAL"


@dataclass
class PreCheckResponse(Message):
status: str = ""
reason: str = ""


@dataclass
class DatasetShardParams(Message):
batch_size: int = 0
Expand Down
11 changes: 11 additions & 0 deletions dlrover/python/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ class JobConstant(object):
# sleep 5s before next node check round
NODE_CHECK_NEXT_ROUND_TIMEOUT = 5

# default interval seconds for loop in training agent
TRAINING_AGENT_LOOP_DEFAULT_INTERVAL = 15

# sleep 5s before next rendezvous round
Expand All @@ -393,6 +394,9 @@ class JobConstant(object):
# sleep 5s before next port synchronization
SYNC_PORTS_DEFAULT_INTERVAL = 5

# interval seconds for pre-check waiting
PRE_CHECK_WAIT_SECS = 5


class Accelerators(object):
NVIDIA_GPU = "nvidia.com/gpu"
Expand Down Expand Up @@ -435,3 +439,10 @@ class ErrorMonitorConstants(object):
ACTION_RESUME_MEM_CKPT_START = "resume_mem_ckpt_start"
ACTION_RESUME_MEM_CKPT_COMPLETE = "resume_mem_ckpt_complete"
ACTION_HANG_WARN = "hang_warning"


class PreCheckStatus(object):
CHECKING = "CHECKING"
FAIL = "FAIL"
PASS = "PASS"
DISABLED = "DISABLED"
4 changes: 3 additions & 1 deletion dlrover/python/common/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class DefaultValues(object):
GPU_NUM_PER_NODE = 8
NPU_NUM_PER_NODE = 16
MAX_METRIC_REC = 30
PRE_CHECK_ENABLED = True


class Context(Singleton):
Expand Down Expand Up @@ -115,8 +116,9 @@ def __init__(self):
self.hang_detection = DefaultValues.HANG_DETECTION
# The duration of downtime as training hang, unit is minute
self.hang_downtime = DefaultValues.HANG_DOWNTIME
#
# The default xpu device type.
self.xpu_type = Accelerators.NVIDIA_GPU
self.pre_check_enabled = DefaultValues.PRE_CHECK_ENABLED
self.gpu_per_node = DefaultValues.GPU_NUM_PER_NODE
self.npu_per_node = DefaultValues.NPU_NUM_PER_NODE

Expand Down
4 changes: 4 additions & 0 deletions dlrover/python/diagnosis/common/diagnosis_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def clear(self):
with self._lock:
self._actions.clear()

def len(self):
with self._lock:
return sum(len(d) for d in self._actions.values())

def next_action(
self,
instance=DiagnosisConstant.LOCAL_INSTANCE,
Expand Down
5 changes: 5 additions & 0 deletions dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@ def get_elastic_run_config(self) -> Dict[str, str]:
response: comm.ElasticRunConfig = self._get(request)
return response.configs

def get_pre_check_result(self) -> str:
request = comm.PreCheckRequest()
response: comm.PreCheckResponse = self._get(request)
return response.status

def report_event(
self,
event_type: str = "",
Expand Down
18 changes: 18 additions & 0 deletions dlrover/python/master/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@
from dlrover.python.common.log import default_logger as logger


def str2bool(value):
if isinstance(value, bool):
return value
if value.lower() in {"true", "yes", "t", "y", "1"}:
return True
elif value.lower() in {"false", "no", "n", "0"}:
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")


def add_params(parser):
parser.add_argument("--job_name", help="ElasticJob name", required=True)
parser.add_argument(
Expand Down Expand Up @@ -54,6 +65,13 @@ def add_params(parser):
type=str,
help="The service type of master: grpc/http.",
)
parser.add_argument(
"--pre_check",
"--pre_check",
default=DefaultValues.PRE_CHECK_ENABLED,
type=str2bool,
help="Enable pre training check or not.",
)


def print_args(args, exclude_args=[], groups=None):
Expand Down
3 changes: 3 additions & 0 deletions dlrover/python/master/diagnosis/diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def get_observing_operators(self) -> List[InferenceOperator]:
def get_resolving_operators(self) -> List[InferenceOperator]:
return self._resolvers

def register_precheck_(self, problems: List[Inference]):
self._training_problems = problems

def register_training_problems(self, problems: List[Inference]):
self._training_problems = problems

Expand Down
78 changes: 75 additions & 3 deletions dlrover/python/master/diagnosis/diagnosis_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Accelerators,
GpuMetricEnum,
NpuMetricEnum,
PreCheckStatus,
)
from dlrover.python.common.global_context import Context, DefaultValues
from dlrover.python.common.log import default_logger as logger
Expand Down Expand Up @@ -52,6 +53,9 @@
from dlrover.python.master.diagnosis.diagnosis_data_manager import (
DiagnosisDataManager,
)
from dlrover.python.master.diagnosis.precheck_operator import (
NoPreCheckOperator,
)
from dlrover.python.master.node.job_context import get_job_context

_metric_context = JobMetricContext.singleton_instance()
Expand All @@ -73,12 +77,80 @@ def __init__(self, job_name=None):
self._metric_monitor = None
self._lock = threading.Lock()

@classmethod
def get_pre_check_operators(cls):
return [NoPreCheckOperator()]

def collect_diagnosis_data(self, data: DiagnosisData):
self._data_manager.store_data(data)

def pre_check(self):
logger.info("Start Diagnosis Manager to pre-check training...")
pass
if not _dlrover_context.pre_check_enabled:
return

start = time.time()
pre_check_ops = self.get_pre_check_operators()
logger.info(
"Start to training pre-check with "
f"operators: {[op.__class__.__name__ for op in pre_check_ops]}."
)

for pre_check_op in pre_check_ops:
current_start = time.time()
current_op_result = None
pre_check_op_name = pre_check_op.__class__.__name__

try:
# retry loops for each operator
for i in range(pre_check_op.get_retry_times()):
check_start = time.time()

# do check
current_op_result = pre_check_op.check()
logger.info(
f"{pre_check_op_name} "
f"check({i}) "
f"cost: {time.time()-check_start:.2f}ms, "
f"result: {current_op_result}"
)

if not current_op_result.is_success():
# try recover and wait
pre_check_op.recover()
time.sleep(pre_check_op.get_retry_interval_secs())

# check again after recover
current_op_result = pre_check_op.check()
else:
break
except Exception as e:
logger.error(f"Pre-check operator got unexpected error: {e}")
continue

if not current_op_result.is_success():
action = pre_check_op.get_failed_action()
self._job_context.enqueue_action(action)
logger.warning(
"Training pre-check failed "
f"by {pre_check_op_name} "
f"with result: {current_op_result}, "
f"cost:{time.time()-current_start:.2f}ms. "
f"Invoke action: {action}."
)
self._job_context.set_pre_check_status(PreCheckStatus.FAIL)
return
else:
self._job_context.set_pre_check_status(PreCheckStatus.CHECKING)
logger.info(
f"{pre_check_op_name} finish "
f"with result: {current_op_result}, "
f"cost:{time.time()-current_start:.2f}ms."
)

self._job_context.set_pre_check_status(PreCheckStatus.PASS)
logger.info(
f"Training pre-check complete, cost:{time.time()-start:.2f}ms."
)

def start_metric_collect(self):
"""
Expand Down Expand Up @@ -127,7 +199,7 @@ def join_metric_collect(self):
self._metric_monitor.join()

def start_observing(self):
logger.info("Start diagnosis manager training observation...")
logger.info("Start to observing training...")
self._is_observing_started = True

self._diagnostician.register_training_problems(
Expand Down
82 changes: 82 additions & 0 deletions dlrover/python/master/diagnosis/precheck_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2025 The DLRover Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List

from dlrover.python.diagnosis.common.diagnosis_action import (
DiagnosisAction,
NoAction,
)


@dataclass
class PreCheckResult(object):

# The default success result is 0. The other result code(>0) should be
# defined by different pre-check operator it's self.
result: int = 0

# The simple description info for the result.
result_msg: str = ""

# Abnormal nodes' id.
abnormal_nodes: List[int] = field(default_factory=list)

def is_success(self):
return self.result == 0


class PreCheckOperator(ABC):
@classmethod
def get_retry_interval_secs(cls) -> int:
"""The retry interval seconds, can be overridden in subclasses."""
return 5

@classmethod
def get_retry_times(cls) -> int:
"""
The limited retry times, can be overridden in subclasses. For most
pre-check, the retry value should > 1(at least once retry).
The failed action will be executed if result still not ok after
several retry times.
"""
return 3

@abstractmethod
def check(self) -> PreCheckResult:
"""The abstraction of the main check procedure."""
pass

@abstractmethod
def recover(self):
"""The abstraction of the procedure if check failed."""
pass

@abstractmethod
def get_failed_action(self) -> DiagnosisAction:
"""The abstraction of the action when operator check failed."""
pass


class NoPreCheckOperator(PreCheckOperator):
def check(self):
return PreCheckResult()

def recover(self):
return

def get_failed_action(self) -> DiagnosisAction:
return NoAction()
1 change: 0 additions & 1 deletion dlrover/python/master/dist_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def prepare(self):
def pre_check(self):
logger.info("Pre-check before running.")
self.diagnosis_manager.pre_check()
# TODO

def _add_node_event_callback(self):
"""Add NodeEventCallbacks for the listeners of Pod events."""
Expand Down
1 change: 1 addition & 0 deletions dlrover/python/master/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def run(args):
update_context(job_args)
master = DistributedJobMaster(_dlrover_context.master_port, job_args)
master.prepare()
master.pre_check()
return master.run()


Expand Down
14 changes: 13 additions & 1 deletion dlrover/python/master/node/job_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import time
from typing import Dict, Optional, Union

from dlrover.python.common.constants import NodeType
from dlrover.python.common.constants import NodeType, PreCheckStatus
from dlrover.python.common.global_context import Context
from dlrover.python.common.node import Node
from dlrover.python.common.singleton import Singleton
from dlrover.python.diagnosis.common.constants import (
Expand All @@ -27,6 +28,8 @@
DiagnosisActionQueue,
)

_dlrover_context = Context.singleton_instance()


class JobContext(Singleton):
"""
Expand All @@ -38,6 +41,7 @@ def __init__(self):
self._action_queue = DiagnosisActionQueue()
self._job_nodes: Dict[str, Dict[int, Node]] = {}
self._failed_nodes: Dict[int, int] = {}
self._pre_check_status: str = PreCheckStatus.CHECKING
self._locker = threading.Lock()

def enqueue_action(self, action):
Expand Down Expand Up @@ -193,6 +197,14 @@ def report_failed_node(self, node_id: Union[int, str] = None):
def get_failed_node_cnt(self):
return len(self._failed_nodes)

def set_pre_check_status(self, status: str):
self._pre_check_status = status

def get_pre_check_status(self):
if _dlrover_context.pre_check_enabled:
return self._pre_check_status
return PreCheckStatus.DISABLED


def get_job_context() -> JobContext:
job_context = JobContext.singleton_instance()
Expand Down
Loading

0 comments on commit e33cfb4

Please sign in to comment.