Skip to content

Commit

Permalink
partially fix coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
collerek committed Jan 25, 2024
1 parent 8d29c1f commit d286de5
Show file tree
Hide file tree
Showing 13 changed files with 54 additions and 119 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ test:
pytest

coverage:
pytest --cov=ormar --cov=tests --cov-fail-under=100 --cov-report=term-missing
pytest --cov=ormar --cov=tests --cov-fail-under=100 --cov-report=term-missing tests

type_check:
mkdir -p .mypy_cache && poetry run python -m mypy . --ignore-missing-imports --install-types --non-interactive
Expand Down
4 changes: 2 additions & 2 deletions ormar/fields/foreign_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,9 @@ def default_source_field_name(self) -> str:
def _evaluate_forward_ref(self, globalns: Any, localns: Any, is_through: bool = False) -> None:
target = "through" if is_through else "to"
target_obj = getattr(self, target)
if sys.version_info.minor <= 8:
if sys.version_info.minor <= 8: # pragma: no cover
evaluated = target_obj._evaluate(globalns, localns)
else:
else: # pragma: no cover
evaluated = target_obj._evaluate(globalns, localns, set())
setattr(self, target, evaluated)

Expand Down
16 changes: 15 additions & 1 deletion ormar/fields/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@ def encode_decimal(value: decimal.Decimal, precision: int = None) -> float:
return float(value)


def encode_bytes(value: Union[str, bytes], represent_as_string: bool = False) -> bytes:
def encode_bytes(value: Union[str, bytes], represent_as_string: bool = False) -> str:
if represent_as_string:
value = value if isinstance(value, str) else base64.b64encode(value).decode("utf-8")
else:
value = value if isinstance(value, str) else value.decode("utf-8")
return value


def decode_bytes(value: str, represent_as_string: bool = False) -> bytes:
if represent_as_string:
value = value if isinstance(value, bytes) else base64.b64decode(value)
else:
Expand Down Expand Up @@ -74,6 +82,11 @@ def re_dump_value(value: str) -> Union[str, bytes]:

SQL_ENCODERS_MAP: Dict[type, Callable] = {bool: encode_bool, **ENCODERS_MAP}

ADDITIONAL_PARAMETERS_MAP: Dict[type, str] = {
bytes: "",
decimal.Decimal: "precision"
}


DECODERS_MAP = {
bool: parse_bool,
Expand All @@ -82,4 +95,5 @@ def re_dump_value(value: str) -> Union[str, bytes]:
datetime.time: SchemaValidator(core_schema.time_schema()).validate_python,
pydantic.Json: json.loads,
decimal.Decimal: decimal.Decimal,
bytes: decode_bytes,
}
7 changes: 6 additions & 1 deletion ormar/fields/sqlalchemy_encrypted.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
except AttributeError:
encoder = ormar.SQL_ENCODERS_MAP.get(self.type_, None)
if encoder:
value = encoder(value) # type: ignore
if self.type_ == bytes:
value = encoder(value, self._field_type.represent_as_base64_str)
else:
value = encoder(value) # type: ignore

encrypted_value = self.backend.encrypt(value)
return encrypted_value
Expand All @@ -177,6 +180,8 @@ def process_result_value(self, value: Any, dialect: Dialect) -> Any:
except AttributeError:
decoder = ormar.DECODERS_MAP.get(self.type_, None)
if decoder:
if self.type_ == bytes:
return decoder(decrypted_value, self._field_type.represent_as_base64_str)
return decoder(decrypted_value) # type: ignore

return self._field_type.__type__(decrypted_value) # type: ignore
7 changes: 2 additions & 5 deletions ormar/models/descriptors/descriptors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
from typing import TYPE_CHECKING, Any, List, Type

from ormar.fields.parsers import encode_json
from ormar.fields.parsers import encode_json, decode_bytes

if TYPE_CHECKING: # pragma: no cover
from ormar import Model
Expand Down Expand Up @@ -65,10 +65,7 @@ def __get__(self, instance: "Model", owner: Type["Model"]) -> Any:
def __set__(self, instance: "Model", value: Any) -> None:
field = instance.ormar_config.model_fields[self.name]
if isinstance(value, str):
if field.represent_as_base64_str:
value = base64.b64decode(value)
else:
value = value.encode("utf-8")
value = decode_bytes(value=value, represent_as_string=field.represent_as_base64_str)
instance._internal_set(self.name, value)
instance.set_save_status(False)

Expand Down
2 changes: 1 addition & 1 deletion ormar/models/helpers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def serialize(
try:
serialized = handler([child])
except ValueError as exc:
if not str(exc).startswith("Circular reference"):
if not str(exc).startswith("Circular reference"): # pragma: no cover
raise exc
result.append({child.ormar_config.pkname: child.pk})
else:
Expand Down
60 changes: 0 additions & 60 deletions ormar/models/helpers/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import pydantic

import ormar # noqa: I100, I202
from ormar.models.helpers.models import config_field_not_set
from ormar.queryset.utils import translate_list_to_dict

Expand All @@ -30,33 +29,6 @@
from ormar.fields import BaseField


def convert_value_if_needed(field: "BaseField", value: Any) -> Any:
"""
Converts dates to isoformat as fastapi can check this condition in routes
and the fields are not yet parsed.
Converts enums to list of it's values.
Converts uuids to strings.
Converts decimal to float with given scale.
:param field: ormar field to check with choices
:type field: BaseField
:param value: current values of the model to verify
:type value: Any
:return: value, choices list
:rtype: Any
"""
encoder = ormar.ENCODERS_MAP.get(field.__type__, lambda x: x)
if field.__type__ == decimal.Decimal:
precision = field.scale # type: ignore
value = encoder(value, precision)
elif field.__type__ == bytes:
represent_as_string = field.represent_as_base64_str
value = encoder(value, represent_as_string)
elif encoder:
value = encoder(value)
return value


def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> Dict:
"""
Generates example to be included in schema in fastapi.
Expand Down Expand Up @@ -198,8 +170,6 @@ def overwrite_example_and_description(
:type model: Type["Model"]
"""
schema["example"] = generate_model_example(model=model)
if "Main base class of ormar Model." in schema.get("description", ""):
schema["description"] = f"{model.__name__}"


def overwrite_binary_format(schema: Dict[str, Any], model: Type["Model"]) -> None:
Expand All @@ -218,36 +188,6 @@ def overwrite_binary_format(schema: Dict[str, Any], model: Type["Model"]) -> Non
and model.ormar_config.model_fields[field_id].represent_as_base64_str
):
prop["format"] = "base64"
if prop.get("enum"):
prop["enum"] = [
base64.b64encode(choice).decode() for choice in prop.get("enum", [])
]


def construct_modify_schema_function(fields_with_choices: List) -> Callable:
"""
Modifies the schema to include fields with choices validator.
Those fields will be displayed in schema as Enum types with available choices
values listed next to them.
Note that schema extra has to be a function, otherwise it's called to soon
before all the relations are expanded.
:param fields_with_choices: list of fields with choices validation
:type fields_with_choices: List
:return: callable that will be run by pydantic to modify the schema
:rtype: Callable
"""

def schema_extra(schema: Dict[str, Any], model: Type["Model"]) -> None:
for field_id, prop in schema.get("properties", {}).items():
if field_id in fields_with_choices:
prop["enum"] = list(model.ormar_config.model_fields[field_id].choices)
prop["description"] = prop.get("description", "") + "An enumeration."
overwrite_example_and_description(schema=schema, model=model)
overwrite_binary_format(schema=schema, model=model)

return staticmethod(schema_extra) # type: ignore


def construct_schema_function_without_choices() -> Callable:
Expand Down
25 changes: 3 additions & 22 deletions ormar/models/newbasemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ormar.exceptions import ModelError, ModelPersistenceError
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.parsers import encode_json
from ormar.fields.parsers import encode_json, decode_bytes
from ormar.models.helpers import register_relation_in_alias_manager
from ormar.models.helpers.relations import expand_reverse_relationship
from ormar.models.helpers.sqlalchemy import (
Expand Down Expand Up @@ -425,22 +425,6 @@ def __same__(self, other: "NewBaseModel") -> bool:
else:
return hash(self) == other.__hash__()

def _copy_and_set_values(
self: "NewBaseModel", values: "DictStrAny", fields_set: "SetStr", *, deep: bool
) -> "NewBaseModel":
"""
Overwrite related models values with dict representation to avoid infinite
recursion through related fields.
"""
self_dict = values
self_dict.update(self.dict(exclude_list=True))
return cast(
"NewBaseModel",
super()._copy_and_set_values(
values=self_dict, fields_set=fields_set, deep=deep
),
)

@classmethod
def get_name(cls, lower: bool = True) -> str:
"""
Expand Down Expand Up @@ -989,11 +973,8 @@ def _convert_to_bytes(self, column_name: str, value: Any) -> Union[str, Dict]:
if column_name not in self._bytes_fields:
return value
field = self.ormar_config.model_fields[column_name]
if not isinstance(value, bytes) and value is not None:
if field.represent_as_base64_str:
value = base64.b64decode(value)
else:
value = value.encode("utf-8")
if value is not None:
value = decode_bytes(value=value, represent_as_string=field.represent_as_base64_str)
return value

def _convert_bytes_to_str(self, column_name: str, value: Any) -> Union[str, Dict]:
Expand Down
9 changes: 0 additions & 9 deletions ormar/queryset/field_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,6 @@ def __init__(
self._model = model
self._access_chain = access_chain

def __bool__(self) -> bool:
"""
Hack to avoid pydantic name check from parent model, returns false
:return: False
:rtype: bool
"""
return False

def __getattr__(self, item: str) -> Any:
"""
Accessor return new accessor for each field and nested models.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_encryption/test_encrypted_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class Author(ormar.Model):
test_smallint: int = ormar.SmallInteger(default=0, **default_fernet)
test_decimal = ormar.Decimal(scale=2, precision=10, **default_fernet)
test_decimal2 = ormar.Decimal(max_digits=10, decimal_places=2, **default_fernet)
test_bytes = ormar.LargeBinary(max_length=100, **default_fernet)
test_b64bytes = ormar.LargeBinary(max_length=100, represent_as_base64_str=True, **default_fernet)
custom_backend: str = ormar.String(
max_length=200,
encrypt_secret="asda8",
Expand Down Expand Up @@ -187,6 +189,8 @@ async def test_save_and_retrieve():
test_decimal2=decimal.Decimal(5.5),
test_json=dict(aa=12),
custom_backend="test12",
test_bytes=b"test",
test_b64bytes=b"test2"
).save()
author = await Author.objects.get()

Expand All @@ -209,6 +213,9 @@ async def test_save_and_retrieve():
assert author.test_decimal == 3.5
assert author.test_decimal2 == 5.5
assert author.custom_backend == "test12"
assert author.test_bytes == "test".encode("utf-8")
assert author.test_b64bytes == "dGVzdDI="
assert base64.b64decode(author.test_b64bytes) == b"test2"


@pytest.mark.asyncio
Expand Down
12 changes: 5 additions & 7 deletions tests/test_fastapi/test_skip_reverse_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,13 @@ def create_test_database():
metadata.drop_all(engine)


@app.post("/categories/", response_model=Category)
async def create_category(category: Category):
await category.save()
await category.save_related(follow=True, save_all=True)
return category
@app.post("/categories/forbid/", response_model=Category2)
async def create_category_forbid(category: Category2): # pragma: no cover
pass


@app.post("/categories/forbid/", response_model=Category2)
async def create_category_forbid(category: Category2):
@app.post("/categories/", response_model=Category)
async def create_category(category: Category):
await category.save()
await category.save_related(follow=True, save_all=True)
return category
Expand Down
11 changes: 11 additions & 0 deletions tests/test_model_definition/test_model_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ class JsonSample3(ormar.Model):
test_json = ormar.JSON(nullable=True)


def test_wrong_pydantic_config():
with pytest.raises(ModelDefinitionError):

class ErrorSample(ormar.Model):
model_config = ["test"]
ormar_config = ormar.OrmarConfig(tablename="jsons3")

id: int = ormar.Integer(primary_key=True)
test_json = ormar.JSON(nullable=True)


def test_non_existing_attr(example):
with pytest.raises(ValueError):
example.new_attr = 12
Expand Down
11 changes: 1 addition & 10 deletions tests/test_model_definition/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class LargeBinarySample(ormar.Model):
)

id: int = ormar.Integer(primary_key=True)
test_binary: bytes = ormar.LargeBinary(max_length=100000, choices=[blob, blob2])
test_binary: bytes = ormar.LargeBinary(max_length=100000)


blob3 = os.urandom(64)
Expand Down Expand Up @@ -518,15 +518,6 @@ async def test_model_first():

assert await User.objects.order_by("name").first() == jane


def not_contains(a, b):
return a not in b


def contains(a, b):
return a in b


@pytest.mark.asyncio
async def test_model_choices():
"""Test that choices work properly for various types of fields."""
Expand Down

0 comments on commit d286de5

Please sign in to comment.