Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support http communication implement for DLRover Master and Agent. #1429

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b1bea1a
impl http server
BalaBalaYi Nov 21, 2024
5141164
Merge remote-tracking branch 'origin/master' into new_http_comm_impl
BalaBalaYi Dec 4, 2024
00f4a75
fix ut
BalaBalaYi Dec 4, 2024
5c10074
impl new servicer based on http
BalaBalaYi Dec 4, 2024
c253c0e
add deps
BalaBalaYi Dec 4, 2024
cb57a9f
fix ut
BalaBalaYi Dec 5, 2024
95869dd
fix ut
BalaBalaYi Dec 5, 2024
4c10514
stash
BalaBalaYi Dec 6, 2024
572859a
Merge remote-tracking branch 'origin/master' into new_http_comm_impl
BalaBalaYi Dec 30, 2024
d8f390c
fix
BalaBalaYi Jan 2, 2025
a3feea7
done master http server/client ut
BalaBalaYi Jan 3, 2025
1e28aca
fix http/grpc response
BalaBalaYi Jan 7, 2025
7274506
add args params
BalaBalaYi Jan 7, 2025
6c0f213
lint
BalaBalaYi Jan 8, 2025
2e2fdca
Merge remote-tracking branch 'origin/master' into new_http_comm_impl
BalaBalaYi Jan 9, 2025
a6862bc
merged
BalaBalaYi Jan 9, 2025
ca1973e
add deps
BalaBalaYi Jan 9, 2025
ed65caa
Merge remote-tracking branch 'origin/master' into new_http_comm_impl
BalaBalaYi Jan 10, 2025
6778f2a
ut fix
BalaBalaYi Jan 10, 2025
a95ec24
ut fix
BalaBalaYi Jan 10, 2025
59161e6
ut fix
BalaBalaYi Jan 10, 2025
b343935
ut fix
BalaBalaYi Jan 10, 2025
711fd9b
optimized http server
BalaBalaYi Jan 13, 2025
acfa58c
Merge branch 'master' into new_http_comm_impl
BalaBalaYi Jan 13, 2025
10e994c
add params support
BalaBalaYi Jan 14, 2025
e52d763
Merge remote-tracking branch 'origin/master' into new_http_comm_impl
BalaBalaYi Jan 14, 2025
085b2ef
Merge remote-tracking branch 'origin/master' into new_http_comm_impl
BalaBalaYi Jan 17, 2025
f1daf6b
Merge branch 'master' into new_http_comm_impl
BalaBalaYi Jan 20, 2025
cba7e7b
Merge remote-tracking branch 'origin/master' into new_http_comm_impl
BalaBalaYi Jan 20, 2025
5a2b2a6
updated
BalaBalaYi Jan 20, 2025
3e21672
Merge remote-tracking branch 'origin/new_http_comm_impl' into new_htt…
BalaBalaYi Jan 20, 2025
198f061
lint
BalaBalaYi Jan 20, 2025
40f7197
Merge branch 'master' into new_http_comm_impl
BalaBalaYi Jan 23, 2025
e4a33a1
Merge branch 'master' into new_http_comm_impl
BalaBalaYi Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ runs:
args:
- "/bin/bash"
- "-c"
- " python -m grpc_tools.protoc -I. \
- "sh scripts/ci_install.sh basic && python -m grpc_tools.protoc -I. \
dlrover/proto/*.proto --python_out=. --grpc_python_out=. \
&& export PYTHONPATH=`pwd` \
&& cd examples/tensorflow/criteo_deeprec\
Expand Down
2 changes: 1 addition & 1 deletion .github/actions/dlrover-system-test-deepfm/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ runs:
args:
- "/bin/bash"
- "-c"
- " python -m grpc_tools.protoc -I. \
- "sh scripts/ci_install.sh basic && python -m grpc_tools.protoc -I. \
dlrover/proto/*.proto --python_out=. --grpc_python_out=. \
&& pip install deepctr deprecated\
&& export PYTHONPATH=`pwd` \
Expand Down
3 changes: 1 addition & 2 deletions .github/actions/dlrover-system-test-tf2/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ runs:
args:
- "/bin/bash"
- "-c"
- "pip install protobuf==3.20 kubernetes grpcio-tools psutil deprecated\
&& python -m grpc_tools.protoc -I. \
- "sh scripts/ci_install.sh basic && python -m grpc_tools.protoc -I. \
dlrover/proto/*.proto --python_out=. --grpc_python_out=. \
&& pip install deepctr \
&& pip install h5py==3.7.0 \
Expand Down
4 changes: 2 additions & 2 deletions dlrover/python/brain/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os

from dlrover.proto import brain_pb2, brain_pb2_grpc
from dlrover.python.common.grpc import build_channel, grpc_server_ready
from dlrover.python.common.comm import build_grpc_channel, grpc_server_ready
from dlrover.python.common.log import default_logger as logger

DATA_STORE = "base_datastore"
Expand Down Expand Up @@ -268,7 +268,7 @@ def build_brain_client():
```
"""
brain_addr = os.getenv(_ENV_BRAIN_ADDR_KEY, "")
channel = build_channel(brain_addr)
channel = build_grpc_channel(brain_addr)
if channel and grpc_server_ready(channel):
return BrainClient(channel)
else:
Expand Down
142 changes: 68 additions & 74 deletions dlrover/python/common/grpc.py → dlrover/python/common/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import pickle
import random
import socket
from contextlib import closing
from dataclasses import dataclass, field
from typing import Dict, List

import grpc

from dlrover.python.common.constants import GRPC, AscendConstants
from dlrover.python.common.constants import GRPC
from dlrover.python.common.log import default_logger as logger
from dlrover.python.common.serialize import JsonSerializable

TIMEOUT_SEC = 5


def build_channel(addr):
def build_grpc_channel(addr):
if not addr_connected(addr):
return None
channel = grpc.insecure_channel(
Expand Down Expand Up @@ -68,74 +66,6 @@
return False


def find_free_port(port=0):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", port))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]


def find_free_port_in_range(start=0, end=65535, random_port=True):
"""Find a free port from a range."""
bind_ports = set()
while True:
if random_port:
port = random.randint(start, end)
else:
port = start + len(bind_ports)
if port in bind_ports:
continue
try:
return find_free_port(port)
except OSError:
logger.warning(f"Socket creation attempt failed with {port}.")
bind_ports.add(port)
if len(bind_ports) == end - start + 1:
break
raise RuntimeError(f"Fail to find a free port in [{start}, {end})")


def find_free_port_in_set(ports):
for port in ports:
try:
return find_free_port(port)
except OSError:
logger.warning(f"Socket creation attempt failed with {port}.")
raise RuntimeError(f"Fail to find a free port in {ports}")


def find_free_port_for_hccl(
start=AscendConstants.HCCL_PORT_START_DEFAULT,
) -> int:
max_port = 65500
cur_start = start
end = start + 10000
if end > max_port:
end = max_port
logger.info(f"Try to find available port for hccl from {start}")
checking_port = 0
while True:
try:
cur_end = cur_start + AscendConstants.NPU_PER_NODE
for port in range(cur_start, cur_end):
checking_port = port
find_free_port(port)
logger.info(f"Find available port start from: {cur_start}")
break
except OSError:
logger.warning(
f"Target port has already been used: {checking_port}."
)
if checking_port > 0:
cur_start = checking_port + 1
else:
cur_start = cur_start + AscendConstants.NPU_PER_NODE
if cur_start > end:
cur_start = 0
break
return cur_start


def grpc_server_ready(channel) -> bool:
try:
grpc.channel_ready_future(channel).result(timeout=TIMEOUT_SEC)
Expand All @@ -144,11 +74,25 @@
return False


def deserialize_message(data: bytes):
def serialize_message(message):
"""The method will create a message instance with the content.
Args:
pickle_data: pickle bytes of a class instance.
"""
data = None
if message:
try:
data = pickle.dumps(message)
except Exception as e:
logger.warning(f"Pickle failed to load {str(data)}", e)
return data

Check warning on line 88 in dlrover/python/common/comm.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/common/comm.py#L82-L88

Added lines #L82 - L88 were not covered by tests


def deserialize_message(data: bytes):
"""The method will create a message instance with the content.
Args:
data: pickle bytes of a class instance.
"""
message = None
if data:
try:
Expand All @@ -163,6 +107,47 @@
return pickle.dumps(self)


@dataclass
class BaseRequest(Message):
node_id: int = -1
node_type: str = ""
data: bytes = b""

def to_json(self):
return {
"node_id": self.node_id,
"node_type": self.node_type,
"data": base64.b64encode(self.data).decode("utf-8"),
}

@staticmethod
def from_json(json_data):
return BaseRequest(
node_id=json_data.get("node_id"),
node_type=json_data.get("node_type"),
data=base64.b64decode(json_data.get("data")),
)


@dataclass
class BaseResponse(Message):
success: bool = False
data: bytes = b""

def to_json(self):
return {

Check warning on line 138 in dlrover/python/common/comm.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/common/comm.py#L138

Added line #L138 was not covered by tests
"success": self.success,
"data": base64.b64encode(self.data).decode("utf-8"),
}

@staticmethod
def from_json(json_data):
return BaseResponse(

Check warning on line 145 in dlrover/python/common/comm.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/common/comm.py#L145

Added line #L145 was not covered by tests
success=bool(json_data.get("success")),
data=base64.b64decode(json_data.get("data")),
)


@dataclass
class TaskRequest(Message):
dataset_name: str = ""
Expand Down Expand Up @@ -526,3 +511,12 @@
@dataclass
class HeartbeatResponse(Message):
action: DiagnosisAction = field(default_factory=DiagnosisAction)


class TaskType(object):
NONE = "NONE"
TRAINING = "TRAINING"
EVALUATION = "EVALUATION"
PREDICTION = "PREDICTION"
WAIT = "WAIT"
TRAIN_END_CALLBACK = "TRAIN_END_CALLBACK"
9 changes: 9 additions & 0 deletions dlrover/python/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class PlatformType(object):
LOCAL = "local"


class CommunicationType(object):
COMM_SERVICE_GRPC = "grpc"
COMM_SERVICE_HTTP = "http"


class ElasticJobApi(object):
GROUP = "elastic.iml.github.io"
VERION = "v1alpha1"
Expand Down Expand Up @@ -248,6 +253,7 @@ class TrainingLoopStatus(object):
class NodeEnv(object):
RELAUNCHED_POD = "RELAUNCHED_POD"
DLROVER_MASTER_ADDR = "DLROVER_MASTER_ADDR"
DLROVER_MASTER_SERVICE_TYPE = "DLROVER_MASTER_SERVICE_TYPE"
GRPC_ENABLE_FORK = "GRPC_ENABLE_FORK_SUPPORT"
GRPC_POLL_STRATEGY = "GRPC_POLL_STRATEGY"
POD_NAME = "POD_NAME"
Expand Down Expand Up @@ -360,6 +366,9 @@ class JobConstant(object):
PENDING_NODE_TIMEOUT_DEFAULT_MIN = 600
NODE_CHECK_TIMEOUT = 300

# timeout 60s
MASTER_CLIENT_DEFAULT_TIMEOUT = 60

# grpc timeout 60s
MASTER_CLIENT_GRPC_DEFAULT_TIMEOUT = 60

Expand Down
13 changes: 9 additions & 4 deletions dlrover/python/common/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@

import os

from dlrover.python.common import grpc
from dlrover.python.common.constants import UserEnv
from dlrover.python.common.constants import CommunicationType, UserEnv
from dlrover.python.common.log import default_logger as logger
from dlrover.python.common.singleton import Singleton
from dlrover.python.util.common_util import (
find_free_port_in_range,
find_free_port_in_set,
)


class ConfigKeys(object):
Expand All @@ -38,6 +41,7 @@ class ConfigKeys(object):


class DefaultValues(object):
SERVICE_TYPE = CommunicationType.COMM_SERVICE_GRPC
TRAIN_SPEED_RECORD_NUM = 50
SEC_TO_START_AUTOSCALE_WORKER = 90
STEP_TO_ADJUST_WORKER = 200
Expand All @@ -61,6 +65,7 @@ class DefaultValues(object):

class Context(Singleton):
def __init__(self):
self.master_service_type = DefaultValues.SERVICE_TYPE
self.train_speed_record_num = DefaultValues.TRAIN_SPEED_RECORD_NUM
self.seconds_to_autoscale_worker = (
DefaultValues.SEC_TO_START_AUTOSCALE_WORKER
Expand Down Expand Up @@ -173,13 +178,13 @@ def config_master_port(self, port=0):
for port in host_ports_env.split(","):
ports.append(int(port))
try:
self.master_port = grpc.find_free_port_in_set(ports)
self.master_port = find_free_port_in_set(ports)
except RuntimeError as e:
logger.warning(e)
elif port > 0:
self.master_port = port
if self.master_port is None:
self.master_port = grpc.find_free_port_in_range(20000, 30000)
self.master_port = find_free_port_in_range(20000, 30000)

def get_param_value_from_brain(self, key_name, default_value, dtype=float):
"""TODO: Get the configured value from Brain service."""
Expand Down
Loading
Loading