diff --git a/src/meatie/api_reference.py b/src/meatie/api_reference.py index 2fbfdec..6f4e2f6 100644 --- a/src/meatie/api_reference.py +++ b/src/meatie/api_reference.py @@ -1,7 +1,7 @@ # Copyright 2024 The Meatie Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. import inspect -from typing import Any, get_args +from typing import Any, Callable, Optional, get_args from typing_extensions import Self @@ -9,10 +9,13 @@ class ApiReference: - __slots__ = ("name",) + __slots__ = ("name", "fmt") - def __init__(self, name: str) -> None: + def __init__( + self, name: Optional[str] = None, fmt: Optional[Callable[[Any], Any]] = None + ) -> None: self.name = name + self.fmt = fmt def __hash__(self) -> int: return hash(self.name) @@ -26,8 +29,10 @@ def __eq__(self, other: Any) -> bool: def from_signature(cls, parameter: inspect.Parameter) -> Self: for arg in get_args(parameter.annotation): if isinstance(arg, cls): + if arg.name is None: + arg.name = parameter.name return arg - return cls(name=parameter.name) + return cls(name=parameter.name, fmt=None) api_ref = ApiReference diff --git a/src/meatie/internal/template/parameter.py b/src/meatie/internal/template/parameter.py index cd7ac0a..b39f4f1 100644 --- a/src/meatie/internal/template/parameter.py +++ b/src/meatie/internal/template/parameter.py @@ -1,7 +1,7 @@ # Copyright 2024 The Meatie Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. from enum import Enum -from typing import Any +from typing import Any, Callable, Optional class Kind(Enum): @@ -12,13 +12,21 @@ class Kind(Enum): class Parameter: - __slots__ = ("kind", "name", "api_ref", "default_value") + __slots__ = ("kind", "name", "api_ref", "default_value", "formatter") - def __init__(self, kind: Kind, name: str, api_ref: str, default_value: Any = None) -> None: + def __init__( + self, + kind: Kind, + name: str, + api_ref: str, + default_value: Any = None, + formatter: Optional[Callable[[Any], Any]] = None, + ) -> None: self.kind = kind self.name = name self.api_ref = api_ref self.default_value = default_value + self.formatter = formatter def __hash__(self) -> int: return hash(self.name) diff --git a/src/meatie/internal/template/request.py b/src/meatie/internal/template/request.py index 855aa24..1a9db42 100644 --- a/src/meatie/internal/template/request.py +++ b/src/meatie/internal/template/request.py @@ -76,17 +76,24 @@ def build_request(self, *args: Any, **kwargs: Any) -> Request: body_value: Any = None for param, value in value_by_param.items(): if param.kind == Kind.PATH: + if param.formatter is not None: + value = param.formatter(value) path_kwargs[param.api_ref] = value continue if param.kind == Kind.QUERY: # emit query parameters only if underlying value is not None if value is not None: + if param.formatter is not None: + value = param.formatter(value) query_kwargs[param.api_ref] = value continue if param.kind == Kind.BODY: - body_value = value + if param.formatter is not None: + body_value = param.formatter(value) + else: + body_value = value continue raise NotImplementedError(f"Kind {param.kind} is not supported") # pragma: no cover @@ -127,6 +134,8 @@ def from_signature( param_type = type_hints[param_name] sig_param = signature.parameters[param_name] api_ref = ApiReference.from_signature(sig_param) + assert api_ref.name is not None + kind = Kind.QUERY if api_ref.name == "body": kind = Kind.BODY @@ -137,7 +146,11 @@ def from_signature( default_value = None if sig_param.default is not inspect.Parameter.empty: default_value = sig_param.default - parameter = Parameter(kind, param_name, api_ref.name, default_value) + + formatter = None + if api_ref.fmt is not None: + formatter = api_ref.fmt + parameter = Parameter(kind, param_name, api_ref.name, default_value, formatter) parameters.append(parameter) return cls.validate_object(template, parameters, signature, request_encoder, method) diff --git a/tests/client/aiohttp_/test_descriptor.py b/tests/client/aiohttp_/test_descriptor.py index 556ad92..d51279c 100644 --- a/tests/client/aiohttp_/test_descriptor.py +++ b/tests/client/aiohttp_/test_descriptor.py @@ -1,11 +1,11 @@ # Copyright 2024 The Meatie Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. - -from typing import Any, Optional +import datetime +from typing import Annotated, Any, Optional from unittest.mock import ANY, Mock import pytest -from meatie import Request, endpoint +from meatie import Request, api_ref, endpoint from meatie.aio import AsyncContext, AsyncEndpointDescriptor from meatie.internal.template import RequestTemplate from meatie.internal.types import AsyncClient @@ -37,6 +37,35 @@ async def get_products(self) -> list[Any]: session.request.assert_awaited_once_with("GET", "/api/v1/products") +@pytest.mark.asyncio() +async def test_get_with_formatter(mock_tools: AiohttpMockTools) -> None: + # GIVEN + session = mock_tools.session_with_json_response(json=PRODUCTS) + + def format_date(date: datetime.datetime) -> str: + return date.strftime("%Y-%m-%d") + + class Store(Client): + def __init__(self) -> None: + super().__init__(session) + + @endpoint("/api/v1/transactions") + async def get_transactions( + self, since: Annotated[datetime.datetime, api_ref(fmt=format_date)] + ) -> list[Any]: + ... + + # WHEN + async with Store() as api: + result = await api.get_transactions(since=datetime.datetime(2006, 1, 2)) + + # THEN + assert PRODUCTS == result + session.request.assert_awaited_once_with( + "GET", "/api/v1/transactions", params={"since": "2006-01-02"} + ) + + @pytest.mark.asyncio() async def test_post_with_body(mock_tools: AiohttpMockTools) -> None: # GIVEN