diff --git a/python/ray/serve/_private/long_poll.py b/python/ray/serve/_private/long_poll.py index 1e958f0be8a4a..a47511207e5e6 100644 --- a/python/ray/serve/_private/long_poll.py +++ b/python/ray/serve/_private/long_poll.py @@ -2,6 +2,7 @@ import logging import os import random +from asyncio import sleep from asyncio.events import AbstractEventLoop from collections import defaultdict from collections.abc import Mapping @@ -79,7 +80,6 @@ def __init__( key_listeners: Dict[KeyType, UpdateStateCallable], call_in_event_loop: AbstractEventLoop, ) -> None: - assert len(key_listeners) > 0 # We used to allow this to be optional, but due to Ray Client issue # we now enforce all long poll client to post callback to event loop # See https://github.com/ray-project/ray/issues/20971 @@ -99,6 +99,27 @@ def __init__( self._poll_next() + def stop(self) -> None: + """Stop the long poll client after the next RPC returns.""" + self.is_running = False + + def add_key_listeners( + self, key_listeners: Dict[KeyType, UpdateStateCallable] + ) -> None: + """Add more key listeners to the client. + The new listeners will only be included in the *next* long poll request; + the current request will continue with the existing listeners. + + If a key is already in the client, the new listener will replace the old one, + but the snapshot ID will be preserved, so the new listener will only be called + on the *next* update to that key. + """ + # Only initialize snapshot ids for *new* keys. + self.snapshot_ids.update( + {key: -1 for key in key_listeners.keys() if key not in self.key_listeners} + ) + self.key_listeners.update(key_listeners) + def _on_callback_completed(self, trigger_at: int): """Called after a single callback is completed. @@ -115,6 +136,9 @@ def _poll_next(self): """Poll the update. The callback is expected to scheduler another _poll_next call. """ + if not self.is_running: + return + self._callbacks_processed_count = 0 self._current_ref = self.host_actor.listen_for_change.remote(self.snapshot_ids) self._current_ref._on_completed(lambda update: self._process_update(update)) @@ -162,6 +186,8 @@ def _process_update(self, updates: Dict[str, UpdatedObject]): f"{list(updates.keys())}.", extra={"log_to_stderr": False}, ) + if not updates: # no updates, no callbacks to run, just poll again + self._schedule_to_event_loop(self._poll_next) for key, update in updates.items(): self.snapshot_ids[key] = update.snapshot_id callback = self.key_listeners[key] @@ -246,10 +272,20 @@ async def listen_for_change( ) -> Union[LongPollState, Dict[KeyType, UpdatedObject]]: """Listen for changed objects. - This method will returns a dictionary of updated objects. It returns - immediately if the snapshot_ids are outdated, otherwise it will block - until there's an update. + This method will return a dictionary of updated objects. It returns + immediately if any of the snapshot_ids are outdated, + otherwise it will block until there's an update. """ + # If there are no keys to listen for, + # just wait for a short time to provide backpressure, + # then return an empty update. + if not keys_to_snapshot_ids: + await sleep(1) + + updated_objects = {} + self._count_send(updated_objects) + return updated_objects + # If there are any keys with outdated snapshot ids, # return their updated values immediately. updated_objects = {} diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 1bb6d5b510e0e..d6b67611c1edb 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -3,10 +3,13 @@ import logging import threading import time +import weakref from abc import ABC, abstractmethod +from asyncio import AbstractEventLoop from collections import defaultdict +from collections.abc import MutableMapping from contextlib import contextmanager -from functools import partial +from functools import lru_cache, partial from typing import Any, Coroutine, DefaultDict, Dict, List, Optional, Tuple, Union import ray @@ -399,6 +402,14 @@ def __init__( ), ) + # The Router needs to stay informed about changes to the target deployment's + # running replicas and deployment config. We do this via the long poll system. + # However, for efficiency, we don't want to create a LongPollClient for every + # DeploymentHandle, so we use a shared LongPollClient that all Routers + # register themselves with. But first, the router needs to get a fast initial + # update so that it can start serving requests, which we do with a dedicated + # LongPollClient that stops running once the shared client takes over. + self.long_poll_client = LongPollClient( controller_handle, { @@ -414,6 +425,11 @@ def __init__( call_in_event_loop=self._event_loop, ) + shared = SharedRouterLongPollClient.get_or_create( + controller_handle, self._event_loop + ) + shared.register(self) + def running_replicas_populated(self) -> bool: return self._running_replicas_populated @@ -690,3 +706,72 @@ def shutdown(self) -> concurrent.futures.Future: return asyncio.run_coroutine_threadsafe( self._asyncio_router.shutdown(), loop=self._asyncio_loop ) + + +class SharedRouterLongPollClient: + def __init__(self, controller_handle: ActorHandle, event_loop: AbstractEventLoop): + self.controller_handler = controller_handle + + # We use a WeakSet to store the Routers so that we don't prevent them + # from being garbage-collected. + self.routers: MutableMapping[ + DeploymentID, weakref.WeakSet[AsyncioRouter] + ] = defaultdict(weakref.WeakSet) + + # Creating the LongPollClient implicitly starts it + self.long_poll_client = LongPollClient( + controller_handle, + key_listeners={}, + call_in_event_loop=event_loop, + ) + + @classmethod + @lru_cache(maxsize=None) + def get_or_create( + cls, controller_handle: ActorHandle, event_loop: AbstractEventLoop + ) -> "SharedRouterLongPollClient": + shared = cls(controller_handle=controller_handle, event_loop=event_loop) + logger.info(f"Started {shared}.") + return shared + + def update_deployment_targets( + self, + deployment_target_info: DeploymentTargetInfo, + deployment_id: DeploymentID, + ) -> None: + for router in self.routers[deployment_id]: + router.update_deployment_targets(deployment_target_info) + router.long_poll_client.stop() + + def update_deployment_config( + self, deployment_config: DeploymentConfig, deployment_id: DeploymentID + ) -> None: + for router in self.routers[deployment_id]: + router.update_deployment_config(deployment_config) + router.long_poll_client.stop() + + def register(self, router: AsyncioRouter) -> None: + self.routers[router.deployment_id].add(router) + + # Remove the entries for any deployment ids that no longer have any routers. + # The WeakSets will automatically lose track of Routers that get GC'd, + # but the outer dict will keep the key around, so we need to clean up manually. + # Note the list(...) to avoid mutating self.routers while iterating over it. + for deployment_id, routers in list(self.routers.items()): + if not routers: + self.routers.pop(deployment_id) + + # Register the new listeners on the long poll client. + # Some of these listeners may already exist, but it's safe to add them again. + key_listeners = { + (LongPollNamespace.DEPLOYMENT_TARGETS, deployment_id): partial( + self.update_deployment_targets, deployment_id=deployment_id + ) + for deployment_id in self.routers.keys() + } | { + (LongPollNamespace.DEPLOYMENT_CONFIG, deployment_id): partial( + self.update_deployment_config, deployment_id=deployment_id + ) + for deployment_id in self.routers.keys() + } + self.long_poll_client.add_key_listeners(key_listeners)