diff --git a/framework/helpers/docker.py b/framework/helpers/docker.py index 636e3d44..b1e4ee9b 100644 --- a/framework/helpers/docker.py +++ b/framework/helpers/docker.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict import datetime import logging import math import pathlib -import queue import threading +import time import grpc +from grpc_channelz.v1 import channelz_pb2 import mako.template from docker import client from docker import errors from docker import types +from framework.rpc.grpc_channelz import ChannelzServiceClient from protos.grpc.testing import messages_pb2 from protos.grpc.testing import test_pb2_grpc from protos.grpc.testing.xdsconfig import xdsconfig_pb2 @@ -50,22 +51,29 @@ def _make_working_dir(base: pathlib.Path) -> str: class Bootstrap: - def __init__(self, base: pathlib.Path, ports: list[int], host_name: str): - self.ports = ports + def __init__( + self, + base: pathlib.Path, + primary_port: int, + fallback_port: int, + host_name: str, + ): + self.primary_port = primary_port + self.fallback_port = fallback_port self.mount_dir = _make_working_dir(base) # Use Mako template = mako.template.Template(filename=BOOTSTRAP_JSON_TEMPLATE) file = template.render( - servers=[f"{host_name}:{port}" for port in self.ports] + servers=[ + f"{host_name}:{primary_port}", + f"{host_name}:{fallback_port}", + ] ) destination = self.mount_dir / "bootstrap.json" with open(destination, "w", encoding="utf-8") as f: f.write(file) logger.debug("Generated bootstrap file at %s", destination) - def xds_config_server_port(self, server_id: int): - return self.ports[server_id] - class ChildProcessEvent: def __init__(self, source: str, data: str): @@ -84,61 +92,10 @@ def __init__( self, bootstrap: Bootstrap, node_id: str, - verbosity="info", ): self.docker_client = client.DockerClient.from_env() self.node_id = node_id - self.outputs = defaultdict(list) - self.queue = queue.Queue() self.bootstrap = bootstrap - self.verbosity = verbosity - - def next_event(self, timeout: int) -> ChildProcessEvent: - event: ChildProcessEvent = self.queue.get(timeout=timeout) - source = event.source - message = event.data - self.outputs[source].append(message) - return event - - def expect_output( - self, process_name: str, expected_message: str, timeout_s: int - ) -> bool: - """ - Checks if the specified message appears in the output of a given process within a timeout. - - Returns: - True if the expected message is found in the process's output within - the timeout, False otherwise. - - Behavior: - - If the process has already produced output, it checks there first. - - Otherwise, it waits for new events from the process, up to the specified timeout. - - If an event from the process contains the expected message, it returns True. - - If the timeout is reached without finding the message, it returns False. - """ - logger.debug( - 'Waiting for message "%s" from %s', expected_message, process_name - ) - if any( - m - for m in self.outputs[process_name] - if m.find(expected_message) >= 0 - ): - return True - deadline = datetime.datetime.now() + datetime.timedelta( - seconds=timeout_s - ) - while datetime.datetime.now() <= deadline: - event = self.next_event(timeout_s) - if ( - event.source == process_name - and event.data.find(expected_message) >= 0 - ): - return True - return False - - def on_message(self, source: str, message: str): - self.queue.put(ChildProcessEvent(source, message)) def _Sanitize(l: str) -> str: @@ -147,12 +104,12 @@ def _Sanitize(l: str) -> str: return l.replace("\0", "�") -def Configure(config, image: str, name: str, verbosity: str): +def Configure(config, image: str, name: str): config["detach"] = True config["environment"] = { "GRPC_EXPERIMENTAL_XDS_FALLBACK": "true", "GRPC_TRACE": "xds_client", - "GRPC_VERBOSITY": verbosity, + "GRPC_VERBOSITY": "info", "GRPC_XDS_BOOTSTRAP": "/grpc/bootstrap.json", } config["extra_hosts"] = {"host.docker.internal": "host-gateway"} @@ -171,9 +128,7 @@ def __init__( **config: types.ContainerConfig, ): self.name = name - self.config = Configure( - config, image=image, name=name, verbosity=manager.verbosity - ) + self.config = Configure(config, image, name) self.container = None self.manager = manager self.thread = None @@ -216,7 +171,6 @@ def log_reader_loop(self): for l in s[: s.rfind("\n")].splitlines(): message = _Sanitize(l) logger.info("[%s] %s", self.name, message) - self.manager.on_message(self.name, message) class GrpcProcess: @@ -252,13 +206,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.grpc_channel.close() self.docker_process.exit() - def expect_message_in_output( - self, expected_message: str, timeout_s: int = 5 - ) -> bool: - return self.manager.expect_output( - self.docker_process.name, expected_message, timeout_s - ) - def channel(self) -> grpc.Channel: if self.grpc_channel is None: self.grpc_channel = grpc.insecure_channel(f"localhost:{self.port}") @@ -318,8 +265,13 @@ def __init__( port=port, image=image, name=name, - command=[f"--server={url}", "--print_response"], - ports={DEFAULT_GRPC_CLIENT_PORT: port}, + command=[ + "--server", + url, + "--stats_port", + str(port), + ], + ports={str(port): port}, volumes={ manager.bootstrap.mount_dir: { "bind": "/grpc", @@ -333,7 +285,28 @@ def get_stats(self, num_rpcs: int): stub = test_pb2_grpc.LoadBalancerStatsServiceStub(self.channel()) res = stub.GetClientStats( messages_pb2.LoadBalancerStatsRequest( - num_rpcs=num_rpcs, timeout_sec=math.ceil(num_rpcs * 1.5) + num_rpcs=num_rpcs, timeout_sec=math.ceil(num_rpcs * 10) ) ) return res + + def expect_channel_status( + self, + port: int, + expected_status: channelz_pb2.ChannelConnectivityState, + timeout: datetime.timedelta, + poll_interval: datetime.timedelta, + ) -> channelz_pb2.ChannelConnectivityState: + deadline = datetime.datetime.now() + timeout + channelz = ChannelzServiceClient(self.channel()) + status = None + while datetime.datetime.now() < deadline: + status = None + for ch in channelz.list_channels(): + if ch.data.target.endswith(str(port)): + status = ch.data.state.state + break + if status == expected_status: + break + time.sleep(poll_interval.microseconds * 0.000001) + return status diff --git a/tests/fallback_test.py b/tests/fallback_test.py index e7fe0f43..5656007a 100644 --- a/tests/fallback_test.py +++ b/tests/fallback_test.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime import logging import socket import absl from absl import flags from absl.testing import absltest +from grpc_channelz.v1 import channelz_pb2 import framework import framework.helpers.docker @@ -39,6 +41,14 @@ "Host name all the services are bound on", ) _NODE_ID = flags.DEFINE_string("node", "test-id", "Node ID") +_STATUS_TIMEOUT_MS = flags.DEFINE_integer( + "status_timeout_ms", + 15000, + "Duration (in ms) that the test will wait for xDS channel to change the status", +) +_STATUS_POLL_INTERVAL_MS = flags.DEFINE_integer( + "status_poll_interval_ms", 300, "Channel status poll interval (in ms)" +) _LISTENER = "listener_0" @@ -58,7 +68,8 @@ class FallbackTest(absltest.TestCase): def setUpClass(): FallbackTest.bootstrap = framework.helpers.docker.Bootstrap( framework.helpers.logs.log_dir_mkdir("bootstrap"), - ports=[get_free_port() for _ in range(2)], + primary_port=get_free_port(), + fallback_port=get_free_port(), host_name=_HOST_NAME.value, ) @@ -80,13 +91,13 @@ def start_client(self, port: int = None, name: str = None): ) def start_control_plane( - self, name: str, index: int, upstream_port: int, cluster_name=None + self, name: str, port: int, upstream_port: int, cluster_name=None ): logger.debug('Starting control plane "%s"', name) return framework.helpers.docker.ControlPlane( self.process_manager, name=name, - port=self.bootstrap.xds_config_server_port(index), + port=port, initial_resources=framework.helpers.xds_resources.build_listener_and_cluster( listener_name=_LISTENER, cluster_name=cluster_name or f"initial_cluster_for_{name}", @@ -108,41 +119,94 @@ def start_server(self, name: str, port: int = None): command=[], ) + def assert_ads_connections( + self, + client: framework.helpers.docker.Client, + primary_status: channelz_pb2.ChannelConnectivityState, + fallback_status: channelz_pb2.ChannelConnectivityState, + ): + self.assertEqual( + client.expect_channel_status( + self.bootstrap.primary_port, + primary_status, + timeout=datetime.timedelta( + milliseconds=_STATUS_TIMEOUT_MS.value + ), + poll_interval=datetime.timedelta( + milliseconds=_STATUS_POLL_INTERVAL_MS.value + ), + ), + primary_status, + ) + self.assertEqual( + client.expect_channel_status( + self.bootstrap.fallback_port, + fallback_status, + timeout=datetime.timedelta( + milliseconds=_STATUS_TIMEOUT_MS.value + ), + poll_interval=datetime.timedelta( + milliseconds=_STATUS_POLL_INTERVAL_MS.value + ), + ), + fallback_status, + ) + def test_fallback_on_startup(self): with ( self.start_server(name="server1") as server1, self.start_server(name="server2") as server2, self.start_client() as client, ): - self.assertTrue( - client.expect_message_in_output( - "UNAVAILABLE: xDS channel for server" - ) + self.assert_ads_connections( + client=client, + primary_status=channelz_pb2.ChannelConnectivityState.TRANSIENT_FAILURE, + fallback_status=channelz_pb2.ChannelConnectivityState.TRANSIENT_FAILURE, ) self.assertEqual(client.get_stats(5).num_failures, 5) # Fallback control plane start, send traffic to server2 with self.start_control_plane( name="fallback_xds_config", - index=1, + port=self.bootstrap.fallback_port, upstream_port=server2.port, ): + self.assert_ads_connections( + client=client, + primary_status=channelz_pb2.ChannelConnectivityState.TRANSIENT_FAILURE, + fallback_status=channelz_pb2.ChannelConnectivityState.READY, + ) stats = client.get_stats(5) self.assertGreater(stats.rpcs_by_peer["server2"], 0) self.assertNotIn("server1", stats.rpcs_by_peer) # Primary control plane start. Will use it with self.start_control_plane( name="primary_xds_config", - index=0, + port=self.bootstrap.primary_port, upstream_port=server1.port, ): + self.assert_ads_connections( + client=client, + primary_status=channelz_pb2.ChannelConnectivityState.READY, + fallback_status=None, + ) stats = client.get_stats(10) self.assertEqual(stats.num_failures, 0) self.assertIn("server1", stats.rpcs_by_peer) self.assertGreater(stats.rpcs_by_peer["server1"], 0) + self.assert_ads_connections( + client=client, + primary_status=channelz_pb2.ChannelConnectivityState.TRANSIENT_FAILURE, + fallback_status=None, + ) # Primary control plane down, cached value is used stats = client.get_stats(5) self.assertEqual(stats.num_failures, 0) self.assertEqual(stats.rpcs_by_peer["server1"], 5) + self.assert_ads_connections( + client=client, + primary_status=channelz_pb2.ChannelConnectivityState.TRANSIENT_FAILURE, + fallback_status=None, + ) # Fallback control plane down, cached value is used stats = client.get_stats(5) self.assertEqual(stats.num_failures, 0) @@ -154,18 +218,27 @@ def test_fallback_mid_startup(self): self.start_server(name="server1") as server1, self.start_server(name="server2") as server2, self.start_control_plane( - "primary_xds_config_run_1", 0, server1.port, "cluster_name" + "primary_xds_config_run_1", + port=self.bootstrap.primary_port, + upstream_port=server1.port, + cluster_name="cluster_name", ) as primary, - self.start_control_plane("fallback_xds_config", 1, server2.port), + self.start_control_plane( + "fallback_xds_config", + port=self.bootstrap.fallback_port, + upstream_port=server2.port, + ), ): primary.stop_on_resource_request( "type.googleapis.com/envoy.config.cluster.v3.Cluster", "cluster_name", ) # Run client - with (self.start_client() as client,): - self.assertTrue( - client.expect_message_in_output("creating xds client") + with self.start_client() as client: + self.assert_ads_connections( + client, + primary_status=channelz_pb2.ChannelConnectivityState.TRANSIENT_FAILURE, + fallback_status=channelz_pb2.ChannelConnectivityState.READY, ) # Secondary xDS config start, send traffic to server2 stats = client.get_stats(5) @@ -174,8 +247,15 @@ def test_fallback_mid_startup(self): self.assertNotIn("server1", stats.rpcs_by_peer) # Rerun primary control plane with self.start_control_plane( - "primary_xds_config_run_2", 0, server1.port + "primary_xds_config_run_2", + port=self.bootstrap.primary_port, + upstream_port=server1.port, ): + self.assert_ads_connections( + client, + primary_status=channelz_pb2.ChannelConnectivityState.READY, + fallback_status=None, + ) stats = client.get_stats(10) self.assertEqual(stats.num_failures, 0) self.assertIn("server1", stats.rpcs_by_peer) @@ -187,13 +267,21 @@ def test_fallback_mid_update(self): self.start_server(name="server2") as server2, self.start_server(name="server3") as server3, self.start_control_plane( - "primary_xds_config_run_1", 0, server1.port + "primary_xds_config_run_1", + port=self.bootstrap.primary_port, + upstream_port=server1.port, ) as primary, - self.start_control_plane("fallback_xds_config", 1, server2.port), + self.start_control_plane( + "fallback_xds_config", + port=self.bootstrap.fallback_port, + upstream_port=server2.port, + ), self.start_client() as client, ): - self.assertTrue( - client.expect_message_in_output("creating xds client") + self.assert_ads_connections( + client, + primary_status=channelz_pb2.ChannelConnectivityState.READY, + fallback_status=None, ) # Secondary xDS config start, send traffic to server2 stats = client.get_stats(5) @@ -210,15 +298,25 @@ def test_fallback_mid_update(self): upstream_host=_HOST_NAME.value, ) ) + self.assert_ads_connections( + client, + primary_status=channelz_pb2.ChannelConnectivityState.TRANSIENT_FAILURE, + fallback_status=channelz_pb2.ChannelConnectivityState.READY, + ) stats = client.get_stats(10) self.assertEqual(stats.num_failures, 0) self.assertIn("server2", stats.rpcs_by_peer) # Check that post-recovery uses a new config with self.start_control_plane( name="primary_xds_config_run_2", - index=0, + port=self.bootstrap.primary_port, upstream_port=server3.port, ): + self.assert_ads_connections( + client, + primary_status=channelz_pb2.ChannelConnectivityState.READY, + fallback_status=None, + ) stats = client.get_stats(20) self.assertEqual(stats.num_failures, 0) self.assertIn("server3", stats.rpcs_by_peer)