diff --git a/src/meatie/internal/template/request.py b/src/meatie/internal/template/request.py index 1c1f9e9..690b094 100644 --- a/src/meatie/internal/template/request.py +++ b/src/meatie/internal/template/request.py @@ -12,7 +12,7 @@ from meatie import Method, Request from meatie.api_reference import ApiReference -from meatie.internal.adapter import JsonAdapter, TypeAdapter, get_adapter +from meatie.internal.adapter import JsonAdapter, StringAdapter, TypeAdapter, get_adapter from meatie.internal.types import PT, RequestBodyType, T from typing_extensions import Callable, Self, Union, get_type_hints @@ -73,7 +73,8 @@ def build_request(self, *args: Any, **kwargs: Any) -> Request: path_kwargs = {} query_kwargs = {} - body_value: Any = None + body_json: Any = None + body_data: Any = None for param, value in value_by_param.items(): if param.kind == Kind.PATH: if param.formatter is not None: @@ -93,27 +94,27 @@ def build_request(self, *args: Any, **kwargs: Any) -> Request: if param.formatter is not None: value = param.formatter(value) + query_kwargs[param.api_ref] = value continue if param.kind == Kind.BODY: if param.formatter is not None: - body_value = param.formatter(value) + body_data = param.formatter(value) else: - body_value = value + body_json = self.request_encoder.to_content(value) continue raise NotImplementedError(f"Kind {param.kind} is not supported") # pragma: no cover path = self.template.format(**path_kwargs) - - if body_value is not None: - body_json = self.request_encoder.to_content(body_value) - else: - body_json = None - return Request( - method=self.method, path=path, params=query_kwargs, headers={}, json=body_json + method=self.method, + path=path, + params=query_kwargs, + headers={}, + json=body_json, + data=body_data, ) @classmethod @@ -146,7 +147,11 @@ def from_signature( kind = Kind.QUERY if api_ref.name == "body": kind = Kind.BODY - request_encoder = get_adapter(param_type) + if api_ref.fmt is None: + request_encoder = get_adapter(param_type) + else: + request_encoder = StringAdapter + elif api_ref.name in template: kind = Kind.PATH diff --git a/tests/client/aiohttp/adapter/test_pydantic_v2.py b/tests/client/aiohttp/adapter/test_pydantic_v2.py new file mode 100644 index 0000000..79447e7 --- /dev/null +++ b/tests/client/aiohttp/adapter/test_pydantic_v2.py @@ -0,0 +1,52 @@ +# Copyright 2024 The Meatie Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. +import json +from http import HTTPStatus +from http.server import BaseHTTPRequestHandler +from typing import Annotated, Any, Optional + +import pytest +from aiohttp import ClientSession +from http_test import HTTPTestServer +from meatie import AsyncResponse, api_ref, endpoint +from meatie_aiohttp import Client + +pydantic = pytest.importorskip("pydantic", minversion="2.0.0") +BaseModel: type = pydantic.BaseModel + + +@pytest.mark.asyncio() +async def test_post_request_body_with_fmt(http_server: HTTPTestServer) -> None: + # GIVEN + class Request(pydantic.BaseModel): + data: Optional[dict[str, Any]] + + def handler(request: BaseHTTPRequestHandler) -> None: + content_length = request.headers.get("Content-Length", "0") + raw_body = request.rfile.read(int(content_length)) + body = json.loads(raw_body.decode("utf-8")) + + if body.get("data") is not None: + request.send_response(HTTPStatus.BAD_REQUEST) + else: + request.send_response(HTTPStatus.OK) + request.end_headers() + + http_server.handler = handler + + def dump_body(model: pydantic.BaseModel) -> str: + return model.model_dump_json(by_alias=True, exclude_none=True) + + class TestClient(Client): + @endpoint("/") + async def post_request( + self, body: Annotated[Request, api_ref(fmt=dump_body)] + ) -> AsyncResponse: + ... + + # WHEN + async with TestClient(ClientSession(http_server.base_url)) as client: + response = await client.post_request(Request(data=None)) + + # THEN + assert response.status == HTTPStatus.OK