Skip to content

Commit

Permalink
fix(models): coerce numbers to strings
Browse files Browse the repository at this point in the history
  • Loading branch information
lengau committed Sep 16, 2024
1 parent 044c7ce commit f7f01ca
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 42 deletions.
132 changes: 103 additions & 29 deletions craft_grammar/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Any, Generic, TypeVar, get_args, get_origin

import pydantic
import pydantic_core
from overrides import overrides
from pydantic import GetCoreSchemaHandler, ValidationError, ValidationInfo
from pydantic_core import core_schema
Expand Down Expand Up @@ -64,6 +65,100 @@ def _grammar_append(cls, entry: list[Any], item: Any, info: ValidationInfo) -> N
key, value = tuple(item.items())[0]
_mark_and_append(entry, {key: cls.validate(value, info)})

@classmethod
def _validate_grammar_list( # noqa: PLR0912
cls,
type_: type[list[T]],
input_value: list[Any],
info: ValidationInfo,
) -> list[T]:
# Check if the type_ is supposed to be a list
sub_type: Any = get_args(type_)

# handle typed list
if sub_type:
sub_type = sub_type[0]
if sub_type is Any:
sub_type = None

new_entry: list[Any] = []
errors: list[pydantic_core.InitErrorDetails] = []
for index, item in enumerate(input_value):
# Check if the item is a valid grammar clause
try:
if _is_grammar_clause(item):
cls._grammar_append(new_entry, item, info)
continue
except pydantic.ValidationError as exc:
errors.extend(
pydantic_core.InitErrorDetails(
type=err["type"],
loc=(index, *err["loc"]),
input=err["input"],
ctx=err.get("ctx", {"error": err}),
)
for err in exc.errors()
)
break
except ValueError as exc:
errors.append(
pydantic_core.InitErrorDetails(
type="value_error",
loc=(index,),
input=item,
ctx={"error": exc},
),
)
continue
if sub_type:
sub_type_adapter = pydantic.TypeAdapter(
sub_type,
config=pydantic.ConfigDict(coerce_numbers_to_str=True),
)
try:
new_entry.append(sub_type_adapter.validate_python(item))
except ValidationError:
pass
else:
continue
if issubclass(type_, str):
if isinstance(item, dict):
errors.append(
pydantic_core.InitErrorDetails(
type="value_error",
loc=(index,),
input=item,
ctx={
"error": ValueError(
f"value must be a str or valid grammar dict: {input_value!r}",
),
},
),
)
else:
raise pydantic.ValidationError.from_exception_data(
title=f"Grammar[{type_.__name__}]",
line_errors=[
pydantic_core.InitErrorDetails(
type="string_type",
loc=(),
input=item,
),
],
)
break
else:
raise ValueError( # noqa: TRY004
_format_type_error(type_, input_value),
)

if errors:
raise pydantic.ValidationError.from_exception_data(
title=f"Grammar[{type_.__name__}]",
line_errors=errors,
)
return new_entry


def _format_type_error(type_: type, entry: Any) -> str:
"""Format a type error message."""
Expand Down Expand Up @@ -106,35 +201,12 @@ class GrammarScalar(_GrammarBase):
def validate(cls, input_value: Any, /, info: ValidationInfo) -> Any:
# Grammar[T] entry can be a list if it contains clauses
if isinstance(input_value, list):
# Check if the type_ supposed to be a list
sub_type: Any = get_args(type_)

# handle typed list
if sub_type:
sub_type = sub_type[0]
if sub_type is Any:
sub_type = None

new_entry: list[Any] = []
for item in input_value:
# Check if the item is a valid grammar clause
if _is_grammar_clause(item):
cls._grammar_append(new_entry, item, info)
continue
if sub_type:
sub_type_adapter = pydantic.TypeAdapter(sub_type)
try:
sub_type_adapter.validate_python(item)
except ValidationError:
pass
else:
new_entry.append(item)
continue
raise ValueError(_format_type_error(type_, input_value))

return new_entry

type_adapter = pydantic.TypeAdapter(type_)
return cls._validate_grammar_list(type_, input_value, info)

type_adapter = pydantic.TypeAdapter(
type_,
config=pydantic.ConfigDict(coerce_numbers_to_str=True),
)

# Not a valid grammar, check if it is a dict
if isinstance(input_value, dict):
Expand All @@ -143,6 +215,8 @@ def validate(cls, input_value: Any, /, info: ValidationInfo) -> Any:
return input_value

# handle primitive types with pydantic validators
if isinstance(type_, type) and issubclass(type_, str):
return type_adapter.validate_python(input_value, strict=False)
try:
type_adapter.validate_python(input_value)
except ValidationError as err:
Expand Down
52 changes: 39 additions & 13 deletions tests/unit/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,21 @@ def test_validate_grammar_recursive():
]


@pytest.mark.parametrize("value", ["foo", 13, 3.14159])
def test_grammar_str_success(value):
class GrammarValidation(pydantic.BaseModel):
"""Test validation of grammar-enabled types."""

x: Grammar[str]

actual = GrammarValidation(x=value)

assert actual.x == str(value)


@pytest.mark.parametrize(
"value",
[["foo"], {"x"}, [{"a": "b"}]],
[["foo"], {"x"}],
)
def test_grammar_str_error(value):
class GrammarValidation(pydantic.BaseModel):
Expand All @@ -377,13 +389,27 @@ class GrammarValidation(pydantic.BaseModel):
err = raised.value.errors()
assert len(err) == 1
assert err[0]["loc"] == ("x",)
assert err[0]["type"] == "value_error"
assert err[0]["msg"] == f"Value error, value must be a str: {value!r}"
assert err[0]["type"] == "string_type"
assert err[0]["msg"] == "Input should be a valid string"


@pytest.mark.parametrize(
"value",
[["foo"], ["foo", 23]],
)
def test_grammar_strlist_success(value):
class GrammarValidation(pydantic.BaseModel):
"""Test validation of grammar-enabled types."""

x: Grammar[list[str]]

actual = GrammarValidation(x=value)
assert actual.x == [str(i) for i in value]


@pytest.mark.parametrize(
"value",
[23, "foo", ["foo", 23], [{"a": "b"}]],
[23, "foo", [{"a": "b"}]],
)
def test_grammar_strlist_error(value):
class GrammarValidation(pydantic.BaseModel):
Expand All @@ -393,7 +419,6 @@ class GrammarValidation(pydantic.BaseModel):

with pytest.raises(pydantic.ValidationError) as raised:
GrammarValidation(x=value)

err = raised.value.errors()
assert len(err) == 1
assert err[0]["loc"] == ("x",)
Expand All @@ -415,9 +440,9 @@ class GrammarValidation(pydantic.BaseModel):
)
err = raised.value.errors()
assert len(err) == 1
assert err[0]["loc"] == ("x",)
assert err[0]["type"] == "value_error"
assert err[0]["msg"] == "Value error, value must be a str: [35]"
assert err[0]["loc"] == ("x", 0, 1)
assert err[0]["type"] == "string_type"
assert err[0]["msg"] == "Input should be a valid string"


def test_grammar_str_elsefail():
Expand Down Expand Up @@ -453,7 +478,7 @@ class GrammarValidation(pydantic.BaseModel):

err = raised.value.errors()
assert len(err) == 1
assert err[0]["loc"] == ("x",)
assert err[0]["loc"] == ("x", 0)
assert err[0]["type"] == "value_error"
assert (
err[0]["msg"]
Expand All @@ -464,13 +489,14 @@ class GrammarValidation(pydantic.BaseModel):
@pytest.mark.parametrize(
("clause", "err_msg"),
[
("on", "value must be a str: [{'on': 'foo'}]"),
("a", "value must be a str or valid grammar dict: [{'a': 'foo'}]"),
("on", "value must be a str or valid grammar dict: [{'on': 'foo'}]"),
("on ,", "syntax error in 'on' selector"),
("on ,arch", "syntax error in 'on' selector"),
("on arch,", "syntax error in 'on' selector"),
("on arch,,arch", "syntax error in 'on' selector"),
("on arch, arch", "spaces are not allowed in 'on' selector"),
("to", "value must be a str: [{'to': 'foo'}]"),
("to", "value must be a str or valid grammar dict: [{'to': 'foo'}]"),
("to ,", "syntax error in 'to' selector"),
("to ,arch", "syntax error in 'to' selector"),
("to arch,", "syntax error in 'to' selector"),
Expand Down Expand Up @@ -499,5 +525,5 @@ class GrammarValidation(pydantic.BaseModel):

err = raised.value.errors()
assert len(err) == 1
assert err[0]["loc"] == ("x",)
assert err_msg in err[0]["msg"]
assert err[0]["loc"] == ("x", 0)
assert err[0]["msg"].endswith(err_msg)

0 comments on commit f7f01ca

Please sign in to comment.