diff --git a/src/meatie/aio/descriptor.py b/src/meatie/aio/descriptor.py index 15735e5..ebe0567 100644 --- a/src/meatie/aio/descriptor.py +++ b/src/meatie/aio/descriptor.py @@ -59,6 +59,7 @@ def __init__( self.response_decoder = response_decoder self.get_json: Optional[Callable[[Any], Awaitable[dict[str, Any]]]] = None self.get_text: Optional[Callable[[Any], Awaitable[str]]] = None + self.get_error: Optional[Callable[[AsyncResponse], Awaitable[Optional[Exception]]]] = None self.__operator_by_priority: dict[int, AsyncOperator[ResponseBodyType]] = {} def __set_name__(self, owner: type[object], name: str) -> None: @@ -91,7 +92,13 @@ def __get__( operators = [operator for _, operator in priority_operator_pair] return BoundAsyncEndpointDescriptor( # type: ignore[return-value] - instance, operators, self.template, self.response_decoder, self.get_json, self.get_text + instance, + operators, + self.template, + self.response_decoder, + self.get_json, + self.get_text, + self.get_error, ) @@ -102,8 +109,9 @@ def __init__( operators: Iterable[AsyncOperator[ResponseBodyType]], template: RequestTemplate[Any], response_decoder: TypeAdapter[ResponseBodyType], - get_json: Optional[Callable[[Any], Awaitable[dict[str, Any]]]], + get_json: Optional[Callable[[Any], Awaitable[Any]]], get_text: Optional[Callable[[Any], Awaitable[str]]], + get_error: Optional[Callable[[AsyncResponse], Awaitable[Optional[Exception]]]], ) -> None: self.__instance = instance self.__operators = list(operators) @@ -112,6 +120,7 @@ def __init__( self.__response_decoder = response_decoder self.__get_json = get_json self.__get_text = get_text + self.__get_error = get_error async def __call__(self, *args: PT.args, **kwargs: PT.kwargs) -> ResponseBodyType: request = self.__template.build_request(*args, **kwargs) @@ -124,5 +133,12 @@ async def __send_request(self, context: AsyncContext[ResponseBodyType]) -> Respo response.get_json = self.__get_json if self.__get_text is not None: response.get_text = self.__get_text + context.response = response + + if self.__get_error is not None: + error = await self.__get_error(response) + if error is not None: + raise error + return await self.__response_decoder.from_async_response(response) diff --git a/src/meatie/descriptor.py b/src/meatie/descriptor.py index 3de04cd..d7840dc 100644 --- a/src/meatie/descriptor.py +++ b/src/meatie/descriptor.py @@ -27,8 +27,9 @@ def __init__( ) -> None: self.template = template self.response_decoder = response_decoder - self.get_json: Optional[Callable[[Any], dict[str, Any]]] = None + self.get_json: Optional[Callable[[Any], Any]] = None self.get_text: Optional[Callable[[Any], str]] = None + self.get_error: Optional[Callable[[Response], Optional[Exception]]] = None self.__operator_by_priority: dict[int, Operator[ResponseBodyType]] = {} def __set_name__(self, owner: type[object], name: str) -> None: @@ -61,7 +62,13 @@ def __get__( operators = [operator for _, operator in priority_operator_pair] return BoundEndpointDescriptor( # type: ignore[return-value] - instance, operators, self.template, self.response_decoder, self.get_json, self.get_text + instance, + operators, + self.template, + self.response_decoder, + self.get_json, + self.get_text, + self.get_error, ) @@ -110,6 +117,7 @@ def __init__( response_decoder: TypeAdapter[ResponseBodyType], get_json: Optional[Callable[[Any], dict[str, Any]]], get_text: Optional[Callable[[Any], str]], + get_error: Optional[Callable[[Response], Optional[Exception]]], ) -> None: self.__instance = instance self.__operators = list(operators) @@ -118,6 +126,7 @@ def __init__( self.__response_decoder = response_decoder self.__get_json = get_json self.__get_text = get_text + self.__get_error = get_error def __call__(self, *args: PT.args, **kwargs: PT.kwargs) -> ResponseBodyType: request = self.__template.build_request(*args, **kwargs) @@ -131,4 +140,10 @@ def __send_request(self, context: Context[ResponseBodyType]) -> ResponseBodyType if self.__get_text is not None: response.get_text = self.__get_text context.response = response + + if self.__get_error is not None: + error = self.__get_error(response) + if error is not None: + raise error + return self.__response_decoder.from_response(response) diff --git a/src/meatie/option/body_option.py b/src/meatie/option/body_option.py index e21a5c9..bb8cfa7 100644 --- a/src/meatie/option/body_option.py +++ b/src/meatie/option/body_option.py @@ -2,7 +2,7 @@ # Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. from typing import Any, Awaitable, Callable, Optional, Union -from meatie import EndpointDescriptor +from meatie import AsyncResponse, EndpointDescriptor, Response from meatie.aio import AsyncEndpointDescriptor from meatie.internal.types import PT, T @@ -10,12 +10,18 @@ class BodyOption: - __slots__ = ("json", "text") + __slots__ = ("json", "text", "error") def __init__( self, json: Optional[Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]]]] = None, text: Optional[Union[Callable[[Any], str], Callable[[Any], Awaitable[str]]]] = None, + error: Optional[ + Union[ + Callable[[Response], Optional[Exception]], + Callable[[AsyncResponse], Awaitable[Optional[Exception]]], + ] + ] = None, ) -> None: """ Customize handling of HTTP response body. @@ -24,16 +30,19 @@ def __init__( implemented in the client library :param text: function to apply on the HTTP response to extract response text, otherwise use the default method implemented in the client library + :param error: function to apply on the HTTP response to extract an error """ self.json = json self.text = text + self.error = error def __call__( self, descriptor: Union[EndpointDescriptor[PT, T], AsyncEndpointDescriptor[PT, T]] ) -> None: descriptor.get_text = self.text descriptor.get_json = self.json + descriptor.get_error = self.error body = BodyOption diff --git a/tests/client/aiohttp/adapter/test_json.py b/tests/client/aiohttp/adapter/test_json.py index 507842c..dd2cadb 100644 --- a/tests/client/aiohttp/adapter/test_json.py +++ b/tests/client/aiohttp/adapter/test_json.py @@ -4,11 +4,18 @@ from decimal import Decimal from functools import partial from http import HTTPStatus -from http.server import BaseHTTPRequestHandler import pytest from aiohttp import ClientResponse, ClientSession from http_test import HTTPTestServer +from http_test.handlers import ( + NGINX_GATEWAY_TIMEOUT, + magic_number, + nginx_gateway_timeout, + status_ok, + status_ok_as_text, + truncated_json, +) from meatie import ParseResponseError, body, endpoint from meatie_aiohttp import Client @@ -16,13 +23,7 @@ @pytest.mark.asyncio() async def test_can_parse_json(http_server: HTTPTestServer) -> None: # GIVEN - def handler(request: BaseHTTPRequestHandler) -> None: - request.send_response(HTTPStatus.OK) - request.send_header("Content-Type", "application/json") - request.end_headers() - request.wfile.write('{"status": "ok"}'.encode("utf-8")) - - http_server.handler = handler + http_server.handler = status_ok class TestClient(Client): @endpoint("/") @@ -40,14 +41,7 @@ async def get_response(self) -> dict[str, str]: @pytest.mark.asyncio() async def test_can_handle_invalid_content_type(http_server: HTTPTestServer) -> None: # GIVEN - content = "{'status': 'ok'}" - - def handler(request: BaseHTTPRequestHandler) -> None: - request.send_response(HTTPStatus.OK) - request.end_headers() - request.wfile.write(content.encode("utf-8")) - - http_server.handler = handler + http_server.handler = status_ok_as_text class TestClient(Client): @endpoint("/") @@ -61,27 +55,14 @@ async def get_response(self) -> dict[str, str]: # THEN exc = exc_info.value - assert content == exc.text + assert "{'status': 'ok'}" == exc.text assert HTTPStatus.OK == exc.response.status @pytest.mark.asyncio() async def test_can_handle_html_response(http_server: HTTPTestServer) -> None: # GIVEN - content = ( - "" - "504 Gateway Time-out" - "

504 Gateway Time-out


nginx
" - "" - ) - - def handler(request: BaseHTTPRequestHandler) -> None: - request.send_response(HTTPStatus.GATEWAY_TIMEOUT) - request.send_header("Content-Type", "text/html") - request.end_headers() - request.wfile.write(content.encode("utf-8")) - - http_server.handler = handler + http_server.handler = nginx_gateway_timeout class TestClient(Client): @endpoint("/") @@ -95,20 +76,14 @@ async def get_response(self) -> dict[str, str]: # THEN exc = exc_info.value - assert content == exc.text + assert NGINX_GATEWAY_TIMEOUT == exc.text assert HTTPStatus.GATEWAY_TIMEOUT == exc.response.status @pytest.mark.asyncio() async def test_can_handle_corrupted_json(http_server: HTTPTestServer) -> None: # GIVEN - def handler(request: BaseHTTPRequestHandler) -> None: - request.send_response(HTTPStatus.OK) - request.send_header("Content-Type", "application/json") - request.end_headers() - request.wfile.write("{'status':".encode("utf-8")) - - http_server.handler = handler + http_server.handler = truncated_json class TestClient(Client): @endpoint("/") @@ -128,13 +103,7 @@ async def get_response(self) -> dict[str, str]: @pytest.mark.asyncio() async def test_use_custom_decoder(http_server: HTTPTestServer) -> None: # GIVEN response have json which will lose precision if parsed as float - def handler(request: BaseHTTPRequestHandler) -> None: - request.send_response(HTTPStatus.OK) - request.send_header("Content-Type", "application/json") - request.end_headers() - request.wfile.write('{"foo": 42.000000000000001}'.encode("utf-8")) - - http_server.handler = handler + http_server.handler = magic_number async def custom_json(response: ClientResponse) -> dict[str, Decimal]: return await response.json(loads=partial(json.loads, parse_float=Decimal)) @@ -149,4 +118,4 @@ async def get_response(self) -> dict[str, Decimal]: response = await client.get_response() # THEN - assert {"foo": Decimal("42.000000000000001")} == response + assert {"number": Decimal("42.000000000000001")} == response diff --git a/tests/client/aiohttp/adapter/test_string.py b/tests/client/aiohttp/adapter/test_string.py index 27184e4..efb6f09 100644 --- a/tests/client/aiohttp/adapter/test_string.py +++ b/tests/client/aiohttp/adapter/test_string.py @@ -1,11 +1,10 @@ # Copyright 2024 The Meatie Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. -from http import HTTPStatus -from http.server import BaseHTTPRequestHandler import pytest from aiohttp import ClientSession from http_test import HTTPTestServer +from http_test.handlers import ascii_emoji, emoji from meatie import endpoint from meatie_aiohttp import Client @@ -13,12 +12,7 @@ @pytest.mark.asyncio() async def test_can_parse_string(http_server: HTTPTestServer) -> None: # GIVEN - def handler(request: BaseHTTPRequestHandler) -> None: - request.send_response(HTTPStatus.OK) - request.end_headers() - request.wfile.write(bytes([0xF0, 0x9F, 0x9A, 0x80])) - - http_server.handler = handler + http_server.handler = emoji class TestClient(Client): @endpoint("/") @@ -36,13 +30,7 @@ async def get_response(self) -> str: @pytest.mark.asyncio() async def test_can_handle_invalid_encoding(http_server: HTTPTestServer) -> None: # GIVEN - def handler(request: BaseHTTPRequestHandler) -> None: - request.send_response(HTTPStatus.OK) - request.send_header("Content-Type", "text/plain; charset=ascii") - request.end_headers() - request.wfile.write(bytes([0xF0, 0x9F, 0x9A, 0x80])) - - http_server.handler = handler + http_server.handler = ascii_emoji class TestClient(Client): @endpoint("/") diff --git a/tests/client/aiohttp/test_error.py b/tests/client/aiohttp/test_error.py new file mode 100644 index 0000000..183e1d2 --- /dev/null +++ b/tests/client/aiohttp/test_error.py @@ -0,0 +1,43 @@ +# Copyright 2024 The Meatie Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. + +import pytest +from aiohttp import ClientSession +from http_test import HTTPTestServer +from http_test.handlers import service_unavailable +from meatie import AsyncResponse, HttpStatusError, ResponseError, body, endpoint +from meatie_aiohttp import Client + + +async def get_error(response: AsyncResponse) -> Exception | None: + exc_type = HttpStatusError if response.status >= 300 else ResponseError + + body = await response.json() + if isinstance(body, dict): + error = body.get("error") + if error is not None: + return exc_type(response, error) + + if response.status >= 300: + return exc_type(response) + + return None + + +@pytest.mark.asyncio() +async def test_raises_error(http_server: HTTPTestServer) -> None: + # GIVEN + http_server.handler = service_unavailable + + class TestClient(Client): + @endpoint("/", body(error=get_error)) + async def get_response(self) -> dict[str, str]: + ... + + # WHEN + with pytest.raises(HttpStatusError) as exc_info: + async with TestClient(ClientSession(http_server.base_url)) as client: + await client.get_response() + + # THEN + assert exc_info.value.args == ("deployment in progress",) diff --git a/tests/client/httpx/adapter/test_json.py b/tests/client/httpx/adapter/test_json.py index 2753e1a..5945438 100644 --- a/tests/client/httpx/adapter/test_json.py +++ b/tests/client/httpx/adapter/test_json.py @@ -1,11 +1,10 @@ # Copyright 2024 The Meatie Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. from decimal import Decimal -from http import HTTPStatus -from http.server import BaseHTTPRequestHandler import httpx from http_test import HTTPTestServer +from http_test.handlers import MAGIC_NUMBER, magic_number from httpx import Response from meatie import body, endpoint from meatie_httpx import Client @@ -13,13 +12,7 @@ def test_use_custom_decoder(http_server: HTTPTestServer) -> None: # GIVEN response have json which will lose precision if parsed as float - def handler(request: BaseHTTPRequestHandler) -> None: - request.send_response(HTTPStatus.OK) - request.send_header("Content-Type", "application/json") - request.end_headers() - request.wfile.write('{"foo": 42.000000000000001}'.encode("utf-8")) - - http_server.handler = handler + http_server.handler = magic_number def custom_json(response: Response) -> dict[str, Decimal]: return response.json(parse_float=Decimal) @@ -34,4 +27,4 @@ def get_response(self) -> dict[str, Decimal]: response = client.get_response() # THEN - assert {"foo": Decimal("42.000000000000001")} == response + assert {"number": MAGIC_NUMBER} == response diff --git a/tests/client/httpx/test_error.py b/tests/client/httpx/test_error.py new file mode 100644 index 0000000..038d8ff --- /dev/null +++ b/tests/client/httpx/test_error.py @@ -0,0 +1,42 @@ +# Copyright 2024 The Meatie Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. + +import httpx +import pytest +from http_test import HTTPTestServer +from http_test.handlers import service_unavailable +from meatie import HttpStatusError, Response, ResponseError, body, endpoint +from meatie_httpx import Client + + +def get_error(response: Response) -> Exception | None: + exc_type = HttpStatusError if response.status >= 300 else ResponseError + + body = response.json() + if isinstance(body, dict): + error = body.get("error") + if error is not None: + return exc_type(response, error) + + if response.status >= 300: + return exc_type(response) + + return None + + +def test_raises_error(http_server: HTTPTestServer) -> None: + # GIVEN + http_server.handler = service_unavailable + + class TestClient(Client): + @endpoint("/", body(error=get_error)) + def get_response(self) -> dict[str, str]: + ... + + # WHEN + with pytest.raises(HttpStatusError) as exc_info: + with TestClient(httpx.Client(base_url=http_server.base_url)) as client: + client.get_response() + + # THEN + assert exc_info.value.args == ("deployment in progress",) diff --git a/tests/client/requests/adapter/test_json.py b/tests/client/requests/adapter/test_json.py index 0b07018..f6f4e96 100644 --- a/tests/client/requests/adapter/test_json.py +++ b/tests/client/requests/adapter/test_json.py @@ -1,24 +1,17 @@ # Copyright 2024 The Meatie Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. from decimal import Decimal -from http import HTTPStatus -from http.server import BaseHTTPRequestHandler import requests from http_test import HTTPTestServer +from http_test.handlers import MAGIC_NUMBER, magic_number from meatie import body, endpoint from meatie_requests import Client def test_use_custom_decoder(http_server: HTTPTestServer) -> None: # GIVEN response have json which will lose precision if parsed as float - def handler(request: BaseHTTPRequestHandler) -> None: - request.send_response(HTTPStatus.OK) - request.send_header("Content-Type", "application/json") - request.end_headers() - request.wfile.write('{"foo": 42.000000000000001}'.encode("utf-8")) - - http_server.handler = handler + http_server.handler = magic_number def custom_json(response: requests.Response) -> dict[str, Decimal]: return response.json(parse_float=Decimal) @@ -33,4 +26,4 @@ def get_response(self) -> dict[str, Decimal]: response = client.get_response() # THEN - assert {"foo": Decimal("42.000000000000001")} == response + assert {"number": MAGIC_NUMBER} == response diff --git a/tests/client/requests/test_error.py b/tests/client/requests/test_error.py new file mode 100644 index 0000000..2e5ac3d --- /dev/null +++ b/tests/client/requests/test_error.py @@ -0,0 +1,41 @@ +# Copyright 2024 The Meatie Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. +import pytest +import requests +from http_test import HTTPTestServer +from http_test.handlers import service_unavailable +from meatie import HttpStatusError, Response, ResponseError, body, endpoint +from meatie_requests import Client + + +def get_error(response: Response) -> Exception | None: + exc_type = HttpStatusError if response.status >= 300 else ResponseError + + body = response.json() + if isinstance(body, dict): + error = body.get("error") + if error is not None: + return exc_type(response, error) + + if response.status >= 300: + return exc_type(response) + + return None + + +def test_raises_error(http_server: HTTPTestServer) -> None: + # GIVEN + http_server.handler = service_unavailable + + class TestClient(Client): + @endpoint(http_server.base_url + "/", body(error=get_error)) + def get_response(self) -> dict[str, str]: + ... + + # WHEN + with pytest.raises(HttpStatusError) as exc_info: + with TestClient(requests.Session()) as client: + client.get_response() + + # THEN + assert exc_info.value.args == ("deployment in progress",) diff --git a/tests/shared/http_test/__init__.py b/tests/shared/http_test/__init__.py index 4dbb2c7..8474e6b 100644 --- a/tests/shared/http_test/__init__.py +++ b/tests/shared/http_test/__init__.py @@ -1,16 +1,12 @@ # Copyright 2024 The Meatie Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. - +from . import handlers from .async_adapter import ClientAdapter, ResponseAdapter +from .handlers import Handler, RequestHandler from .http_server import ( - Handler, HTTPSTestServer, HTTPTestServer, - RequestHandler, - StatusHandler, - diagnostic_handler, - echo_handler, ) __all__ = [ @@ -18,9 +14,7 @@ "RequestHandler", "HTTPTestServer", "HTTPSTestServer", - "StatusHandler", - "diagnostic_handler", - "echo_handler", "ClientAdapter", + "handlers", "ResponseAdapter", ] diff --git a/tests/shared/http_test/handlers.py b/tests/shared/http_test/handlers.py new file mode 100644 index 0000000..5a9a4eb --- /dev/null +++ b/tests/shared/http_test/handlers.py @@ -0,0 +1,177 @@ +import json +import urllib.parse +from decimal import Decimal +from http import HTTPStatus +from http.server import BaseHTTPRequestHandler, SimpleHTTPRequestHandler +from typing import Any, Callable, Optional + + +class Handler(SimpleHTTPRequestHandler): + def send_json(self, status_code: HTTPStatus, message: Any) -> None: + if isinstance(message, str): + json_data = message + else: + try: + json_data = json.dumps(message) + except Exception as exc: + self.send_text( + HTTPStatus.INTERNAL_SERVER_ERROR, "Failed to serialize response: " + str(exc) + ) + return + + self.send_bytes(status_code, "application/json", json_data.encode("utf-8")) + + def send_text(self, status_code: HTTPStatus, text: str) -> None: + self.send_bytes(status_code, "text/plain", text.encode("utf-8")) + + def send_bytes(self, status_code: HTTPStatus, content_type: str, message: bytes) -> None: + self.send_response(status_code) + self.send_header("Content-Type", content_type) + self.send_header("Content-Length", str(len(message))) + self.end_headers() + self.wfile.write(message) + + def safe_content_length(self) -> Optional[int]: + content_length_raw = self.headers.get("Content-Length", "0") + try: + return int(content_length_raw) + except ValueError: + self.send_text( + HTTPStatus.BAD_REQUEST, + f"Content length should be an integer: '{content_length_raw}'", + ) + return None + + def safe_bytes(self) -> Optional[bytes]: + content_length_opt = self.safe_content_length() + if content_length_opt is None: + return None + + try: + return self.rfile.read(content_length_opt) + except Exception as exc: + self.send_text(HTTPStatus.BAD_REQUEST, "Failed to read request body: " + str(exc)) + return None + + def safe_text(self) -> Optional[str]: + raw_body_opt = self.safe_bytes() + if raw_body_opt is None: + return None + + content_charset = self.headers.get_content_charset("utf-8") + try: + return raw_body_opt.decode(content_charset) + except Exception as exc: + self.send_text(HTTPStatus.BAD_REQUEST, "Failed to decode request body: " + str(exc)) + return None + + +RequestHandler = Callable[[Handler], None] + + +class StatusHandler: + def __init__(self, status: HTTPStatus) -> None: + self.status = status + + def __call__(self, handler: Handler) -> None: + handler.send_response(self.status) + handler.end_headers() + + +def echo_handler(handler: Handler) -> None: + raw_body_opt = handler.safe_bytes() + if raw_body_opt is None: + return + + content_type = handler.headers.get_content_type() + handler.send_bytes(HTTPStatus.OK, content_type, raw_body_opt) + + +def diagnostic_handler(handler: Handler) -> None: + body_opt = handler.safe_text() + if body_opt is None: + return + + headers = {key: value for key, value in handler.headers.items()} + try: + url = urllib.parse.urlparse(handler.path) + except Exception as exc: + handler.send_text(HTTPStatus.BAD_REQUEST, "Failed to parse URL: " + str(exc)) + return + + response = { + "path": url.path, + "scheme": url.scheme, + "params": url.params, + "query": url.query, + "fragment": url.fragment, + "netloc": url.netloc, + "headers": headers, + "body": body_opt, + } + handler.send_json(HTTPStatus.OK, response) + + +def service_unavailable(request: BaseHTTPRequestHandler) -> None: + request.send_response(HTTPStatus.SERVICE_UNAVAILABLE) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write('{"error": "deployment in progress"}'.encode("utf-8")) + + +def emoji(request: BaseHTTPRequestHandler) -> None: + request.send_response(HTTPStatus.OK) + request.end_headers() + request.wfile.write(bytes([0xF0, 0x9F, 0x9A, 0x80])) + + +def ascii_emoji(request: BaseHTTPRequestHandler) -> None: + request.send_response(HTTPStatus.OK) + request.send_header("Content-Type", "text/plain; charset=ascii") + request.end_headers() + request.wfile.write(bytes([0xF0, 0x9F, 0x9A, 0x80])) + + +def status_ok(request: BaseHTTPRequestHandler) -> None: + request.send_response(HTTPStatus.OK) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write('{"status": "ok"}'.encode("utf-8")) + + +def status_ok_as_text(request: BaseHTTPRequestHandler) -> None: + request.send_response(HTTPStatus.OK) + request.end_headers() + request.wfile.write("{'status': 'ok'}".encode("utf-8")) + + +NGINX_GATEWAY_TIMEOUT = ( + "" + "504 Gateway Time-out" + "

504 Gateway Time-out


nginx
" + "" +) + + +def nginx_gateway_timeout(request: BaseHTTPRequestHandler) -> None: + request.send_response(HTTPStatus.GATEWAY_TIMEOUT) + request.send_header("Content-Type", "text/html") + request.end_headers() + request.wfile.write(NGINX_GATEWAY_TIMEOUT.encode("utf-8")) + + +def truncated_json(request: BaseHTTPRequestHandler) -> None: + request.send_response(HTTPStatus.OK) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write("{'status':".encode("utf-8")) + + +MAGIC_NUMBER = Decimal("42.000000000000001") + + +def magic_number(request: BaseHTTPRequestHandler) -> None: + request.send_response(HTTPStatus.OK) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(f'{{"number": {MAGIC_NUMBER}}}'.encode("utf-8")) diff --git a/tests/shared/http_test/http_server.py b/tests/shared/http_test/http_server.py index 926271b..73db1d0 100644 --- a/tests/shared/http_test/http_server.py +++ b/tests/shared/http_test/http_server.py @@ -1,120 +1,14 @@ # Copyright 2024 The Meatie Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. -import json import ssl -import urllib.parse from http import HTTPStatus -from http.server import BaseHTTPRequestHandler, HTTPServer, SimpleHTTPRequestHandler +from http.server import BaseHTTPRequestHandler, HTTPServer from threading import Thread -from typing import Any, Callable, Optional +from typing import Any, Optional from typing_extensions import Self - -class Handler(SimpleHTTPRequestHandler): - def send_json(self, status_code: HTTPStatus, message: Any) -> None: - if isinstance(message, str): - json_data = message - else: - try: - json_data = json.dumps(message) - except Exception as exc: - self.send_text( - HTTPStatus.INTERNAL_SERVER_ERROR, "Failed to serialize response: " + str(exc) - ) - return - - self.send_bytes(status_code, "application/json", json_data.encode("utf-8")) - - def send_text(self, status_code: HTTPStatus, text: str) -> None: - self.send_bytes(status_code, "text/plain", text.encode("utf-8")) - - def send_bytes(self, status_code: HTTPStatus, content_type: str, message: bytes) -> None: - self.send_response(status_code) - self.send_header("Content-Type", content_type) - self.send_header("Content-Length", str(len(message))) - self.end_headers() - self.wfile.write(message) - - def safe_content_length(self) -> Optional[int]: - content_length_raw = self.headers.get("Content-Length", "0") - try: - return int(content_length_raw) - except ValueError: - self.send_text( - HTTPStatus.BAD_REQUEST, - f"Content length should be an integer: '{content_length_raw}'", - ) - return None - - def safe_bytes(self) -> Optional[bytes]: - content_length_opt = self.safe_content_length() - if content_length_opt is None: - return None - - try: - return self.rfile.read(content_length_opt) - except Exception as exc: - self.send_text(HTTPStatus.BAD_REQUEST, "Failed to read request body: " + str(exc)) - return None - - def safe_text(self) -> Optional[str]: - raw_body_opt = self.safe_bytes() - if raw_body_opt is None: - return None - - content_charset = self.headers.get_content_charset("utf-8") - try: - return raw_body_opt.decode(content_charset) - except Exception as exc: - self.send_text(HTTPStatus.BAD_REQUEST, "Failed to decode request body: " + str(exc)) - return None - - -RequestHandler = Callable[[Handler], None] - - -class StatusHandler: - def __init__(self, status: HTTPStatus) -> None: - self.status = status - - def __call__(self, handler: Handler) -> None: - handler.send_response(self.status) - handler.end_headers() - - -def echo_handler(handler: Handler) -> None: - raw_body_opt = handler.safe_bytes() - if raw_body_opt is None: - return - - content_type = handler.headers.get_content_type() - handler.send_bytes(HTTPStatus.OK, content_type, raw_body_opt) - - -def diagnostic_handler(handler: Handler) -> None: - body_opt = handler.safe_text() - if body_opt is None: - return - - headers = {key: value for key, value in handler.headers.items()} - try: - url = urllib.parse.urlparse(handler.path) - except Exception as exc: - handler.send_text(HTTPStatus.BAD_REQUEST, "Failed to parse URL: " + str(exc)) - return - - response = { - "path": url.path, - "scheme": url.scheme, - "params": url.params, - "query": url.query, - "fragment": url.fragment, - "netloc": url.netloc, - "headers": headers, - "body": body_opt, - } - handler.send_json(HTTPStatus.OK, response) +from .handlers import Handler, RequestHandler class HTTPTestServer: diff --git a/tests/shared/suite/client/default_suite.py b/tests/shared/suite/client/default_suite.py index ef84d92..eea569f 100644 --- a/tests/shared/suite/client/default_suite.py +++ b/tests/shared/suite/client/default_suite.py @@ -5,13 +5,10 @@ import pytest from http_test import ( - Handler, HTTPSTestServer, HTTPTestServer, - StatusHandler, - diagnostic_handler, - echo_handler, ) +from http_test.handlers import Handler, StatusHandler, diagnostic_handler, echo_handler from meatie import ( ParseResponseError, Request,