Skip to content

Commit

Permalink
[Serve] Shared LongPollClient for Routers (#48807)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

In our use case we use Ray Serve with many hundreds/thousands of apps,
plus a "router" app that routes traffic to those apps using
`DeploymentHandle`s. Right now, that means we have a `LongPollClient`
for each `DeploymentHandle` in each router app replica, which could be
tens or hundreds of thousands of `LongPollClient`s. This is expensive on
both the Serve Controller and on the router app replicas. It can be
particularly problematic in resource usage on the Serve Controller - the
main thing blocking us from having as many router replicas as we'd like
is the stability of the controller.

This PR aims to amortize this cost of having so many `LongPollClient`s
by going from one-long-poll-client-per-handle to
one-long-poll-client-per-process. Each `DeploymentHandle`'s `Router` now
registers itself with a shared `LongPollClient` held by a singleton.

The actual implementation that I've gone with is a bit clunky because
I'm trying to bridge the gap between the current solution and a design
that *only* has shared `LongPollClient`s. This could potentially be
cleaned up in the future. Right now, each `Router` still gets a
dedicated `LongPollClient` that only runs temporarily, until the shared
client tells it to stop.

Related: #45957 is the same idea
but for handle autoscaling metrics pushing.

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number

<!-- For example: "Closes #1234" -->

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [x] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: Josh Karpel <[email protected]>
  • Loading branch information
JoshKarpel authored Jan 16, 2025
1 parent d71e088 commit 7452bc6
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 5 deletions.
44 changes: 40 additions & 4 deletions python/ray/serve/_private/long_poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import random
from asyncio import sleep
from asyncio.events import AbstractEventLoop
from collections import defaultdict
from collections.abc import Mapping
Expand Down Expand Up @@ -79,7 +80,6 @@ def __init__(
key_listeners: Dict[KeyType, UpdateStateCallable],
call_in_event_loop: AbstractEventLoop,
) -> None:
assert len(key_listeners) > 0
# We used to allow this to be optional, but due to Ray Client issue
# we now enforce all long poll client to post callback to event loop
# See https://github.com/ray-project/ray/issues/20971
Expand All @@ -99,6 +99,27 @@ def __init__(

self._poll_next()

def stop(self) -> None:
"""Stop the long poll client after the next RPC returns."""
self.is_running = False

def add_key_listeners(
self, key_listeners: Dict[KeyType, UpdateStateCallable]
) -> None:
"""Add more key listeners to the client.
The new listeners will only be included in the *next* long poll request;
the current request will continue with the existing listeners.
If a key is already in the client, the new listener will replace the old one,
but the snapshot ID will be preserved, so the new listener will only be called
on the *next* update to that key.
"""
# Only initialize snapshot ids for *new* keys.
self.snapshot_ids.update(
{key: -1 for key in key_listeners.keys() if key not in self.key_listeners}
)
self.key_listeners.update(key_listeners)

def _on_callback_completed(self, trigger_at: int):
"""Called after a single callback is completed.
Expand All @@ -115,6 +136,9 @@ def _poll_next(self):
"""Poll the update. The callback is expected to scheduler another
_poll_next call.
"""
if not self.is_running:
return

self._callbacks_processed_count = 0
self._current_ref = self.host_actor.listen_for_change.remote(self.snapshot_ids)
self._current_ref._on_completed(lambda update: self._process_update(update))
Expand Down Expand Up @@ -162,6 +186,8 @@ def _process_update(self, updates: Dict[str, UpdatedObject]):
f"{list(updates.keys())}.",
extra={"log_to_stderr": False},
)
if not updates: # no updates, no callbacks to run, just poll again
self._schedule_to_event_loop(self._poll_next)
for key, update in updates.items():
self.snapshot_ids[key] = update.snapshot_id
callback = self.key_listeners[key]
Expand Down Expand Up @@ -246,10 +272,20 @@ async def listen_for_change(
) -> Union[LongPollState, Dict[KeyType, UpdatedObject]]:
"""Listen for changed objects.
This method will returns a dictionary of updated objects. It returns
immediately if the snapshot_ids are outdated, otherwise it will block
until there's an update.
This method will return a dictionary of updated objects. It returns
immediately if any of the snapshot_ids are outdated,
otherwise it will block until there's an update.
"""
# If there are no keys to listen for,
# just wait for a short time to provide backpressure,
# then return an empty update.
if not keys_to_snapshot_ids:
await sleep(1)

updated_objects = {}
self._count_send(updated_objects)
return updated_objects

# If there are any keys with outdated snapshot ids,
# return their updated values immediately.
updated_objects = {}
Expand Down
87 changes: 86 additions & 1 deletion python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import logging
import threading
import time
import weakref
from abc import ABC, abstractmethod
from asyncio import AbstractEventLoop
from collections import defaultdict
from collections.abc import MutableMapping
from contextlib import contextmanager
from functools import partial
from functools import lru_cache, partial
from typing import Any, Coroutine, DefaultDict, Dict, List, Optional, Tuple, Union

import ray
Expand Down Expand Up @@ -399,6 +402,14 @@ def __init__(
),
)

# The Router needs to stay informed about changes to the target deployment's
# running replicas and deployment config. We do this via the long poll system.
# However, for efficiency, we don't want to create a LongPollClient for every
# DeploymentHandle, so we use a shared LongPollClient that all Routers
# register themselves with. But first, the router needs to get a fast initial
# update so that it can start serving requests, which we do with a dedicated
# LongPollClient that stops running once the shared client takes over.

self.long_poll_client = LongPollClient(
controller_handle,
{
Expand All @@ -414,6 +425,11 @@ def __init__(
call_in_event_loop=self._event_loop,
)

shared = SharedRouterLongPollClient.get_or_create(
controller_handle, self._event_loop
)
shared.register(self)

def running_replicas_populated(self) -> bool:
return self._running_replicas_populated

Expand Down Expand Up @@ -690,3 +706,72 @@ def shutdown(self) -> concurrent.futures.Future:
return asyncio.run_coroutine_threadsafe(
self._asyncio_router.shutdown(), loop=self._asyncio_loop
)


class SharedRouterLongPollClient:
def __init__(self, controller_handle: ActorHandle, event_loop: AbstractEventLoop):
self.controller_handler = controller_handle

# We use a WeakSet to store the Routers so that we don't prevent them
# from being garbage-collected.
self.routers: MutableMapping[
DeploymentID, weakref.WeakSet[AsyncioRouter]
] = defaultdict(weakref.WeakSet)

# Creating the LongPollClient implicitly starts it
self.long_poll_client = LongPollClient(
controller_handle,
key_listeners={},
call_in_event_loop=event_loop,
)

@classmethod
@lru_cache(maxsize=None)
def get_or_create(
cls, controller_handle: ActorHandle, event_loop: AbstractEventLoop
) -> "SharedRouterLongPollClient":
shared = cls(controller_handle=controller_handle, event_loop=event_loop)
logger.info(f"Started {shared}.")
return shared

def update_deployment_targets(
self,
deployment_target_info: DeploymentTargetInfo,
deployment_id: DeploymentID,
) -> None:
for router in self.routers[deployment_id]:
router.update_deployment_targets(deployment_target_info)
router.long_poll_client.stop()

def update_deployment_config(
self, deployment_config: DeploymentConfig, deployment_id: DeploymentID
) -> None:
for router in self.routers[deployment_id]:
router.update_deployment_config(deployment_config)
router.long_poll_client.stop()

def register(self, router: AsyncioRouter) -> None:
self.routers[router.deployment_id].add(router)

# Remove the entries for any deployment ids that no longer have any routers.
# The WeakSets will automatically lose track of Routers that get GC'd,
# but the outer dict will keep the key around, so we need to clean up manually.
# Note the list(...) to avoid mutating self.routers while iterating over it.
for deployment_id, routers in list(self.routers.items()):
if not routers:
self.routers.pop(deployment_id)

# Register the new listeners on the long poll client.
# Some of these listeners may already exist, but it's safe to add them again.
key_listeners = {
(LongPollNamespace.DEPLOYMENT_TARGETS, deployment_id): partial(
self.update_deployment_targets, deployment_id=deployment_id
)
for deployment_id in self.routers.keys()
} | {
(LongPollNamespace.DEPLOYMENT_CONFIG, deployment_id): partial(
self.update_deployment_config, deployment_id=deployment_id
)
for deployment_id in self.routers.keys()
}
self.long_poll_client.add_key_listeners(key_listeners)

0 comments on commit 7452bc6

Please sign in to comment.