Skip to content

Commit

Permalink
fix: type hints after pre-commit update (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgarel authored Jan 3, 2025
1 parent baf6387 commit acd55a1
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 41 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
repos:
# Note for all linters: do not forget to update pyproject.toml when updating version.
- repo: https://github.com/python-poetry/poetry
rev: 1.8.4
rev: 1.8.0
hooks:
- id: poetry-lock
args: ["--check"]

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0
rev: 24.10.0
hooks:
- id: black
language_version: python3
Expand All @@ -23,7 +23,7 @@ repos:
- id: isort

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.1
rev: v1.14.1
hooks:
- id: mypy
additional_dependencies: [types-all-v2]
Expand Down
28 changes: 14 additions & 14 deletions app/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,14 @@ def uses_sort_script(self):
_, sort_by = self.sign_sort_by
return index_config.scripts and sort_by in index_config.scripts.keys()

@cached_property
def sign_sort_by(self) -> Tuple[str_utils.BoolOperator, str | None]:
return (
("+", None)
if self.sort_by is None
else str_utils.split_sort_by_sign(self.sort_by)
)


class AggregateSearchParameters(BaseModel):

Expand Down Expand Up @@ -433,14 +441,6 @@ def check_charts_are_valid(self):
raise ValueError(errors)
return self

@cached_property
def sign_sort_by(self) -> Tuple[str_utils.BoolOperator, str | None]:
return (
("+", None)
if self.sort_by is None
else str_utils.split_sort_by_sign(self.sort_by)
)


def _prepare_str_list(item: Any) -> str | None:
if isinstance(item, str):
Expand Down Expand Up @@ -503,19 +503,19 @@ def parse_charts_str(
"""
str_charts = _prepare_str_list(charts)
if str_charts:
charts = []
parsed_charts: list[DistributionChart | ScatterChart] = []
charts_list = str_charts.split(",")
for c in charts_list:
if ":" in c:
[x, y] = c.split(":")
charts.append(ScatterChart(x=x, y=y))
parsed_charts.append(ScatterChart(x=x, y=y))
else:
charts.append(DistributionChart(field=c))
if charts is not None:
parsed_charts.append(DistributionChart(field=c))
if parsed_charts is not None:
# we already know because of code logic that charts is the right type
# but we need to cast for mypy type checking
charts = cast(list[ChartType], charts)
return charts
result_charts = cast(list[ChartType], parsed_charts)
return result_charts

@model_validator(mode="after")
def validate_q_or_sort_by(self):
Expand Down
14 changes: 7 additions & 7 deletions app/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def build_distribution_chart(
Return the vega structure for a Bar Chart
Inspiration: https://vega.github.io/vega/examples/bar-chart/
"""
chart = empty_chart(chart.field)
chart["data"] = [
vega_chart = empty_chart(chart.field)
vega_chart["data"] = [
{
"name": "table",
"values": values,
"transform": [{"type": "filter", "expr": "datum['category'] != 'unknown'"}],
},
]
chart["signals"].append(
vega_chart["signals"].append(
{
"name": "tooltip",
"value": {},
Expand All @@ -68,7 +68,7 @@ def build_distribution_chart(
],
}
)
chart["scales"] = [
vega_chart["scales"] = [
{
"name": "xscale",
"type": "band",
Expand All @@ -86,10 +86,10 @@ def build_distribution_chart(
]
# How to hide vertical axis: do not add { scale: yscale, ...}
# in axes section
chart["axes"] = [
vega_chart["axes"] = [
{"orient": "bottom", "scale": "xscale", "domain": False, "ticks": False}
]
chart["marks"] = [
vega_chart["marks"] = [
{
"type": "rect",
"from": {"data": "table"},
Expand Down Expand Up @@ -135,7 +135,7 @@ def build_distribution_chart(
},
},
]
return chart
return vega_chart


def build_scatter_chart(
Expand Down
4 changes: 2 additions & 2 deletions app/openfoodfacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class DocumentFetcher(BaseDocumentFetcher):
def fetch_document(self, stream_name: str, item: JSONType) -> FetcherResult:
if item.get("action") == "deleted":
# this is a deleted product, no need to fetch
return FetcherResult(status=FetcherStatus.REMOVED, data=None)
return FetcherResult(status=FetcherStatus.REMOVED, document=None)

code = item["code"]
url = f"{OFF_API_URL}/api/v2/product/{code}"
Expand All @@ -144,7 +144,7 @@ def fetch_document(self, stream_name: str, item: JSONType) -> FetcherResult:
or not json_response.get("product")
):
# consider it removed
return FetcherResult(status=FetcherStatus.REMOVED, data=None)
return FetcherResult(status=FetcherStatus.REMOVED, document=None)

return FetcherResult(
status=FetcherStatus.FOUND, document=json_response["product"]
Expand Down
9 changes: 7 additions & 2 deletions app/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ._types import (
ErrorSearchResponse,
JSONType,
PostSearchParameters,
QueryAnalysis,
SearchParameters,
SearchResponse,
Expand Down Expand Up @@ -303,7 +304,11 @@ def build_es_query(
es_query.aggs.bucket(agg_name, agg)

sort_by: JSONType | str | None = None
if params.uses_sort_script and params.sort_by is not None:
if (
isinstance(params, PostSearchParameters)
and params.uses_sort_script
and params.sort_by is not None
):
sort_by = parse_sort_by_script(
params.sort_by, params.sort_params, config, params.valid_index_id
)
Expand Down Expand Up @@ -367,7 +372,7 @@ def execute_query(
projection: set[str] | None = None,
) -> SearchResponse:
errors = []
debug = SearchResponseDebug(query=query.to_dict())
debug = SearchResponseDebug(es_query=query.to_dict())
try:
results = query.execute()
except elasticsearch.ApiError as e:
Expand Down
26 changes: 13 additions & 13 deletions tests/unit/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from luqum.parser import parser

from app._types import QueryAnalysis, SearchParameters
from app._types import JSONType, QueryAnalysis, SearchParameters
from app.config import IndexConfig
from app.es_query_builder import FullTextQueryBuilder
from app.exceptions import QueryAnalysisError
Expand Down Expand Up @@ -72,11 +72,11 @@ def test_boost_phrases(query: str, proximity: int | None, expected: str):
@pytest.mark.parametrize(
"id_,q,langs,size,page,sort_by,facets,boost_phrase",
[
("simple_full_text_query", "flocons d'avoine", {"fr"}, 10, 1, None, None, True),
("simple_full_text_query", "flocons d'avoine", ["fr"], 10, 1, None, None, True),
(
"simple_full_text_query_facets",
"flocons d'avoine",
{"fr"},
["fr"],
10,
1,
None,
Expand All @@ -87,7 +87,7 @@ def test_boost_phrases(query: str, proximity: int | None, expected: str):
(
"sort_by_query",
"flocons d'avoine",
{"fr"},
["fr"],
10,
1,
"-unique_scans_n",
Expand All @@ -98,7 +98,7 @@ def test_boost_phrases(query: str, proximity: int | None, expected: str):
(
"simple_filter_query",
'countries:"en:italy"',
{"en"},
["en"],
25,
2,
None,
Expand All @@ -109,7 +109,7 @@ def test_boost_phrases(query: str, proximity: int | None, expected: str):
"complex_query",
'bacon de boeuf (countries:italy AND (categories:"en:beef" AND '
"(nutriments.salt_100g:[2 TO *] OR nutriments.salt_100g:[0 TO 0.05])))",
{"en"},
["en"],
25,
2,
None,
Expand All @@ -119,7 +119,7 @@ def test_boost_phrases(query: str, proximity: int | None, expected: str):
(
"empty_query_with_sort_by",
None,
{"en"},
["en"],
25,
2,
"unique_scans_n",
Expand All @@ -129,7 +129,7 @@ def test_boost_phrases(query: str, proximity: int | None, expected: str):
(
"empty_query_with_sort_by_and_facets",
None,
{"en"},
["en"],
25,
2,
"unique_scans_n",
Expand All @@ -139,7 +139,7 @@ def test_boost_phrases(query: str, proximity: int | None, expected: str):
(
"open_range",
"(unique_scans_n:>2 AND unique_scans_n:<3) OR unique_scans_n:>=10",
{"en"},
["en"],
25,
2,
None,
Expand All @@ -150,7 +150,7 @@ def test_boost_phrases(query: str, proximity: int | None, expected: str):
# it should be ok for now, until we implement subfields
"non_existing_subfield",
"Milk AND nutriments:(nonexisting:>=3)",
{"en"},
["en"],
25,
2,
None,
Expand All @@ -161,7 +161,7 @@ def test_boost_phrases(query: str, proximity: int | None, expected: str):
# * in a phrase is legit, it does not have the wildcard meaning
"wildcard_in_phrase_is_legit",
'Milk AND "*" AND categories:"*"',
{"en"},
["en"],
25,
2,
None,
Expand All @@ -177,7 +177,7 @@ def test_build_search_query(
# parameters
id_: str,
q: str,
langs: set[str],
langs: list[str],
size: int,
page: int,
sort_by: str | None,
Expand Down Expand Up @@ -244,7 +244,7 @@ def test_build_search_query_failure(
default_filter_query_builder: FullTextQueryBuilder,
):
# base search params
params = {
params: JSONType = {
"q": "Milk",
"langs": ["fr", "en"],
"page_size": 5,
Expand Down

0 comments on commit acd55a1

Please sign in to comment.