Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize gateway transport #1898

Merged
merged 4 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changes/1898.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Optimize gateway transport
- Merge cold path for zlib compression into main path to avoid additional call
- Handle data in `bytes`, rather than in `str` to make good use of speedups (similar to `RESTClient`)
2 changes: 1 addition & 1 deletion hikari/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class GatewayTransportError(GatewayError):
"""An exception thrown if an issue occurs at the transport layer."""

def __str__(self) -> str:
return f"Gateway transport error: {self.reason!r}"
return f"Gateway transport error: {self.reason}"


@attrs.define(auto_exc=True, repr=False, slots=False)
Expand Down
47 changes: 24 additions & 23 deletions hikari/impl/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@
_CUSTOM_STATUS_NAME = "Custom Status"


def _log_filterer(token: str) -> typing.Callable[[str], str]:
def filterer(entry: str) -> str:
return entry.replace(token, "**REDACTED TOKEN**")
def _log_filterer(token: bytes) -> typing.Callable[[bytes], bytes]:
def filterer(entry: bytes) -> bytes:
return entry.replace(token, b"**REDACTED TOKEN**")

return filterer

Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(
transport_compression: bool,
exit_stack: contextlib.AsyncExitStack,
logger: logging.Logger,
log_filterer: typing.Callable[[str], str],
log_filterer: typing.Callable[[bytes], bytes],
dumps: data_binding.JSONEncoder,
loads: data_binding.JSONDecoder,
) -> None:
Expand Down Expand Up @@ -203,7 +203,7 @@ async def receive_json(self) -> typing.Any:
async def send_json(self, data: data_binding.JSONObject) -> None:
pl = self._dumps(data)
if self._logger.isEnabledFor(ux.TRACE):
filtered = self._log_filterer(pl.decode("utf-8"))
filtered = self._log_filterer(pl)
self._logger.log(ux.TRACE, "sending payload with size %s\n %s", len(pl), filtered)

await self._ws.send_bytes(pl)
Expand Down Expand Up @@ -232,39 +232,40 @@ def _handle_other_message(self, message: aiohttp.WSMessage, /) -> typing.NoRetur
reason = f"{message.data!r} [extra={message.extra!r}, type={message.type}]"
raise errors.GatewayTransportError(reason) from self._ws.exception()

async def _receive_and_check_text(self) -> str:
async def _receive_and_check_text(self) -> bytes:
message = await self._ws.receive()

if message.type == aiohttp.WSMsgType.TEXT:
assert isinstance(message.data, str)
return message.data
return message.data.encode()

self._handle_other_message(message)

async def _receive_and_check_zlib(self) -> str:
async def _receive_and_check_zlib(self) -> bytes:
message = await self._ws.receive()

if message.type == aiohttp.WSMsgType.BINARY:
if message.data.endswith(_ZLIB_SUFFIX):
return self._zlib.decompress(message.data).decode("utf-8")

return await self._receive_and_check_complete_zlib_package(message.data)
# Hot and fast path: we already have the full message
# in a single frame
return self._zlib.decompress(message.data)

self._handle_other_message(message)
# Cold and slow path: we need to keep receiving frames to complete
# the whole message. Only then do we create a buffer
buff = bytearray(message.data)

async def _receive_and_check_complete_zlib_package(self, initial_data: bytes, /) -> str:
buff = bytearray(initial_data)
while not buff.endswith(_ZLIB_SUFFIX):
message = await self._ws.receive()

while not buff.endswith(_ZLIB_SUFFIX):
message = await self._ws.receive()
if message.type == aiohttp.WSMsgType.BINARY:
buff.extend(message.data)
continue

if message.type == aiohttp.WSMsgType.BINARY:
buff.extend(message.data)
continue
self._handle_other_message(message)

self._handle_other_message(message)
return self._zlib.decompress(buff)

return self._zlib.decompress(buff).decode("utf-8")
self._handle_other_message(message)

@classmethod
async def connect(
Expand All @@ -273,7 +274,7 @@ async def connect(
http_settings: config.HTTPSettings,
logger: logging.Logger,
proxy_settings: config.ProxySettings,
log_filterer: typing.Callable[[str], str],
log_filterer: typing.Callable[[bytes], bytes],
dumps: data_binding.JSONEncoder,
loads: data_binding.JSONDecoder,
transport_compression: bool,
Expand Down Expand Up @@ -810,7 +811,7 @@ async def _connect(self) -> typing.Tuple[asyncio.Task[None], ...]:

self._ws = await _GatewayTransport.connect(
http_settings=self._http_settings,
log_filterer=_log_filterer(self._token),
log_filterer=_log_filterer(self._token.encode()),
logger=self._logger,
proxy_settings=self._proxy_settings,
transport_compression=self._transport_compression,
Expand Down
103 changes: 36 additions & 67 deletions tests/hikari/impl/test_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@


def test_log_filterer():
filterer = shard._log_filterer("TOKEN")
filterer = shard._log_filterer(b"TOKEN")

returned = filterer("this log contains the TOKEN and it should get removed and the TOKEN here too")
returned = filterer(b"this log contains the TOKEN and it should get removed and the TOKEN here too")
assert returned == (
"this log contains the **REDACTED TOKEN** and it should get removed and the **REDACTED TOKEN** here too"
b"this log contains the **REDACTED TOKEN** and it should get removed and the **REDACTED TOKEN** here too"
)


Expand Down Expand Up @@ -275,100 +275,69 @@ def test__handle_other_message_when_message_type_is_unknown(self, transport_impl
assert exc_info.value.__cause__ is exception

@pytest.mark.asyncio
async def test__receive_and_check_text_when_message_type_is_TEXT(self, transport_impl):
async def test__receive_and_check_text(self, transport_impl):
transport_impl._ws.receive = mock.AsyncMock(
return_value=StubResponse(type=aiohttp.WSMsgType.TEXT, data="some text")
)

assert await transport_impl._receive_and_check_text() == "some text"
assert await transport_impl._receive_and_check_text() == b"some text"

transport_impl._ws.receive.assert_awaited_once_with()

@pytest.mark.asyncio
async def test__receive_and_check_text_when_message_type_is_unknown(self, transport_impl):
mock_exception = errors.GatewayError("aye")
transport_impl._ws.receive = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.BINARY))

with mock.patch.object(
shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception
) as handle_other_message:
with pytest.raises(errors.GatewayError) as exc_info:
await transport_impl._receive_and_check_text()
with pytest.raises(
errors.GatewayTransportError,
match="Gateway transport error: Unexpected message type received BINARY, expected TEXT",
):
await transport_impl._receive_and_check_text()

assert exc_info.value is mock_exception
transport_impl._ws.receive.assert_awaited_once_with()
handle_other_message.assert_called_once_with(transport_impl._ws.receive.return_value)

@pytest.mark.asyncio
async def test__receive_and_check_zlib_when_message_type_is_BINARY(self, transport_impl):
response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some initial data")
transport_impl._ws.receive = mock.AsyncMock(return_value=response)
async def test__receive_and_check_zlib_when_payload_split_across_frames(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xda\xf2H\xcd\xc9")
response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\xc9W(\xcf/\xcaIQ\x04\x00\x00")
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\xff\xff")
transport_impl._ws.receive = mock.AsyncMock(side_effect=[response1, response2, response3])

with mock.patch.object(
shard._GatewayTransport, "_receive_and_check_complete_zlib_package"
) as receive_and_check_complete_zlib_package:
assert (
await transport_impl._receive_and_check_zlib() is receive_and_check_complete_zlib_package.return_value
)
assert await transport_impl._receive_and_check_zlib() == b"Hello world!"

transport_impl._ws.receive.assert_awaited_once_with()
receive_and_check_complete_zlib_package.assert_awaited_once_with(b"some initial data")
assert transport_impl._ws.receive.call_count == 3

@pytest.mark.asyncio
async def test__receive_and_check_zlib_when_message_type_is_BINARY_and_the_full_payload(self, transport_impl):
response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some initial data\x00\x00\xff\xff")
async def test__receive_and_check_zlib_when_full_payload_in_one_frame(self, transport_impl):
response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xdaJLD\x07\x00\x00\x00\x00\xff\xff")
transport_impl._ws.receive = mock.AsyncMock(return_value=response)
transport_impl._zlib = mock.Mock(decompress=mock.Mock(return_value=b"aaaaaaaaaaaaaaaaaa"))

assert await transport_impl._receive_and_check_zlib() == "aaaaaaaaaaaaaaaaaa"
assert await transport_impl._receive_and_check_zlib() == b"aaaaaaaaaaaaaaaaaa"

transport_impl._ws.receive.assert_awaited_once_with()
transport_impl._zlib.decompress.assert_called_once_with(response.data)

@pytest.mark.asyncio
async def test__receive_and_check_zlib_when_message_type_is_unknown(self, transport_impl):
mock_exception = errors.GatewayError("aye")
transport_impl._ws.receive = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.TEXT))

with mock.patch.object(
shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception
) as handle_other_message:
with pytest.raises(errors.GatewayError) as exc_info:
await transport_impl._receive_and_check_zlib()

assert exc_info.value is mock_exception
transport_impl._ws.receive.assert_awaited_once_with()
handle_other_message.assert_called_once_with(transport_impl._ws.receive.return_value)

@pytest.mark.asyncio
async def test__receive_and_check_complete_zlib_package_for_unexpected_message_type(self, transport_impl):
mock_exception = errors.GatewayError("aye")
response = StubResponse(type=aiohttp.WSMsgType.TEXT)
transport_impl._ws.receive = mock.AsyncMock(return_value=response)

with mock.patch.object(
shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception
) as handle_other_message:
with pytest.raises(errors.GatewayError) as exc_info:
await transport_impl._receive_and_check_complete_zlib_package(b"some")

assert exc_info.value is mock_exception
transport_impl._ws.receive.assert_awaited_with()
handle_other_message.assert_called_once_with(response)
with pytest.raises(
errors.GatewayTransportError,
match="Gateway transport error: Unexpected message type received TEXT, expected BINARY",
):
await transport_impl._receive_and_check_zlib()

@pytest.mark.asyncio
async def test__receive_and_check_complete_zlib_package(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"more")
response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"data")
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\x00\xff\xff")
async def test__receive_and_check_zlib_when_issue_during_reception_of_multiple_frames(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xda\xf2H\xcd\xc9")
response2 = StubResponse(type=aiohttp.WSMsgType.ERROR, data="Something broke!")
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\xff\xff")
transport_impl._ws.receive = mock.AsyncMock(side_effect=[response1, response2, response3])
transport_impl._zlib = mock.Mock(decompress=mock.Mock(return_value=b"decoded utf-8 encoded bytes"))

assert await transport_impl._receive_and_check_complete_zlib_package(b"some") == "decoded utf-8 encoded bytes"
transport_impl._ws.exception = mock.Mock(return_value=None)

assert transport_impl._ws.receive.call_count == 3
transport_impl._ws.receive.assert_has_awaits([mock.call(), mock.call(), mock.call()])
transport_impl._zlib.decompress.assert_called_once_with(bytearray(b"somemoredata\x00\x00\xff\xff"))
with pytest.raises(
errors.GatewayTransportError, match=r"Gateway transport error: 'Something broke!' \[extra=None, type=258\]"
):
await transport_impl._receive_and_check_zlib()

@pytest.mark.parametrize("transport_compression", [True, False])
@pytest.mark.asyncio
Expand Down Expand Up @@ -1002,7 +971,7 @@ async def test__connect_when_not_reconnecting(self, client, http_settings, proxy
with stack:
assert await client._connect() == (heartbeat_task, poll_events_task)

log_filterer.assert_called_once_with("sometoken")
log_filterer.assert_called_once_with(b"sometoken")
gateway_transport_connect.assert_called_once_with(
http_settings=http_settings,
log_filterer=log_filterer.return_value,
Expand Down Expand Up @@ -1087,7 +1056,7 @@ async def test__connect_when_reconnecting(self, client, http_settings, proxy_set
with stack:
assert await client._connect() == (heartbeat_task, poll_events_task)

log_filterer.assert_called_once_with("sometoken")
log_filterer.assert_called_once_with(b"sometoken")
gateway_transport_connect.assert_called_once_with(
http_settings=http_settings,
log_filterer=log_filterer.return_value,
Expand Down
Loading