Skip to content

Commit

Permalink
Ensure recursive field annotation resolution
Browse files Browse the repository at this point in the history
At the end of the registry initialisation, we need to walk across all the fields defined on extendable models to replace the annoted orignal type by the one build into the registry if the field reference a extendable model. Once we replace the declared type buy the resolved one, we also need to rebuild the model schema to take into account this change. Prior to this change, the result was consistent. Indeed, the resolution mechanism was not applied recursively. As result the model rebuild for a class with a field declared as an extendable model type, could not contain  the resolved definition of fields declared into this referenced extendable model type if the resolution of the last one was not already done. The resolution mechanism is now recursive and when a annotation is resolved, we ensure that the new type is also resolved.
  • Loading branch information
lmignon committed Nov 20, 2023
1 parent 55cd466 commit 8f9fb77
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 4 deletions.
10 changes: 8 additions & 2 deletions src/extendable_pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import typing
import warnings
from typing import Any, Dict, List, Optional, cast, no_type_check
from typing import Any, Dict, List, Optional, cast, no_type_check, Set

from extendable import context, main
from extendable.main import ExtendableMeta
Expand Down Expand Up @@ -93,6 +93,13 @@ def _resolve_submodel_fields(
"""Replace the original field type into the definition of the field by the one
from the registry."""
registry = registry if registry else context.extendable_registry.get()
resolved: Set["ExtendableModelMeta"] = getattr(
registry, "_resolved_models", set()
)
if cls in resolved:
return
resolved.add(cls)
registry._resolved_models = resolved # type: ignore[union-attr]
to_rebuild = False
if issubclass(cls, BaseModel):
for field_name, field_info in cast(BaseModel, cls).model_fields.items():
Expand All @@ -105,7 +112,6 @@ def _resolve_submodel_fields(
if to_rebuild:
delattr(cls, "__pydantic_core_schema__")
cast(BaseModel, cls).model_rebuild(force=True)
return


class RegistryListener(ExtendableRegistryListener):
Expand Down
5 changes: 4 additions & 1 deletion src/extendable_pydantic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def resolve_annotation(
# semantics as "typing" classes or generic aliases

if not origin_type and issubclass(type(type_), ExtendableMeta):
return type_._get_assembled_cls(registry)
final_type = type_._get_assembled_cls(registry)
if final_type is not type_:
final_type._resolve_submodel_fields(registry)
return final_type

# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
Expand Down
61 changes: 60 additions & 1 deletion tests/test_generics_inheritance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test generics model inheritance."""
from typing import Generic, List, TypeVar
from typing import Generic, List, TypeVar, Optional

try:
from typing import Literal
Expand All @@ -9,6 +9,7 @@
from pydantic.main import BaseModel

from extendable_pydantic import ExtendableModelMeta
from extendable_pydantic.models import ExtendableBaseModel

from .conftest import skip_not_supported_version_for_generics

Expand Down Expand Up @@ -117,6 +118,64 @@ class SearchResultExtended(SearchResult[T], Generic[T], extends=SearchResult[T])
}


@skip_not_supported_version_for_generics
def test_generic_with_nested_extended(test_registry):
T = TypeVar("T")

class SearchResult(ExtendableBaseModel, Generic[T]):
total: int
results: List[T]

class Level(ExtendableBaseModel):
val: int

class SearchLevelResult(SearchResult[Level]):
pass

class Level11(ExtendableBaseModel):
val: int

class Level1(ExtendableBaseModel):
val: int
level11: Optional[Level11]

class Level11Extended(Level11, extends=True):
name: str = "level11"

class Level1Extended(Level1, extends=True):
name: str = "level1"

class LevelExtended(Level, extends=True):
name: str = "level"
level1: Optional[Level1]

test_registry.init_registry()

assert Level11(val=3).model_dump() == {"val": 3, "name": "level11"}

item = SearchLevelResult(
total=0,
results=[Level(val=1, level1=Level1(val=2, level11=Level11(val=3)))],
)
assert item.model_dump() == {
"total": 0,
"results": [
{
"val": 1,
"level1": {
"val": 2,
"level11": {
"val": 3,
"name": "level11",
},
"name": "level1",
},
"name": "level",
}
],
}


@skip_not_supported_version_for_generics
def test_extended_generics_of_extended_model(test_registry):
"""In this test we check that the extension of a genrics of extended model
Expand Down

0 comments on commit 8f9fb77

Please sign in to comment.