Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Shared LongPollClient for Routers #48807

Merged
merged 27 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5169e86
set up shared long poll client
JoshKarpel Nov 19, 2024
ae170d3
avoid mutating while iterating
JoshKarpel Nov 20, 2024
53f97fc
better dictionary merge
JoshKarpel Nov 20, 2024
d586f08
protect against empty keys
JoshKarpel Nov 20, 2024
6ce1b3e
Merge branch 'master' into shared-long-poll-client
JoshKarpel Nov 21, 2024
9fa746b
call _count_send
JoshKarpel Nov 21, 2024
83eb5cf
shorter sleep on empty keys
JoshKarpel Nov 21, 2024
cfc7d19
poll again if no callbacks
JoshKarpel Nov 21, 2024
c673bec
Merge branch 'master' into shared-long-poll-client
JoshKarpel Nov 27, 2024
67384eb
Merge branch 'master' into shared-long-poll-client
JoshKarpel Dec 2, 2024
1b4712f
Merge branch 'master' into shared-long-poll-client
JoshKarpel Dec 2, 2024
62bb4c3
Merge branch 'master' into shared-long-poll-client
JoshKarpel Dec 4, 2024
56b7df5
rework test
JoshKarpel Dec 5, 2024
102fb09
use new handle
JoshKarpel Dec 5, 2024
ca4a1f7
does a long sleep fix it?
JoshKarpel Dec 11, 2024
9737dbd
do not stop the dedicated client until the shared client gets an update
JoshKarpel Dec 11, 2024
78d2b77
fix typo
JoshKarpel Dec 12, 2024
9f14f71
tidy up
JoshKarpel Dec 12, 2024
a6676ed
Merge branch 'master' into shared-long-poll-client
JoshKarpel Dec 12, 2024
d7dd6de
undo test changes
JoshKarpel Dec 12, 2024
4f93687
Merge branch 'master' into shared-long-poll-client
JoshKarpel Dec 16, 2024
0c62830
Merge branch 'master' into shared-long-poll-client
JoshKarpel Jan 6, 2025
49cc679
Merge branch 'master' into shared-long-poll-client
JoshKarpel Jan 9, 2025
3dfce1f
Merge branch 'master' into shared-long-poll-client
JoshKarpel Jan 13, 2025
0611dae
handle changes around update_running_targets
JoshKarpel Jan 13, 2025
d62e598
Merge branch 'master' into shared-long-poll-client
JoshKarpel Jan 14, 2025
dc5138c
Merge branch 'master' into shared-long-poll-client
JoshKarpel Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions python/ray/serve/_private/long_poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import random
from asyncio import sleep
from asyncio.events import AbstractEventLoop
from collections import defaultdict
from collections.abc import Mapping
Expand Down Expand Up @@ -79,7 +80,6 @@ def __init__(
key_listeners: Dict[KeyType, UpdateStateCallable],
call_in_event_loop: AbstractEventLoop,
) -> None:
assert len(key_listeners) > 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is now handled by https://github.com/ray-project/ray/pull/48807/files#diff-f138b21f7ddcd7d61c0b2704c8b828b9bbe7eb5021531e2c7fabeb20ec322e1aR280-R288 (and is necessary - when the shared client boots up for the first time it will send an RPC with no keys in it)

# We used to allow this to be optional, but due to Ray Client issue
# we now enforce all long poll client to post callback to event loop
# See https://github.com/ray-project/ray/issues/20971
Expand All @@ -99,6 +99,27 @@ def __init__(

self._poll_next()

def stop(self) -> None:
"""Stop the long poll client after the next RPC returns."""
self.is_running = False

def add_key_listeners(
self, key_listeners: Dict[KeyType, UpdateStateCallable]
) -> None:
"""Add more key listeners to the client.
The new listeners will only be included in the *next* long poll request;
the current request will continue with the existing listeners.

If a key is already in the client, the new listener will replace the old one,
but the snapshot ID will be preserved, so the new listener will only be called
on the *next* update to that key.
"""
# Only initialize snapshot ids for *new* keys.
self.snapshot_ids.update(
{key: -1 for key in key_listeners.keys() if key not in self.key_listeners}
)
self.key_listeners.update(key_listeners)

def _on_callback_completed(self, trigger_at: int):
"""Called after a single callback is completed.

Expand All @@ -115,6 +136,9 @@ def _poll_next(self):
"""Poll the update. The callback is expected to scheduler another
_poll_next call.
"""
if not self.is_running:
return

self._callbacks_processed_count = 0
self._current_ref = self.host_actor.listen_for_change.remote(self.snapshot_ids)
self._current_ref._on_completed(lambda update: self._process_update(update))
Expand Down Expand Up @@ -162,6 +186,8 @@ def _process_update(self, updates: Dict[str, UpdatedObject]):
f"{list(updates.keys())}.",
extra={"log_to_stderr": False},
)
if not updates: # no updates, no callbacks to run, just poll again
self._schedule_to_event_loop(self._poll_next)
for key, update in updates.items():
self.snapshot_ids[key] = update.snapshot_id
callback = self.key_listeners[key]
Expand Down Expand Up @@ -246,10 +272,20 @@ async def listen_for_change(
) -> Union[LongPollState, Dict[KeyType, UpdatedObject]]:
"""Listen for changed objects.

This method will returns a dictionary of updated objects. It returns
immediately if the snapshot_ids are outdated, otherwise it will block
until there's an update.
This method will return a dictionary of updated objects. It returns
immediately if any of the snapshot_ids are outdated,
otherwise it will block until there's an update.
"""
# If there are no keys to listen for,
# just wait for a short time to provide backpressure,
# then return an empty update.
if not keys_to_snapshot_ids:
await sleep(1)

updated_objects = {}
self._count_send(updated_objects)
return updated_objects

# If there are any keys with outdated snapshot ids,
# return their updated values immediately.
updated_objects = {}
Expand Down
85 changes: 84 additions & 1 deletion python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import threading
import time
import uuid
import weakref
from abc import ABC, abstractmethod
from asyncio import AbstractEventLoop
from collections import defaultdict
from collections.abc import MutableMapping
from contextlib import contextmanager
from functools import partial
from functools import lru_cache, partial
from typing import Any, Coroutine, DefaultDict, Dict, List, Optional, Tuple, Union

import ray
Expand Down Expand Up @@ -398,6 +401,14 @@ def __init__(
),
)

# The Router needs to stay informed about changes to the target deployment's
# running replicas and deployment config. We do this via the long poll system.
# However, for efficiency, we don't want to create a LongPollClient for every
# DeploymentHandle, so we use a shared LongPollClient that all Routers
# register themselves with. But first, the router needs to get a fast initial
# update so that it can start serving requests, which we do with a dedicated
# LongPollClient that stops running once the shared client takes over.

self.long_poll_client = LongPollClient(
controller_handle,
{
Expand All @@ -413,6 +424,11 @@ def __init__(
call_in_event_loop=self._event_loop,
)

shared = SharedRouterLongPollClient.get_or_create(
controller_handle, self._event_loop
)
shared.register(self)

def running_replicas_populated(self) -> bool:
return self._running_replicas_populated

Expand Down Expand Up @@ -684,3 +700,70 @@ def shutdown(self) -> concurrent.futures.Future:
return asyncio.run_coroutine_threadsafe(
self._asyncio_router.shutdown(), loop=self._asyncio_loop
)


class SharedRouterLongPollClient:
def __init__(self, controller_handle: ActorHandle, event_loop: AbstractEventLoop):
self.controller_handler = controller_handle

# We use a WeakSet to store the Routers so that we don't prevent them
# from being garbage-collected.
self.routers: MutableMapping[
DeploymentID, weakref.WeakSet[AsyncioRouter]
] = defaultdict(weakref.WeakSet)

# Creating the LongPollClient implicitly starts it
self.long_poll_client = LongPollClient(
controller_handle,
key_listeners={},
call_in_event_loop=event_loop,
)

@classmethod
@lru_cache(maxsize=None)
def get_or_create(
cls, controller_handle: ActorHandle, event_loop: AbstractEventLoop
) -> "SharedRouterLongPollClient":
shared = cls(controller_handle=controller_handle, event_loop=event_loop)
logger.info(f"Started {shared}.")
return shared

def update_running_replicas(
self, running_replicas: List[RunningReplicaInfo], deployment_id: DeploymentID
) -> None:
for router in self.routers[deployment_id]:
router.update_running_replicas(running_replicas)
router.long_poll_client.stop()

def update_deployment_config(
self, deployment_config: DeploymentConfig, deployment_id: DeploymentID
) -> None:
for router in self.routers[deployment_id]:
router.update_deployment_config(deployment_config)
router.long_poll_client.stop()

def register(self, router: AsyncioRouter) -> None:
self.routers[router.deployment_id].add(router)

# Remove the entries for any deployment ids that no longer have any routers.
# The WeakSets will automatically lose track of Routers that get GC'd,
# but the outer dict will keep the key around, so we need to clean up manually.
# Note the list(...) to avoid mutating self.routers while iterating over it.
for deployment_id, routers in list(self.routers.items()):
if not routers:
self.routers.pop(deployment_id)

# Register the new listeners on the long poll client.
# Some of these listeners may already exist, but it's safe to add them again.
key_listeners = {
(LongPollNamespace.RUNNING_REPLICAS, deployment_id): partial(
self.update_running_replicas, deployment_id=deployment_id
)
for deployment_id in self.routers.keys()
} | {
(LongPollNamespace.DEPLOYMENT_CONFIG, deployment_id): partial(
self.update_deployment_config, deployment_id=deployment_id
)
for deployment_id in self.routers.keys()
}
self.long_poll_client.add_key_listeners(key_listeners)
33 changes: 29 additions & 4 deletions python/ray/serve/tests/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,16 @@ def make_nonblocking_calls(expected, expect_blocking=False):
make_nonblocking_calls({"2": 2})


def test_reconfigure_with_queries(serve_instance):
def test_reconfigure_does_not_run_while_there_are_active_queries(serve_instance):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to de-flake this test 🤞🏻

"""
This tests checks that reconfigure can't trigger while there are active requests,
so that the actor's state is not mutated mid-request.

https://github.com/ray-project/ray/pull/20315
"""
signal = SignalActor.remote()

@serve.deployment(max_ongoing_requests=10, num_replicas=3)
@serve.deployment(max_ongoing_requests=10, num_replicas=1)
class A:
def __init__(self):
self.state = None
Expand All @@ -340,17 +346,36 @@ async def __call__(self):
return self.state["a"]

handle = serve.run(A.options(version="1", user_config={"a": 1}).bind())
responses = [handle.remote() for _ in range(30)]
responses = [handle.remote() for _ in range(10)]

# Give the queries time to get to the replicas before the reconfigure.
time.sleep(0.1)

@ray.remote(num_cpus=0)
def reconfigure():
serve.run(A.options(version="1", user_config={"a": 2}).bind())

# Start the reconfigure;
# this will not complete until the signal is released
# to allow the queries to complete.
reconfigure_ref = reconfigure.remote()

# Release the signal to allow the queries to complete.
signal.send.remote()

# Wait for the reconfigure to complete.
ray.get(reconfigure_ref)

assert all([r.result() == 1 for r in responses])
# These should all be 1 because the queries were sent before the reconfigure,
# the reconfigure blocks until they complete,
# and we just waited for the reconfigure to finish.
results = [r.result() for r in responses]
print(results)
assert all([r == 1 for r in results])

# If we query again, it should be 2,
# because the reconfigure will have gone through after the
# original queries completed.
assert handle.remote().result() == 2


Expand Down
Loading