From 9b52406167b8338499274157dbe2690fc962af27 Mon Sep 17 00:00:00 2001 From: Raphael Odini Date: Wed, 15 Nov 2023 21:13:23 +0100 Subject: [PATCH] feat: add advanced filtering on GET /prices (#29) * Add fastapi-filter package * Replace existing filters with new library --- app/api.py | 14 ++------------ app/crud.py | 13 ++++--------- app/schemas.py | 23 +++++++++++++++++++++++ poetry.lock | 22 +++++++++++++++++++++- pyproject.toml | 6 +++++- 5 files changed, 55 insertions(+), 23 deletions(-) diff --git a/app/api.py b/app/api.py index 90de0ea0..a4363db0 100644 --- a/app/api.py +++ b/app/api.py @@ -1,6 +1,5 @@ import asyncio import uuid -from datetime import date from pathlib import Path from typing import Annotated @@ -9,7 +8,6 @@ Depends, FastAPI, HTTPException, - Query, Request, Response, UploadFile, @@ -18,6 +16,7 @@ from fastapi.responses import HTMLResponse, PlainTextResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.templating import Jinja2Templates +from fastapi_filter import FilterDepends from fastapi_pagination import Page, add_pagination from fastapi_pagination.ext.sqlalchemy import paginate from openfoodfacts.utils import get_logger @@ -133,16 +132,7 @@ async def authentication( @app.get("/prices", response_model=Page[schemas.PriceBase]) -async def get_price( - product_code: str | None = None, - location_osm_id: int | None = None, - date: date | None = None, -): - filters = { - "product_code": product_code, - "location_osm_id": location_osm_id, - "date": date, - } +async def get_price(filters: schemas.PriceFilter = FilterDepends(schemas.PriceFilter)): return paginate(db, crud.get_prices_query(filters=filters)) diff --git a/app/crud.py b/app/crud.py index cab3685a..5fac025f 100644 --- a/app/crud.py +++ b/app/crud.py @@ -9,7 +9,7 @@ from app import config from app.models import Price, Proof, User -from app.schemas import PriceCreate, ProofBase, UserBase +from app.schemas import PriceCreate, PriceFilter, UserBase def get_user(db: Session, user_id: str): @@ -56,20 +56,15 @@ def delete_user(db: Session, user_id: UserBase): return False -def get_prices_query(filters: dict | None = None): +def get_prices_query(filters: PriceFilter | None = None): """Useful for pagination.""" query = select(Price) if filters: - if filters.get("product_code", None): - query = query.filter(Price.product_code == filters["product_code"]) - if filters.get("location_osm_id", None): - query = query.filter(Price.location_osm_id == filters["location_osm_id"]) - if filters.get("date", None): - query = query.filter(Price.date == filters["date"]) + query = filters.filter(query) return query -def get_prices(db: Session, filters: dict | None = None): +def get_prices(db: Session, filters: PriceFilter | None = None): return db.execute(get_prices_query(filters=filters)).all() diff --git a/app/schemas.py b/app/schemas.py index c59f280b..0824c730 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -1,9 +1,12 @@ from datetime import date, datetime +from typing import Optional +from fastapi_filter.contrib.sqlalchemy import Filter from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from sqlalchemy_utils import Currency from app.enums import PriceLocationOSMType +from app.models import Price class UserBase(BaseModel): @@ -54,3 +57,23 @@ class ProofBase(ProofCreate): id: int owner: str created: datetime + + +class PriceFilter(Filter): + product_code: Optional[str] | None = None + location_osm_id: Optional[int] | None = None + location_osm_type: Optional[PriceLocationOSMType] | None = None + price: Optional[int] | None = None + currency: Optional[str] | None = None + price__gt: Optional[int] | None = None + price__gte: Optional[int] | None = None + price__lt: Optional[int] | None = None + price__lte: Optional[int] | None = None + date: Optional[str] | None = None + date__gt: Optional[str] | None = None + date__gte: Optional[str] | None = None + date__lt: Optional[str] | None = None + date__lte: Optional[str] | None = None + + class Constants(Filter.Constants): + model = Price diff --git a/poetry.lock b/poetry.lock index c535f49c..fc8d1644 100644 --- a/poetry.lock +++ b/poetry.lock @@ -279,6 +279,26 @@ typing-extensions = ">=4.5.0" [package.extras] all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +[[package]] +name = "fastapi-filter" +version = "1.0.0" +description = "FastAPI filter" +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "fastapi_filter-1.0.0-py3-none-any.whl", hash = "sha256:8bbdcfe71537ee6a33334d895a4a78098ddd8049e705c9a4dc4d24bca3e6c632"}, + {file = "fastapi_filter-1.0.0.tar.gz", hash = "sha256:f4ff5eed250d746e19ae4075a04a4bc9d0c46cc49b112c7689e13db978b28c42"}, +] + +[package.dependencies] +fastapi = ">=0.100.0,<1.0" +pydantic = ">=2.0.0,<3.0.0" + +[package.extras] +all = ["SQLAlchemy (>=1.4.36,<2.1.0)", "mongoengine (>=0.24.1,<0.28.0)"] +mongoengine = ["mongoengine (>=0.24.1,<0.28.0)"] +sqlalchemy = ["SQLAlchemy (>=1.4.36,<2.1.0)"] + [[package]] name = "fastapi-pagination" version = "0.12.12" @@ -1253,4 +1273,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "d2303c7fd7cf77afff6d61080be615f36b06f9b9590b9bb3a5831256119dbcd4" +content-hash = "c7601c7b9257f1f122b5a9492b9eb5dc5f5424017cd6b32a1d3906c5081f7c6a" diff --git a/pyproject.toml b/pyproject.toml index ad9f2b47..0906f5f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ sqlalchemy = "~2.0.23" sqlalchemy-utils = "~0.41.1" uvicorn = "~0.23.2" fastapi-pagination = "^0.12.12" +fastapi-filter = "^1.0.0" [tool.poetry.group.dev.dependencies] @@ -44,5 +45,8 @@ exclude = ''' )/ ''' +[tool.isort] +profile = "black" + [tool.mypy] -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true