Skip to content

Commit

Permalink
feat: support custom formatters parameters (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmateusz authored Jul 24, 2024
1 parent 3029b24 commit 5792d19
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 12 deletions.
13 changes: 9 additions & 4 deletions src/meatie/api_reference.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
# Copyright 2024 The Meatie Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
import inspect
from typing import Any, get_args
from typing import Any, Callable, Optional, get_args

from typing_extensions import Self

__all__ = ["api_ref", "ApiReference"]


class ApiReference:
__slots__ = ("name",)
__slots__ = ("name", "fmt")

def __init__(self, name: str) -> None:
def __init__(
self, name: Optional[str] = None, fmt: Optional[Callable[[Any], Any]] = None
) -> None:
self.name = name
self.fmt = fmt

def __hash__(self) -> int:
return hash(self.name)
Expand All @@ -26,8 +29,10 @@ def __eq__(self, other: Any) -> bool:
def from_signature(cls, parameter: inspect.Parameter) -> Self:
for arg in get_args(parameter.annotation):
if isinstance(arg, cls):
if arg.name is None:
arg.name = parameter.name
return arg
return cls(name=parameter.name)
return cls(name=parameter.name, fmt=None)


api_ref = ApiReference
14 changes: 11 additions & 3 deletions src/meatie/internal/template/parameter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 The Meatie Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
from enum import Enum
from typing import Any
from typing import Any, Callable, Optional


class Kind(Enum):
Expand All @@ -12,13 +12,21 @@ class Kind(Enum):


class Parameter:
__slots__ = ("kind", "name", "api_ref", "default_value")
__slots__ = ("kind", "name", "api_ref", "default_value", "formatter")

def __init__(self, kind: Kind, name: str, api_ref: str, default_value: Any = None) -> None:
def __init__(
self,
kind: Kind,
name: str,
api_ref: str,
default_value: Any = None,
formatter: Optional[Callable[[Any], Any]] = None,
) -> None:
self.kind = kind
self.name = name
self.api_ref = api_ref
self.default_value = default_value
self.formatter = formatter

def __hash__(self) -> int:
return hash(self.name)
Expand Down
17 changes: 15 additions & 2 deletions src/meatie/internal/template/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,24 @@ def build_request(self, *args: Any, **kwargs: Any) -> Request:
body_value: Any = None
for param, value in value_by_param.items():
if param.kind == Kind.PATH:
if param.formatter is not None:
value = param.formatter(value)
path_kwargs[param.api_ref] = value
continue

if param.kind == Kind.QUERY:
# emit query parameters only if underlying value is not None
if value is not None:
if param.formatter is not None:
value = param.formatter(value)
query_kwargs[param.api_ref] = value
continue

if param.kind == Kind.BODY:
body_value = value
if param.formatter is not None:
body_value = param.formatter(value)
else:
body_value = value
continue

raise NotImplementedError(f"Kind {param.kind} is not supported") # pragma: no cover
Expand Down Expand Up @@ -127,6 +134,8 @@ def from_signature(
param_type = type_hints[param_name]
sig_param = signature.parameters[param_name]
api_ref = ApiReference.from_signature(sig_param)
assert api_ref.name is not None

kind = Kind.QUERY
if api_ref.name == "body":
kind = Kind.BODY
Expand All @@ -137,7 +146,11 @@ def from_signature(
default_value = None
if sig_param.default is not inspect.Parameter.empty:
default_value = sig_param.default
parameter = Parameter(kind, param_name, api_ref.name, default_value)

formatter = None
if api_ref.fmt is not None:
formatter = api_ref.fmt
parameter = Parameter(kind, param_name, api_ref.name, default_value, formatter)
parameters.append(parameter)

return cls.validate_object(template, parameters, signature, request_encoder, method)
Expand Down
35 changes: 32 additions & 3 deletions tests/client/aiohttp_/test_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright 2024 The Meatie Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.

from typing import Any, Optional
import datetime
from typing import Annotated, Any, Optional
from unittest.mock import ANY, Mock

import pytest
from meatie import Request, endpoint
from meatie import Request, api_ref, endpoint
from meatie.aio import AsyncContext, AsyncEndpointDescriptor
from meatie.internal.template import RequestTemplate
from meatie.internal.types import AsyncClient
Expand Down Expand Up @@ -37,6 +37,35 @@ async def get_products(self) -> list[Any]:
session.request.assert_awaited_once_with("GET", "/api/v1/products")


@pytest.mark.asyncio()
async def test_get_with_formatter(mock_tools: AiohttpMockTools) -> None:
# GIVEN
session = mock_tools.session_with_json_response(json=PRODUCTS)

def format_date(date: datetime.datetime) -> str:
return date.strftime("%Y-%m-%d")

class Store(Client):
def __init__(self) -> None:
super().__init__(session)

@endpoint("/api/v1/transactions")
async def get_transactions(
self, since: Annotated[datetime.datetime, api_ref(fmt=format_date)]
) -> list[Any]:
...

# WHEN
async with Store() as api:
result = await api.get_transactions(since=datetime.datetime(2006, 1, 2))

# THEN
assert PRODUCTS == result
session.request.assert_awaited_once_with(
"GET", "/api/v1/transactions", params={"since": "2006-01-02"}
)


@pytest.mark.asyncio()
async def test_post_with_body(mock_tools: AiohttpMockTools) -> None:
# GIVEN
Expand Down

0 comments on commit 5792d19

Please sign in to comment.