Skip to content

Commit

Permalink
Merge branch 'master' into hjiang/fix-windows-dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
dentiny committed Jan 22, 2025
2 parents cd35adf + 89990f6 commit dfd0116
Show file tree
Hide file tree
Showing 27 changed files with 1,737 additions and 370 deletions.
10 changes: 5 additions & 5 deletions docker/base-deps/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,16 @@ RUN <<EOF

set -euo pipefail

# Install miniconda
# Install miniforge
wget --quiet \
"https://repo.anaconda.com/miniconda/Miniconda3-py311_24.4.0-0-Linux-${HOSTTYPE}.sh" \
-O /tmp/miniconda.sh
"https://github.com/conda-forge/miniforge/releases/download/24.11.3-0/Miniforge3-24.11.3-0-Linux-${HOSTTYPE}.sh" \
-O /tmp/miniforge.sh

/bin/bash /tmp/miniconda.sh -b -u -p $HOME/anaconda3
/bin/bash /tmp/miniforge.sh -b -u -p $HOME/anaconda3

$HOME/anaconda3/bin/conda init
echo 'export PATH=$HOME/anaconda3/bin:$PATH' >> /home/ray/.bashrc
rm /tmp/miniconda.sh
rm /tmp/miniforge.sh
$HOME/anaconda3/bin/conda install -y libgcc-ng python=$PYTHON_VERSION
$HOME/anaconda3/bin/conda install -y -c conda-forge libffi=3.4.2
$HOME/anaconda3/bin/conda clean -y --all
Expand Down
22 changes: 5 additions & 17 deletions python/ray/autoscaler/_private/kuberay/autoscaling_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ class AutoscalingConfigProducer:
"""

def __init__(self, ray_cluster_name, ray_cluster_namespace):
self._headers, self._verify = node_provider.load_k8s_secrets()
self._ray_cr_url = node_provider.url_from_resource(
namespace=ray_cluster_namespace, path=f"rayclusters/{ray_cluster_name}"
self.kubernetes_api_client = node_provider.KubernetesHttpApiClient(
namespace=ray_cluster_namespace
)
self._ray_cr_path = f"rayclusters/{ray_cluster_name}"

def __call__(self):
ray_cr = self._fetch_ray_cr_from_k8s_with_retries()
Expand All @@ -67,7 +67,7 @@ def _fetch_ray_cr_from_k8s_with_retries(self) -> Dict[str, Any]:
"""
for i in range(1, MAX_RAYCLUSTER_FETCH_TRIES + 1):
try:
return self._fetch_ray_cr_from_k8s()
return self.kubernetes_api_client.get(self._ray_cr_path)
except requests.HTTPError as e:
if i < MAX_RAYCLUSTER_FETCH_TRIES:
logger.exception(
Expand All @@ -80,18 +80,6 @@ def _fetch_ray_cr_from_k8s_with_retries(self) -> Dict[str, Any]:
# This branch is inaccessible. Raise to satisfy mypy.
raise AssertionError

def _fetch_ray_cr_from_k8s(self) -> Dict[str, Any]:
result = requests.get(
self._ray_cr_url,
headers=self._headers,
timeout=node_provider.KUBERAY_REQUEST_TIMEOUT_S,
verify=self._verify,
)
if not result.status_code == 200:
result.raise_for_status()
ray_cr = result.json()
return ray_cr


def _derive_autoscaling_config_from_ray_cr(ray_cr: Dict[str, Any]) -> Dict[str, Any]:
provider_config = _generate_provider_config(ray_cr["metadata"]["namespace"])
Expand Down Expand Up @@ -179,7 +167,7 @@ def _generate_legacy_autoscaling_config_fields() -> Dict[str, Any]:


def _generate_available_node_types_from_ray_cr_spec(
ray_cr_spec: Dict[str, Any]
ray_cr_spec: Dict[str, Any],
) -> Dict[str, Any]:
"""Formats autoscaler "available_node_types" field based on the Ray CR's group
specs.
Expand Down
28 changes: 23 additions & 5 deletions python/ray/autoscaler/_private/kuberay/node_provider.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import json
import logging
import os
Expand Down Expand Up @@ -54,6 +55,8 @@
# Key for GKE label that identifies which multi-host replica a pod belongs to
REPLICA_INDEX_KEY = "replicaIndex"

TOKEN_REFRESH_PERIOD = datetime.timedelta(minutes=1)

# Design:

# Each modification the autoscaler wants to make is posted to the API server goal state
Expand Down Expand Up @@ -264,7 +267,19 @@ class KubernetesHttpApiClient(IKubernetesHttpApiClient):
def __init__(self, namespace: str, kuberay_crd_version: str = KUBERAY_CRD_VER):
self._kuberay_crd_version = kuberay_crd_version
self._namespace = namespace
self._headers, self._verify = load_k8s_secrets()
self._token_expires_at = datetime.datetime.now() + TOKEN_REFRESH_PERIOD
self._headers, self._verify = None, None

def _get_refreshed_headers_and_verify(self):
if (datetime.datetime.now() >= self._token_expires_at) or (
self._headers is None or self._verify is None
):
logger.info("Refreshing K8s API client token and certs.")
self._headers, self._verify = load_k8s_secrets()
self._token_expires_at = datetime.datetime.now() + TOKEN_REFRESH_PERIOD
return self._headers, self._verify
else:
return self._headers, self._verify

def get(self, path: str) -> Dict[str, Any]:
"""Wrapper for REST GET of resource with proper headers.
Expand All @@ -283,11 +298,13 @@ def get(self, path: str) -> Dict[str, Any]:
path=path,
kuberay_crd_version=self._kuberay_crd_version,
)

headers, verify = self._get_refreshed_headers_and_verify()
result = requests.get(
url,
headers=self._headers,
headers=headers,
timeout=KUBERAY_REQUEST_TIMEOUT_S,
verify=self._verify,
verify=verify,
)
if not result.status_code == 200:
result.raise_for_status()
Expand All @@ -311,11 +328,12 @@ def patch(self, path: str, payload: List[Dict[str, Any]]) -> Dict[str, Any]:
path=path,
kuberay_crd_version=self._kuberay_crd_version,
)
headers, verify = self._get_refreshed_headers_and_verify()
result = requests.patch(
url,
json.dumps(payload),
headers={**self._headers, "Content-type": "application/json-patch+json"},
verify=self._verify,
headers={**headers, "Content-type": "application/json-patch+json"},
verify=verify,
)
if not result.status_code == 200:
result.raise_for_status()
Expand Down
151 changes: 3 additions & 148 deletions python/ray/dashboard/optional_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import collections
import functools
import inspect
import json
import logging
import os
import time
import traceback
from collections import namedtuple
from typing import Any, Callable, Union
from typing import Callable, Union

from aiohttp.web import Request, Response

Expand All @@ -23,13 +22,12 @@
# All third-party dependencies that are not included in the minimal Ray
# installation must be included in this file. This allows us to determine if
# the agent has the necessary dependencies to be started.
from ray.dashboard.optional_deps import PathLike, RouteDef, aiohttp, hdrs
from ray.dashboard.optional_deps import aiohttp, hdrs
from ray.dashboard.utils import (
CustomEncoder,
DashboardAgentModule,
DashboardHeadModule,
to_google_style,
)
from ray.dashboard.routes import method_route_table_factory, rest_response

try:
create_task = asyncio.create_task
Expand All @@ -39,153 +37,10 @@

logger = logging.getLogger(__name__)


def method_route_table_factory():
class MethodRouteTable:
"""A helper class to bind http route to class method."""

_bind_map = collections.defaultdict(dict)
_routes = aiohttp.web.RouteTableDef()

class _BindInfo:
def __init__(self, filename, lineno, instance):
self.filename = filename
self.lineno = lineno
self.instance = instance

@classmethod
def routes(cls):
return cls._routes

@classmethod
def bound_routes(cls):
bound_items = []
for r in cls._routes._items:
if isinstance(r, RouteDef):
route_method = r.handler.__route_method__
route_path = r.handler.__route_path__
instance = cls._bind_map[route_method][route_path].instance
if instance is not None:
bound_items.append(r)
else:
bound_items.append(r)
routes = aiohttp.web.RouteTableDef()
routes._items = bound_items
return routes

@classmethod
def _register_route(cls, method, path, **kwargs):
def _wrapper(handler):
if path in cls._bind_map[method]:
bind_info = cls._bind_map[method][path]
raise Exception(
f"Duplicated route path: {path}, "
f"previous one registered at "
f"{bind_info.filename}:{bind_info.lineno}"
)

bind_info = cls._BindInfo(
handler.__code__.co_filename, handler.__code__.co_firstlineno, None
)

@functools.wraps(handler)
async def _handler_route(*args) -> aiohttp.web.Response:
try:
# Make the route handler as a bound method.
# The args may be:
# * (Request, )
# * (self, Request)
req = args[-1]
return await handler(bind_info.instance, req)
except Exception:
logger.exception("Handle %s %s failed.", method, path)
return rest_response(
success=False, message=traceback.format_exc()
)

cls._bind_map[method][path] = bind_info
_handler_route.__route_method__ = method
_handler_route.__route_path__ = path
return cls._routes.route(method, path, **kwargs)(_handler_route)

return _wrapper

@classmethod
def head(cls, path, **kwargs):
return cls._register_route(hdrs.METH_HEAD, path, **kwargs)

@classmethod
def get(cls, path, **kwargs):
return cls._register_route(hdrs.METH_GET, path, **kwargs)

@classmethod
def post(cls, path, **kwargs):
return cls._register_route(hdrs.METH_POST, path, **kwargs)

@classmethod
def put(cls, path, **kwargs):
return cls._register_route(hdrs.METH_PUT, path, **kwargs)

@classmethod
def patch(cls, path, **kwargs):
return cls._register_route(hdrs.METH_PATCH, path, **kwargs)

@classmethod
def delete(cls, path, **kwargs):
return cls._register_route(hdrs.METH_DELETE, path, **kwargs)

@classmethod
def view(cls, path, **kwargs):
return cls._register_route(hdrs.METH_ANY, path, **kwargs)

@classmethod
def static(cls, prefix: str, path: PathLike, **kwargs: Any) -> None:
cls._routes.static(prefix, path, **kwargs)

@classmethod
def bind(cls, instance):
def predicate(o):
if inspect.ismethod(o):
return hasattr(o, "__route_method__") and hasattr(
o, "__route_path__"
)
return False

handler_routes = inspect.getmembers(instance, predicate)
for _, h in handler_routes:
cls._bind_map[h.__func__.__route_method__][
h.__func__.__route_path__
].instance = instance

return MethodRouteTable


DashboardHeadRouteTable = method_route_table_factory()
DashboardAgentRouteTable = method_route_table_factory()


def rest_response(
success, message, convert_google_style=True, **kwargs
) -> aiohttp.web.Response:
# In the dev context we allow a dev server running on a
# different port to consume the API, meaning we need to allow
# cross-origin access
if os.environ.get("RAY_DASHBOARD_DEV") == "1":
headers = {"Access-Control-Allow-Origin": "*"}
else:
headers = {}
return aiohttp.web.json_response(
{
"result": success,
"msg": message,
"data": to_google_style(kwargs) if convert_google_style else kwargs,
},
dumps=functools.partial(json.dumps, cls=CustomEncoder),
headers=headers,
status=200 if success else 500,
)


# The cache value type used by aiohttp_cache.
_AiohttpCacheValue = namedtuple("AiohttpCacheValue", ["data", "expiration", "task"])
# The methods with no request body used by aiohttp_cache.
Expand Down
Loading

0 comments on commit dfd0116

Please sign in to comment.