Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support proxy of binary messages from addons to HA #4605

Merged
merged 5 commits into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions supervisor/api/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import aiohttp
from aiohttp import web
from aiohttp.client_exceptions import ClientConnectorError
from aiohttp.client_ws import ClientWebSocketResponse
from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE
from aiohttp.http import WSMessage
from aiohttp.http_websocket import WSMsgType
from aiohttp.web_exceptions import HTTPBadGateway, HTTPUnauthorized

from ..coresys import CoreSysAttributes
Expand Down Expand Up @@ -114,7 +117,7 @@ async def api(self, request: web.Request):
body=data, status=client.status, content_type=client.content_type
)

async def _websocket_client(self):
async def _websocket_client(self) -> ClientWebSocketResponse:
"""Initialize a WebSocket API connection."""
url = f"{self.sys_homeassistant.api_url}/api/websocket"

Expand Down Expand Up @@ -167,6 +170,25 @@ async def _websocket_client(self):

raise APIError()

async def _proxy_message(
self,
read_task: asyncio.Task,
target: web.WebSocketResponse | ClientWebSocketResponse,
) -> None:
"""Proxy a message from client to server or vice versa."""
if read_task.exception():
raise read_task.exception()

msg: WSMessage = read_task.result()
if msg.type == WSMsgType.TEXT:
return await target.send_str(msg.data)
if msg.type == WSMsgType.BINARY:
return await target.send_bytes(msg.data)

raise TypeError(
f"Cannot proxy websocket message of unsupported type: {msg.type}"
)

async def websocket(self, request: web.Request):
"""Initialize a WebSocket API connection."""
if not await self.sys_homeassistant.api.check_api_state():
Expand Down Expand Up @@ -214,13 +236,13 @@ async def websocket(self, request: web.Request):

_LOGGER.info("Home Assistant WebSocket API request running")
try:
client_read = None
server_read = None
client_read: asyncio.Task | None = None
server_read: asyncio.Task | None = None
while not server.closed and not client.closed:
if not client_read:
client_read = self.sys_create_task(client.receive_str())
client_read = self.sys_create_task(client.receive())
if not server_read:
server_read = self.sys_create_task(server.receive_str())
server_read = self.sys_create_task(server.receive())

# wait until data need to be processed
await asyncio.wait(
Expand All @@ -229,14 +251,12 @@ async def websocket(self, request: web.Request):

# server
if server_read.done() and not client.closed:
server_read.exception()
await client.send_str(server_read.result())
await self._proxy_message(server_read, client)
server_read = None

# client
if client_read.done() and not server.closed:
client_read.exception()
await server.send_str(client_read.result())
await self._proxy_message(client_read, server)
client_read = None

except asyncio.CancelledError:
Expand All @@ -246,9 +266,9 @@ async def websocket(self, request: web.Request):
_LOGGER.info("Home Assistant WebSocket API error: %s", err)

finally:
if client_read:
if client_read and not client_read.done():
client_read.cancel()
if server_read:
if server_read and not server_read.done():
server_read.cancel()

# close connections
Expand Down
177 changes: 177 additions & 0 deletions tests/api/test_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""Test Home Assistant proxy."""

from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Generator
from json import dumps
from typing import Any, cast
from unittest.mock import patch

from aiohttp import ClientWebSocketResponse
from aiohttp.http_websocket import WSMessage, WSMsgType
from aiohttp.test_utils import TestClient
import pytest

from supervisor.addons.addon import Addon
from supervisor.api.proxy import APIProxy
from supervisor.const import ATTR_ACCESS_TOKEN


def id_generator() -> Generator[int, None, None]:
"""Generate IDs for WS messages."""
i = 0
while True:
yield (i := i + 1)


class MockHAClientWebSocket(ClientWebSocketResponse):
"""Protocol for a wrapped ClientWebSocketResponse."""

client: TestClient
send_json_auto_id: Callable[[dict[str, Any]], Coroutine[Any, Any, None]]


class MockHAServerWebSocket:
"""Mock of HA Websocket server."""

closed: bool = False

def __init__(self) -> None:
"""Initialize object."""
self.outgoing: asyncio.Queue[WSMessage] = asyncio.Queue()
self.incoming: asyncio.Queue[WSMessage] = asyncio.Queue()
self._id_generator = id_generator()

def receive(self) -> Awaitable[WSMessage]:
"""Receive next message."""
return self.outgoing.get()

def send_str(self, data: str) -> Awaitable[None]:
"""Incoming string message."""
return self.incoming.put(WSMessage(WSMsgType.TEXT, data, None))

def send_bytes(self, data: bytes) -> Awaitable[None]:
"""Incoming string message."""
return self.incoming.put(WSMessage(WSMsgType.BINARY, data, None))

def respond_json(self, data: dict[str, Any]) -> Awaitable[None]:
"""Respond with JSON."""
return self.outgoing.put(
WSMessage(
WSMsgType.TEXT, dumps(data | {"id": next(self._id_generator)}), None
)
)

def respond_bytes(self, data: bytes) -> Awaitable[None]:
"""Respond with binary."""
return self.outgoing.put(WSMessage(WSMsgType.BINARY, data, None))

async def close(self) -> None:
"""Close connection."""
self.closed = True


WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]]


@pytest.fixture(name="ha_ws_server")
async def fixture_ha_ws_server() -> MockHAServerWebSocket:
"""Mock HA WS server for testing."""
with patch.object(
APIProxy,
"_websocket_client",
return_value=(mock_server := MockHAServerWebSocket()),
):
yield mock_server


@pytest.fixture(name="proxy_ws_client")
def fixture_proxy_ws_client(
api_client: TestClient, ha_ws_server: MockHAServerWebSocket
) -> WebSocketGenerator:
"""Websocket client fixture connected to websocket server."""

async def create_client(auth_token: str) -> MockHAClientWebSocket:
"""Create a websocket client."""
websocket = await api_client.ws_connect("/core/websocket")
auth_resp = await websocket.receive_json()
assert auth_resp["type"] == "auth_required"
await websocket.send_json({"type": "auth", "access_token": auth_token})

auth_ok = await websocket.receive_json()
assert auth_ok["type"] == "auth_ok"

_id_generator = id_generator()

def _send_json_auto_id(data: dict[str, Any]) -> Coroutine[Any, Any, None]:
data["id"] = next(_id_generator)
return websocket.send_json(data)

# wrap in client
wrapped_websocket = cast(MockHAClientWebSocket, websocket)
wrapped_websocket.client = api_client
wrapped_websocket.send_json_auto_id = _send_json_auto_id
return wrapped_websocket

return create_client


async def test_proxy_message(
proxy_ws_client: WebSocketGenerator,
ha_ws_server: MockHAServerWebSocket,
install_addon_ssh: Addon,
):
"""Test proxy a message to and from Home Assistant."""
install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123"
client: MockHAClientWebSocket = await proxy_ws_client(
install_addon_ssh.supervisor_token
)

await client.send_json_auto_id({"hello": "world"})
proxied_msg = await ha_ws_server.incoming.get()
assert proxied_msg.type == WSMsgType.TEXT
assert proxied_msg.data == '{"hello": "world", "id": 1}'

await ha_ws_server.respond_json({"world": "received"})
assert await client.receive_json() == {"world": "received", "id": 1}

assert await client.close()


async def test_proxy_binary_message(
proxy_ws_client: WebSocketGenerator,
ha_ws_server: MockHAServerWebSocket,
install_addon_ssh: Addon,
):
"""Test proxy a binary message to and from Home Assistant."""
install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123"
client: MockHAClientWebSocket = await proxy_ws_client(
install_addon_ssh.supervisor_token
)

await client.send_bytes(b"hello world")
proxied_msg = await ha_ws_server.incoming.get()
assert proxied_msg.type == WSMsgType.BINARY
assert proxied_msg.data == b"hello world"

await ha_ws_server.respond_bytes(b"world received")
assert await client.receive_bytes() == b"world received"

assert await client.close()


@pytest.mark.parametrize("auth_token", ["abc123", "bad"])
async def test_proxy_invalid_auth(
api_client: TestClient, install_addon_example: Addon, auth_token: str
):
"""Test invalid access token or addon with no access."""
install_addon_example.persist[ATTR_ACCESS_TOKEN] = "abc123"
websocket = await api_client.ws_connect("/core/websocket")
auth_resp = await websocket.receive_json()
assert auth_resp["type"] == "auth_required"
await websocket.send_json({"type": "auth", "access_token": auth_token})

auth_not_ok = await websocket.receive_json()
assert auth_not_ok["type"] == "auth_invalid"
assert auth_not_ok["message"] == "Invalid access"
1 change: 1 addition & 0 deletions tests/fixtures/addons/local/ssh/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ panel_icon: "mdi:console"
panel_title: Terminal
hassio_api: true
hassio_role: manager
homeassistant_api: true
audio: true
uart: true
ports:
Expand Down