diff --git a/python/ray/serve/tests/test_deploy.py b/python/ray/serve/tests/test_deploy.py index 769bff110093..0068f0c07466 100644 --- a/python/ray/serve/tests/test_deploy.py +++ b/python/ray/serve/tests/test_deploy.py @@ -10,7 +10,7 @@ import ray from ray import serve from ray._private.pydantic_compat import ValidationError -from ray._private.test_utils import SignalActor +from ray._private.test_utils import SignalActor, wait_for_condition from ray.serve._private.constants import RAY_SERVE_EAGERLY_START_REPLACEMENT_REPLICAS from ray.serve._private.utils import get_random_string from ray.serve.exceptions import RayServeException @@ -324,10 +324,16 @@ def make_nonblocking_calls(expected, expect_blocking=False): make_nonblocking_calls({"2": 2}) -def test_reconfigure_with_queries(serve_instance): +def test_reconfigure_does_not_run_while_there_are_active_queries(serve_instance): + """ + This tests checks that reconfigure can't trigger while there are active requests, + so that the actor's state is not mutated mid-request. + + https://github.com/ray-project/ray/pull/20315 + """ signal = SignalActor.remote() - @serve.deployment(max_ongoing_requests=10, num_replicas=3) + @serve.deployment(max_ongoing_requests=10, num_replicas=1) class A: def __init__(self): self.state = None @@ -340,17 +346,38 @@ async def __call__(self): return self.state["a"] handle = serve.run(A.options(version="1", user_config={"a": 1}).bind()) - responses = [handle.remote() for _ in range(30)] + responses = [handle.remote() for _ in range(10)] + + # Give the queries time to get to the replicas before the reconfigure. + wait_for_condition( + lambda: ray.get(signal.cur_num_waiters.remote()) == len(responses) + ) @ray.remote(num_cpus=0) def reconfigure(): serve.run(A.options(version="1", user_config={"a": 2}).bind()) + # Start the reconfigure; + # this will not complete until the signal is released + # to allow the queries to complete. reconfigure_ref = reconfigure.remote() + + # Release the signal to allow the queries to complete. signal.send.remote() + + # Wait for the reconfigure to complete. ray.get(reconfigure_ref) - assert all([r.result() == 1 for r in responses]) + # These should all be 1 because the queries were sent before the reconfigure, + # the reconfigure blocks until they complete, + # and we just waited for the reconfigure to finish. + results = [r.result() for r in responses] + print(results) + assert all([r == 1 for r in results]) + + # If we query again, it should be 2, + # because the reconfigure will have gone through after the + # original queries completed. assert handle.remote().result() == 2