diff --git a/src/extendable_pydantic/main.py b/src/extendable_pydantic/main.py index 7a03109..f968627 100644 --- a/src/extendable_pydantic/main.py +++ b/src/extendable_pydantic/main.py @@ -3,7 +3,7 @@ import inspect import typing import warnings -from typing import Any, Dict, List, Optional, cast, no_type_check +from typing import Any, Dict, List, Optional, cast, no_type_check, Set from extendable import context, main from extendable.main import ExtendableMeta @@ -93,6 +93,13 @@ def _resolve_submodel_fields( """Replace the original field type into the definition of the field by the one from the registry.""" registry = registry if registry else context.extendable_registry.get() + resolved: Set["ExtendableModelMeta"] = getattr( + registry, "_resolved_models", set() + ) + if cls in resolved: + return + resolved.add(cls) + registry._resolved_models = resolved # type: ignore[union-attr] to_rebuild = False if issubclass(cls, BaseModel): for field_name, field_info in cast(BaseModel, cls).model_fields.items(): @@ -105,7 +112,6 @@ def _resolve_submodel_fields( if to_rebuild: delattr(cls, "__pydantic_core_schema__") cast(BaseModel, cls).model_rebuild(force=True) - return class RegistryListener(ExtendableRegistryListener): diff --git a/src/extendable_pydantic/utils.py b/src/extendable_pydantic/utils.py index 23b80a1..42afd68 100644 --- a/src/extendable_pydantic/utils.py +++ b/src/extendable_pydantic/utils.py @@ -113,7 +113,10 @@ def resolve_annotation( # semantics as "typing" classes or generic aliases if not origin_type and issubclass(type(type_), ExtendableMeta): - return type_._get_assembled_cls(registry) + final_type = type_._get_assembled_cls(registry) + if final_type is not type_: + final_type._resolve_submodel_fields(registry) + return final_type # Handle special case for typehints that can have lists as arguments. # `typing.Callable[[int, str], int]` is an example for this. diff --git a/tests/test_generics_inheritance.py b/tests/test_generics_inheritance.py index ba006a5..7c1f015 100644 --- a/tests/test_generics_inheritance.py +++ b/tests/test_generics_inheritance.py @@ -1,5 +1,5 @@ """Test generics model inheritance.""" -from typing import Generic, List, TypeVar +from typing import Generic, List, TypeVar, Optional try: from typing import Literal @@ -9,6 +9,7 @@ from pydantic.main import BaseModel from extendable_pydantic import ExtendableModelMeta +from extendable_pydantic.models import ExtendableBaseModel from .conftest import skip_not_supported_version_for_generics @@ -117,6 +118,64 @@ class SearchResultExtended(SearchResult[T], Generic[T], extends=SearchResult[T]) } +@skip_not_supported_version_for_generics +def test_generic_with_nested_extended(test_registry): + T = TypeVar("T") + + class SearchResult(ExtendableBaseModel, Generic[T]): + total: int + results: List[T] + + class Level(ExtendableBaseModel): + val: int + + class SearchLevelResult(SearchResult[Level]): + pass + + class Level11(ExtendableBaseModel): + val: int + + class Level1(ExtendableBaseModel): + val: int + level11: Optional[Level11] + + class Level11Extended(Level11, extends=True): + name: str = "level11" + + class Level1Extended(Level1, extends=True): + name: str = "level1" + + class LevelExtended(Level, extends=True): + name: str = "level" + level1: Optional[Level1] + + test_registry.init_registry() + + assert Level11(val=3).model_dump() == {"val": 3, "name": "level11"} + + item = SearchLevelResult( + total=0, + results=[Level(val=1, level1=Level1(val=2, level11=Level11(val=3)))], + ) + assert item.model_dump() == { + "total": 0, + "results": [ + { + "val": 1, + "level1": { + "val": 2, + "level11": { + "val": 3, + "name": "level11", + }, + "name": "level1", + }, + "name": "level", + } + ], + } + + @skip_not_supported_version_for_generics def test_extended_generics_of_extended_model(test_registry): """In this test we check that the extension of a genrics of extended model