diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index 73cc4f1d5..7e43c0eeb 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -12,7 +12,6 @@ overload, ) -from pydantic.v1.typing import evaluate_forwardref import ormar # noqa: I100 from ormar import ModelDefinitionError diff --git a/ormar/models/helpers/relations.py b/ormar/models/helpers/relations.py index 1ff41dc0b..1f2d71da0 100644 --- a/ormar/models/helpers/relations.py +++ b/ormar/models/helpers/relations.py @@ -1,5 +1,5 @@ import inspect -from typing import List, TYPE_CHECKING, Type, Union, cast +from typing import List, TYPE_CHECKING, Optional, Type, Union, cast from pydantic import BaseModel, create_model from pydantic.fields import FieldInfo @@ -138,7 +138,7 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None: ) if not model_field.skip_reverse: field_type = model_field.to.ormar_config.model_fields[related_name].__type__ - field_type = replace_models_with_copy(annotation=field_type) + field_type = replace_models_with_copy(annotation=field_type, source_model_field=model_field.name) if not model_field.is_multi: field_type = Union[field_type, List[field_type], None] model_field.to.model_fields[related_name] = FieldInfo.from_annotated_attribute( @@ -148,7 +148,7 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None: setattr(model_field.to, related_name, RelationDescriptor(name=related_name)) -def replace_models_with_copy(annotation: Type) -> Type: +def replace_models_with_copy(annotation: Type, source_model_field: Optional[str] = None) -> Type: """ Replaces all models in annotation with their copies to avoid circular references. @@ -157,16 +157,16 @@ def replace_models_with_copy(annotation: Type) -> Type: :return: annotation with replaced models :rtype: Type """ - if inspect.isclass(annotation) and issubclass(annotation, BaseModel): - return create_copy_to_avoid_circular_references(model=annotation) + if inspect.isclass(annotation) and issubclass(annotation, ormar.Model): + return create_copy_to_avoid_circular_references(model=annotation, source_model_field=source_model_field) elif hasattr(annotation, "__origin__"): if annotation.__origin__ == list: return List[ - replace_models_with_copy(annotation=annotation.__args__[0]) + replace_models_with_copy(annotation=annotation.__args__[0], source_model_field=source_model_field) ] # type: ignore elif annotation.__origin__ == Union: args = annotation.__args__ - new_args = [replace_models_with_copy(annotation=arg) for arg in args] + new_args = [replace_models_with_copy(annotation=arg, source_model_field=source_model_field) for arg in args] return Union[tuple(new_args)] else: return annotation @@ -174,8 +174,14 @@ def replace_models_with_copy(annotation: Type) -> Type: return annotation -def create_copy_to_avoid_circular_references(model: Type["Model"]) -> Type["BaseModel"]: - return cast(Type[BaseModel], type(model.__name__, (model,), {})) +def create_copy_to_avoid_circular_references(model: Type["Model"], source_model_field: Optional[str] = None) -> Type["BaseModel"]: + new_model = create_model( + model.__name__, + __base__=model, + **{k: (v.annotation, v.default) for k, v in model.model_fields.items() if k != source_model_field}, + ) + new_model.model_fields.pop(source_model_field, None) + return new_model def register_through_shortcut_fields(model_field: "ManyToManyField") -> None: diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index cb2582fb0..8944d03e9 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -16,6 +16,7 @@ import databases import pydantic import sqlalchemy +from pydantic._internal._generics import PydanticGenericMetadata from pydantic._internal._model_construction import complete_model_class from pydantic.fields import ComputedFieldInfo, FieldInfo from sqlalchemy.sql.schema import ColumnCollectionConstraint @@ -569,7 +570,14 @@ def add_field_descriptor( class ModelMetaclass(pydantic._internal._model_construction.ModelMetaclass): def __new__( # type: ignore # noqa: CCR001 - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + mcs: "ModelMetaclass", + name: str, + bases: Any, + attrs: dict, + __pydantic_generic_metadata__: PydanticGenericMetadata | None = None, + __pydantic_reset_parent_namespace__: bool = True, + _create_model_module: str | None = None, + **kwargs ) -> "ModelMetaclass": """ Metaclass used by ormar Models that performs configuration @@ -614,7 +622,15 @@ def __new__( # type: ignore # noqa: CCR001 if "ormar_config" in attrs: attrs["model_config"]["ignored_types"] = (OrmarConfig,) attrs["model_config"]["from_attributes"] = True - new_model = super().__new__(mcs, name, bases, attrs) # type: ignore + new_model = super().__new__( + mcs, # type: ignore + name, + bases, + attrs, + __pydantic_generic_metadata__=__pydantic_generic_metadata__, + __pydantic_reset_parent_namespace__=__pydantic_reset_parent_namespace__, + _create_model_module=_create_model_module, + **kwargs) add_cached_properties(new_model) diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 51ea6ce6b..9989a9b09 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -206,8 +206,6 @@ def __getattr__(self, item: str) -> Any: # TODO: Check __pydantic_extra__ if item == "__pydantic_extra__": return None - if item =="__pydantic_serializer__": - breakpoint() return super().__getattr__(item) def __getstate__(self) -> Dict[Any, Any]: diff --git a/tests/test_fastapi/test_binary_fields.py b/tests/test_fastapi/test_binary_fields.py index 5dce624e5..cd06831c9 100644 --- a/tests/test_fastapi/test_binary_fields.py +++ b/tests/test_fastapi/test_binary_fields.py @@ -1,6 +1,7 @@ import base64 import json import uuid +from enum import Enum from typing import List import databases @@ -47,15 +48,20 @@ async def shutdown() -> None: database=database, ) + +class BinaryEnum(Enum): + blob3 = blob3 + blob4 = blob4 + blob5 = blob5 + blob6 = blob6 + + class BinaryThing(ormar.Model): ormar_config = base_ormar_config.copy(tablename = "things") id: uuid.UUID = ormar.UUID(primary_key=True, default=uuid.uuid4) name: str = ormar.Text(default="") - bt: str = ormar.LargeBinary( - max_length=1000, - choices=[blob3, blob4, blob5, blob6], - represent_as_base64_str=True, + bt: str = ormar.Enum(enum_class=BinaryEnum, represent_as_base64_str=True, ) diff --git a/tests/test_fastapi/test_fastapi_docs.py b/tests/test_fastapi/test_fastapi_docs.py index fe0662f79..085da98ad 100644 --- a/tests/test_fastapi/test_fastapi_docs.py +++ b/tests/test_fastapi/test_fastapi_docs.py @@ -5,6 +5,7 @@ import pydantic import pytest import sqlalchemy +import uvicorn from asgi_lifespan import LifespanManager from fastapi import FastAPI from httpx import AsyncClient @@ -160,4 +161,6 @@ def test_schema_modification(): def test_schema_gen(): schema = app.openapi() assert "Category" in schema["components"]["schemas"] - assert "Item" in schema["components"]["schemas"] + subschemas = [x.split("__")[-1] for x in schema["components"]["schemas"]] + assert "Item-Input" in subschemas + assert "Item-Output" in subschemas