Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Check the HTTP response for errors #99

Merged
merged 3 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/meatie/aio/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
self.response_decoder = response_decoder
self.get_json: Optional[Callable[[Any], Awaitable[dict[str, Any]]]] = None
self.get_text: Optional[Callable[[Any], Awaitable[str]]] = None
self.get_error: Optional[Callable[[AsyncResponse], Awaitable[Optional[Exception]]]] = None
self.__operator_by_priority: dict[int, AsyncOperator[ResponseBodyType]] = {}

def __set_name__(self, owner: type[object], name: str) -> None:
Expand Down Expand Up @@ -91,7 +92,13 @@ def __get__(
operators = [operator for _, operator in priority_operator_pair]

return BoundAsyncEndpointDescriptor( # type: ignore[return-value]
instance, operators, self.template, self.response_decoder, self.get_json, self.get_text
instance,
operators,
self.template,
self.response_decoder,
self.get_json,
self.get_text,
self.get_error,
)


Expand All @@ -102,8 +109,9 @@ def __init__(
operators: Iterable[AsyncOperator[ResponseBodyType]],
template: RequestTemplate[Any],
response_decoder: TypeAdapter[ResponseBodyType],
get_json: Optional[Callable[[Any], Awaitable[dict[str, Any]]]],
get_json: Optional[Callable[[Any], Awaitable[Any]]],
get_text: Optional[Callable[[Any], Awaitable[str]]],
get_error: Optional[Callable[[AsyncResponse], Awaitable[Optional[Exception]]]],
) -> None:
self.__instance = instance
self.__operators = list(operators)
Expand All @@ -112,6 +120,7 @@ def __init__(
self.__response_decoder = response_decoder
self.__get_json = get_json
self.__get_text = get_text
self.__get_error = get_error

async def __call__(self, *args: PT.args, **kwargs: PT.kwargs) -> ResponseBodyType:
request = self.__template.build_request(*args, **kwargs)
Expand All @@ -124,5 +133,12 @@ async def __send_request(self, context: AsyncContext[ResponseBodyType]) -> Respo
response.get_json = self.__get_json
if self.__get_text is not None:
response.get_text = self.__get_text

context.response = response

if self.__get_error is not None:
error = await self.__get_error(response)
if error is not None:
raise error

return await self.__response_decoder.from_async_response(response)
19 changes: 17 additions & 2 deletions src/meatie/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def __init__(
) -> None:
self.template = template
self.response_decoder = response_decoder
self.get_json: Optional[Callable[[Any], dict[str, Any]]] = None
self.get_json: Optional[Callable[[Any], Any]] = None
self.get_text: Optional[Callable[[Any], str]] = None
self.get_error: Optional[Callable[[Response], Optional[Exception]]] = None
self.__operator_by_priority: dict[int, Operator[ResponseBodyType]] = {}

def __set_name__(self, owner: type[object], name: str) -> None:
Expand Down Expand Up @@ -61,7 +62,13 @@ def __get__(
operators = [operator for _, operator in priority_operator_pair]

return BoundEndpointDescriptor( # type: ignore[return-value]
instance, operators, self.template, self.response_decoder, self.get_json, self.get_text
instance,
operators,
self.template,
self.response_decoder,
self.get_json,
self.get_text,
self.get_error,
)


Expand Down Expand Up @@ -110,6 +117,7 @@ def __init__(
response_decoder: TypeAdapter[ResponseBodyType],
get_json: Optional[Callable[[Any], dict[str, Any]]],
get_text: Optional[Callable[[Any], str]],
get_error: Optional[Callable[[Response], Optional[Exception]]],
) -> None:
self.__instance = instance
self.__operators = list(operators)
Expand All @@ -118,6 +126,7 @@ def __init__(
self.__response_decoder = response_decoder
self.__get_json = get_json
self.__get_text = get_text
self.__get_error = get_error

def __call__(self, *args: PT.args, **kwargs: PT.kwargs) -> ResponseBodyType:
request = self.__template.build_request(*args, **kwargs)
Expand All @@ -131,4 +140,10 @@ def __send_request(self, context: Context[ResponseBodyType]) -> ResponseBodyType
if self.__get_text is not None:
response.get_text = self.__get_text
context.response = response

if self.__get_error is not None:
error = self.__get_error(response)
if error is not None:
raise error

return self.__response_decoder.from_response(response)
13 changes: 11 additions & 2 deletions src/meatie/option/body_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@
# Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
from typing import Any, Awaitable, Callable, Optional, Union

from meatie import EndpointDescriptor
from meatie import AsyncResponse, EndpointDescriptor, Response
from meatie.aio import AsyncEndpointDescriptor
from meatie.internal.types import PT, T

__all__ = ["body"]


class BodyOption:
__slots__ = ("json", "text")
__slots__ = ("json", "text", "error")

def __init__(
self,
json: Optional[Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]]]] = None,
text: Optional[Union[Callable[[Any], str], Callable[[Any], Awaitable[str]]]] = None,
error: Optional[
Union[
Callable[[Response], Optional[Exception]],
Callable[[AsyncResponse], Awaitable[Optional[Exception]]],
]
] = None,
) -> None:
"""
Customize handling of HTTP response body.
Expand All @@ -24,16 +30,19 @@ def __init__(
implemented in the client library
:param text: function to apply on the HTTP response to extract response text, otherwise use the default method
implemented in the client library
:param error: function to apply on the HTTP response to extract an error
"""

self.json = json
self.text = text
self.error = error

def __call__(
self, descriptor: Union[EndpointDescriptor[PT, T], AsyncEndpointDescriptor[PT, T]]
) -> None:
descriptor.get_text = self.text
descriptor.get_json = self.json
descriptor.get_error = self.error


body = BodyOption
63 changes: 16 additions & 47 deletions tests/client/aiohttp/adapter/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,26 @@
from decimal import Decimal
from functools import partial
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler

import pytest
from aiohttp import ClientResponse, ClientSession
from http_test import HTTPTestServer
from http_test.handlers import (
NGINX_GATEWAY_TIMEOUT,
magic_number,
nginx_gateway_timeout,
status_ok,
status_ok_as_text,
truncated_json,
)
from meatie import ParseResponseError, body, endpoint
from meatie_aiohttp import Client


@pytest.mark.asyncio()
async def test_can_parse_json(http_server: HTTPTestServer) -> None:
# GIVEN
def handler(request: BaseHTTPRequestHandler) -> None:
request.send_response(HTTPStatus.OK)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write('{"status": "ok"}'.encode("utf-8"))

http_server.handler = handler
http_server.handler = status_ok

class TestClient(Client):
@endpoint("/")
Expand All @@ -40,14 +41,7 @@ async def get_response(self) -> dict[str, str]:
@pytest.mark.asyncio()
async def test_can_handle_invalid_content_type(http_server: HTTPTestServer) -> None:
# GIVEN
content = "{'status': 'ok'}"

def handler(request: BaseHTTPRequestHandler) -> None:
request.send_response(HTTPStatus.OK)
request.end_headers()
request.wfile.write(content.encode("utf-8"))

http_server.handler = handler
http_server.handler = status_ok_as_text

class TestClient(Client):
@endpoint("/")
Expand All @@ -61,27 +55,14 @@ async def get_response(self) -> dict[str, str]:

# THEN
exc = exc_info.value
assert content == exc.text
assert "{'status': 'ok'}" == exc.text
assert HTTPStatus.OK == exc.response.status


@pytest.mark.asyncio()
async def test_can_handle_html_response(http_server: HTTPTestServer) -> None:
# GIVEN
content = (
"<html>"
"<head><title>504 Gateway Time-out</title></head>"
"<body><center><h1>504 Gateway Time-out</h1></center><hr><center>nginx</center></body>"
"</html>"
)

def handler(request: BaseHTTPRequestHandler) -> None:
request.send_response(HTTPStatus.GATEWAY_TIMEOUT)
request.send_header("Content-Type", "text/html")
request.end_headers()
request.wfile.write(content.encode("utf-8"))

http_server.handler = handler
http_server.handler = nginx_gateway_timeout

class TestClient(Client):
@endpoint("/")
Expand All @@ -95,20 +76,14 @@ async def get_response(self) -> dict[str, str]:

# THEN
exc = exc_info.value
assert content == exc.text
assert NGINX_GATEWAY_TIMEOUT == exc.text
assert HTTPStatus.GATEWAY_TIMEOUT == exc.response.status


@pytest.mark.asyncio()
async def test_can_handle_corrupted_json(http_server: HTTPTestServer) -> None:
# GIVEN
def handler(request: BaseHTTPRequestHandler) -> None:
request.send_response(HTTPStatus.OK)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write("{'status':".encode("utf-8"))

http_server.handler = handler
http_server.handler = truncated_json

class TestClient(Client):
@endpoint("/")
Expand All @@ -128,13 +103,7 @@ async def get_response(self) -> dict[str, str]:
@pytest.mark.asyncio()
async def test_use_custom_decoder(http_server: HTTPTestServer) -> None:
# GIVEN response have json which will lose precision if parsed as float
def handler(request: BaseHTTPRequestHandler) -> None:
request.send_response(HTTPStatus.OK)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write('{"foo": 42.000000000000001}'.encode("utf-8"))

http_server.handler = handler
http_server.handler = magic_number

async def custom_json(response: ClientResponse) -> dict[str, Decimal]:
return await response.json(loads=partial(json.loads, parse_float=Decimal))
Expand All @@ -149,4 +118,4 @@ async def get_response(self) -> dict[str, Decimal]:
response = await client.get_response()

# THEN
assert {"foo": Decimal("42.000000000000001")} == response
assert {"number": Decimal("42.000000000000001")} == response
18 changes: 3 additions & 15 deletions tests/client/aiohttp/adapter/test_string.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
# Copyright 2024 The Meatie Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler

import pytest
from aiohttp import ClientSession
from http_test import HTTPTestServer
from http_test.handlers import ascii_emoji, emoji
from meatie import endpoint
from meatie_aiohttp import Client


@pytest.mark.asyncio()
async def test_can_parse_string(http_server: HTTPTestServer) -> None:
# GIVEN
def handler(request: BaseHTTPRequestHandler) -> None:
request.send_response(HTTPStatus.OK)
request.end_headers()
request.wfile.write(bytes([0xF0, 0x9F, 0x9A, 0x80]))

http_server.handler = handler
http_server.handler = emoji

class TestClient(Client):
@endpoint("/")
Expand All @@ -36,13 +30,7 @@ async def get_response(self) -> str:
@pytest.mark.asyncio()
async def test_can_handle_invalid_encoding(http_server: HTTPTestServer) -> None:
# GIVEN
def handler(request: BaseHTTPRequestHandler) -> None:
request.send_response(HTTPStatus.OK)
request.send_header("Content-Type", "text/plain; charset=ascii")
request.end_headers()
request.wfile.write(bytes([0xF0, 0x9F, 0x9A, 0x80]))

http_server.handler = handler
http_server.handler = ascii_emoji

class TestClient(Client):
@endpoint("/")
Expand Down
43 changes: 43 additions & 0 deletions tests/client/aiohttp/test_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2024 The Meatie Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.

import pytest
from aiohttp import ClientSession
from http_test import HTTPTestServer
from http_test.handlers import service_unavailable
from meatie import AsyncResponse, HttpStatusError, ResponseError, body, endpoint
from meatie_aiohttp import Client


async def get_error(response: AsyncResponse) -> Exception | None:
exc_type = HttpStatusError if response.status >= 300 else ResponseError

body = await response.json()
if isinstance(body, dict):
error = body.get("error")
if error is not None:
return exc_type(response, error)

if response.status >= 300:
return exc_type(response)

return None


@pytest.mark.asyncio()
async def test_raises_error(http_server: HTTPTestServer) -> None:
# GIVEN
http_server.handler = service_unavailable

class TestClient(Client):
@endpoint("/", body(error=get_error))
async def get_response(self) -> dict[str, str]:
...

# WHEN
with pytest.raises(HttpStatusError) as exc_info:
async with TestClient(ClientSession(http_server.base_url)) as client:
await client.get_response()

# THEN
assert exc_info.value.args == ("deployment in progress",)
Loading
Loading