Skip to content

Commit

Permalink
Merge pull request #549 from yukinarit/raise-error-unsupported-flatten
Browse files Browse the repository at this point in the history
Raise error when flatten can not be used
  • Loading branch information
yukinarit authored Jun 13, 2024
2 parents 3eb4d07 + ae99304 commit cbfd8af
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
19 changes: 19 additions & 0 deletions serde/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,25 @@ def is_bare_opt(typ: Any) -> bool:
return not type_args(typ) and typ is Optional


@cache
def is_opt_dataclass(typ: Any) -> bool:
"""
Test if the type is optional dataclass.
>>> is_opt_dataclass(Optional[int])
False
>>> @dataclasses.dataclass
... class Foo:
... pass
>>> is_opt_dataclass(Foo)
False
>>> is_opt_dataclass(Optional[Foo])
False
"""
args = get_args(typ)
return is_opt(typ) and len(args) > 0 and is_dataclass(args[0])


@cache
def is_list(typ: type[Any]) -> bool:
"""
Expand Down
3 changes: 3 additions & 0 deletions serde/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_new_type_primitive,
is_any,
is_opt,
is_opt_dataclass,
is_set,
is_tuple,
is_union,
Expand Down Expand Up @@ -620,6 +621,8 @@ def from_dataclass(cls, f: dataclasses.Field[T], parent: Optional[Any] = None) -
flatten = f.metadata.get("serde_flatten")
if flatten is True:
flatten = FlattenOpts()
if flatten and not (dataclasses.is_dataclass(f.type) or is_opt_dataclass(f.type)):
raise SerdeError(f"pyserde does not support flatten attribute for {typename(f.type)}")

kw_only = bool(f.kw_only) if sys.version_info >= (3, 10) else False

Expand Down
15 changes: 14 additions & 1 deletion tests/test_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from serde import field, serde
from serde import field, serde, SerdeError
from serde.json import from_json, to_json

from .common import all_formats
Expand Down Expand Up @@ -68,3 +68,16 @@ class Foo:

f = Foo(a=10, b="foo", bar=Bar(c=100.0, d=True))
assert de(Foo, se(f)) == f


@pytest.mark.parametrize("se,de", all_formats)
def test_flatten_not_supported(se: Any, de: Any) -> None:
@serde
class Bar:
pass

with pytest.raises(SerdeError):

@serde
class Foo:
bar: list[Bar] = field(flatten=True)

0 comments on commit cbfd8af

Please sign in to comment.