Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting test suites to pass #1249

Merged
merged 8 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ormar/fields/foreign_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def evaluate_forward_ref(self, globalns: Any, localns: Any) -> None:
:rtype: None
"""
if self.to.__class__ == ForwardRef:
self.to = self.to._evaluate(globalns, localns, set())
self.to = self.to._evaluate(globalns, localns)
(
self.__type__,
self.constraints,
Expand Down
4 changes: 2 additions & 2 deletions ormar/fields/many_to_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def evaluate_forward_ref(self, globalns: Any, localns: Any) -> None:
:rtype: None
"""
if self.to.__class__ == ForwardRef:
self.to = self.to._evaluate(globalns, localns, set())
self.to = self.to._evaluate(globalns, localns)

(
self.__type__,
Expand All @@ -242,7 +242,7 @@ def evaluate_forward_ref(self, globalns: Any, localns: Any) -> None:
self.to_pk_only = pk_only_model

if self.through.__class__ == ForwardRef:
self.through = self.through._evaluate(globalns, localns, set())
self.through = self.through._evaluate(globalns, localns)

forbid_through_relations(self.through)

Expand Down
6 changes: 4 additions & 2 deletions ormar/fields/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ def encode_decimal(value: decimal.Decimal, precision: int = None) -> float:

def encode_bytes(value: Union[str, bytes], represent_as_string: bool = False) -> bytes:
if represent_as_string:
return value if isinstance(value, bytes) else base64.b64decode(value)
return value if isinstance(value, bytes) else value.encode("utf-8")
value= value if isinstance(value, bytes) else base64.b64decode(value)
else:
value = value if isinstance(value, bytes) else value.encode("utf-8")
return value


def encode_json(value: Any) -> Optional[str]:
Expand Down
3 changes: 1 addition & 2 deletions ormar/models/helpers/validation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import base64
import decimal
import numbers
from types import NoneType
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -170,7 +169,7 @@ def get_pydantic_example_repr(type_: Any) -> Any:
"""
if hasattr(type_, "__origin__"):
if type_.__origin__ == Union:
values = tuple(get_pydantic_example_repr(x) for x in type_.__args__ if x is not NoneType)
values = tuple(get_pydantic_example_repr(x) for x in type_.__args__ if x is not type(None))
if len(values) == 1:
return values[0]
return values
Expand Down
26 changes: 13 additions & 13 deletions ormar/models/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def copy_and_replace_m2m_through_model( # noqa: CFQ002
table_name: str,
parent_fields: Dict,
attrs: Dict,
meta: ModelMeta,
ormar_config: OrmarConfig,
base_class: Type["Model"],
) -> None:
"""
Expand All @@ -292,8 +292,8 @@ def copy_and_replace_m2m_through_model( # noqa: CFQ002
:type parent_fields: Dict
:param attrs: new namespace for class being constructed
:type attrs: Dict
:param meta: metaclass of currently created model
:type meta: ModelMeta
:param ormar_config: metaclass of currently created model
:type ormar_config: OrmarConfig
"""
Field: Type[BaseField] = type( # type: ignore
field.__class__.__name__, (ManyToManyField, BaseField), {}
Expand Down Expand Up @@ -333,7 +333,7 @@ def copy_and_replace_m2m_through_model( # noqa: CFQ002
# create new table with copied columns but remove foreign keys
# they will be populated later in expanding reverse relation
# if hasattr(new_meta, "table"):
new_meta.tablename += "_" + meta.tablename
new_meta.tablename += "_" + ormar_config.tablename
new_meta.table = None
new_meta.model_fields = {
name: field
Expand Down Expand Up @@ -392,20 +392,20 @@ def copy_data_from_parent_model( # noqa: CCR001
model_fields=model_fields,
)
parent_fields: Dict = dict()
meta = attrs.get("ormar_config")
if not meta: # pragma: no cover
ormar_config = attrs.get("ormar_config")
if not ormar_config: # pragma: no cover
raise ModelDefinitionError(
f"Model {curr_class.__name__} declared without ormar_config"
)
table_name = (
meta.tablename
if hasattr(meta, "tablename") and meta.tablename
ormar_config.tablename
if hasattr(ormar_config, "tablename") and ormar_config.tablename
else attrs.get("__name__", "").lower() + "s"
)
for field_name, field in base_class.ormar_config.model_fields.items():
if (
hasattr(meta, "exclude_parent_fields")
and field_name in meta.exclude_parent_fields
hasattr(ormar_config, "exclude_parent_fields")
and field_name in ormar_config.exclude_parent_fields
):
continue
if field.is_multi:
Expand All @@ -416,7 +416,7 @@ def copy_data_from_parent_model( # noqa: CCR001
table_name=table_name,
parent_fields=parent_fields,
attrs=attrs,
meta=meta,
ormar_config=ormar_config,
base_class=base_class, # type: ignore
)

Expand Down Expand Up @@ -574,9 +574,9 @@ def __new__( # type: ignore # noqa: CCR001
name: str,
bases: Any,
attrs: dict,
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
__pydantic_generic_metadata__: Union[PydanticGenericMetadata, None] = None,
__pydantic_reset_parent_namespace__: bool = True,
_create_model_module: str | None = None,
_create_model_module: Union[str, None] = None,
**kwargs
) -> "ModelMetaclass":
"""
Expand Down
18 changes: 7 additions & 11 deletions ormar/models/newbasemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import databases
import pydantic
import sqlalchemy
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


import ormar # noqa I100
Expand Down Expand Up @@ -152,7 +152,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore
values = new_kwargs
object.__setattr__(self, "__dict__", values)
object.__setattr__(self, "__pydantic_fields_set__", fields_set)

# add back through fields
new_kwargs.update(through_tmp_dict)
model_fields = object.__getattribute__(self, "ormar_config").model_fields
Expand All @@ -162,9 +161,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore
new_kwargs.get(related), self, to_register=True
)

if hasattr(self, "_init_private_attributes"):
# introduced in pydantic 1.7
self._init_private_attributes()

def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
"""
Expand Down Expand Up @@ -208,17 +204,17 @@ def __getattr__(self, item: str) -> Any:
return None
return super().__getattr__(item)

def __getattribute__(self, item: str) -> Any:
if item == "__dict__":
TouwaStar marked this conversation as resolved.
Show resolved Hide resolved
print(item, sys._getframe(1))
return super().__getattribute__(item)

def __getstate__(self) -> Dict[Any, Any]:
state = super().__getstate__()
self_dict = self.dict()
state["__dict__"].update(**self_dict)
return state

def __getattribute__(self, item: str) -> Any:
if item == "__dict__":
print(item, sys._getframe(1))
return super().__getattribute__(item)

def __setstate__(self, state: Dict[Any, Any]) -> None:
relations = {
k: v
Expand Down Expand Up @@ -293,7 +289,7 @@ def _process_kwargs(self, kwargs: Dict) -> Tuple[Dict, Dict]: # noqa: CCR001

Removes property_fields

Checks if field is in the model fields or pydatnic fields.
Checks if field is in the model fields or pydantic fields.

Nullifies fields that should be excluded.

Expand Down
4 changes: 2 additions & 2 deletions ormar/models/ormar_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
metadata: Optional[sqlalchemy.MetaData] = None,
database: Optional[databases.Database] = None,
tablename: Optional[str] = None,
order_by: Optional[list[str]] = None,
order_by: Optional[List[str]] = None,
abstract: bool = False,
exclude_parent_fields: Optional[List[str]] = None,
queryset_class: Type[QuerySet] = QuerySet,
Expand Down Expand Up @@ -50,7 +50,7 @@ def copy(
metadata: Optional[sqlalchemy.MetaData] = None,
database: Optional[databases.Database] = None,
tablename: Optional[str] = None,
order_by: Optional[list[str]] = None,
order_by: Optional[List[str]] = None,
abstract: Optional[bool] = None,
exclude_parent_fields: Optional[List[str]] = None,
queryset_class: Optional[Type[QuerySet]] = None,
Expand Down
2 changes: 1 addition & 1 deletion ormar/relations/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def add(self, child: "Model") -> None:
self._populate_owner_side_dict(rel=rel, child=child)
self._owner.__dict__[relation_name] = rel

def _populate_owner_side_dict(self, rel:List["Model"], child: "Model") -> None:
def _populate_owner_side_dict(self, rel: List["Model"], child: "Model") -> None:
try:
if child not in rel:
rel.append(child)
Expand Down
10 changes: 6 additions & 4 deletions tests/test_fastapi/test_binary_fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import json
import uuid
from enum import Enum
from typing import List
Expand Down Expand Up @@ -37,7 +36,7 @@ async def shutdown() -> None:
await database_.disconnect()


blob3 = b"\xc3\x28"
blob3 = b"\xc3\x83\x28"
collerek marked this conversation as resolved.
Show resolved Hide resolved
blob4 = b"\xf0\x28\x8c\x28"
blob5 = b"\xee"
blob6 = b"\xff"
Expand Down Expand Up @@ -94,9 +93,12 @@ async def test_read_main():
)
assert response.status_code == 200
response = await client.get("/things")
assert response.json()[0]["bt"] == base64.b64encode(blob3).decode()
thing = BinaryThing(**response.json()[0])
assert response.json()[0]["bt"] == blob3.decode()
resp_json = response.json()
resp_json[0]["bt"] = resp_json[0]["bt"].encode()
thing = BinaryThing(**resp_json[0])
assert thing.__dict__["bt"] == blob3
assert thing.bt == base64.b64encode(blob3).decode()


def test_schema():
Expand Down
1 change: 0 additions & 1 deletion tests/test_fastapi/test_fastapi_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pydantic
import pytest
import sqlalchemy
import uvicorn
from asgi_lifespan import LifespanManager
from fastapi import FastAPI
from httpx import AsyncClient
Expand Down