diff --git a/Makefile b/Makefile index 2d80d56af..29fab21d1 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ test: pytest coverage: - pytest --cov=ormar --cov=tests --cov-fail-under=100 --cov-report=term-missing + pytest --cov=ormar --cov=tests --cov-fail-under=100 --cov-report=term-missing tests type_check: mkdir -p .mypy_cache && poetry run python -m mypy . --ignore-missing-imports --install-types --non-interactive diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index b6c63cbf2..f9c75e9b1 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -365,9 +365,9 @@ def default_source_field_name(self) -> str: def _evaluate_forward_ref(self, globalns: Any, localns: Any, is_through: bool = False) -> None: target = "through" if is_through else "to" target_obj = getattr(self, target) - if sys.version_info.minor <= 8: + if sys.version_info.minor <= 8: # pragma: no cover evaluated = target_obj._evaluate(globalns, localns) - else: + else: # pragma: no cover evaluated = target_obj._evaluate(globalns, localns, set()) setattr(self, target, evaluated) diff --git a/ormar/fields/parsers.py b/ormar/fields/parsers.py index dcc3f6cbe..c885e954b 100644 --- a/ormar/fields/parsers.py +++ b/ormar/fields/parsers.py @@ -31,7 +31,15 @@ def encode_decimal(value: decimal.Decimal, precision: int = None) -> float: return float(value) -def encode_bytes(value: Union[str, bytes], represent_as_string: bool = False) -> bytes: +def encode_bytes(value: Union[str, bytes], represent_as_string: bool = False) -> str: + if represent_as_string: + value = value if isinstance(value, str) else base64.b64encode(value).decode("utf-8") + else: + value = value if isinstance(value, str) else value.decode("utf-8") + return value + + +def decode_bytes(value: str, represent_as_string: bool = False) -> bytes: if represent_as_string: value = value if isinstance(value, bytes) else base64.b64decode(value) else: @@ -74,6 +82,11 @@ def re_dump_value(value: str) -> Union[str, bytes]: SQL_ENCODERS_MAP: Dict[type, Callable] = {bool: encode_bool, **ENCODERS_MAP} +ADDITIONAL_PARAMETERS_MAP: Dict[type, str] = { + bytes: "", + decimal.Decimal: "precision" +} + DECODERS_MAP = { bool: parse_bool, @@ -82,4 +95,5 @@ def re_dump_value(value: str) -> Union[str, bytes]: datetime.time: SchemaValidator(core_schema.time_schema()).validate_python, pydantic.Json: json.loads, decimal.Decimal: decimal.Decimal, + bytes: decode_bytes, } diff --git a/ormar/fields/sqlalchemy_encrypted.py b/ormar/fields/sqlalchemy_encrypted.py index 90518e33a..fe5d92f9c 100644 --- a/ormar/fields/sqlalchemy_encrypted.py +++ b/ormar/fields/sqlalchemy_encrypted.py @@ -162,7 +162,10 @@ def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: except AttributeError: encoder = ormar.SQL_ENCODERS_MAP.get(self.type_, None) if encoder: - value = encoder(value) # type: ignore + if self.type_ == bytes: + value = encoder(value, self._field_type.represent_as_base64_str) + else: + value = encoder(value) # type: ignore encrypted_value = self.backend.encrypt(value) return encrypted_value @@ -177,6 +180,8 @@ def process_result_value(self, value: Any, dialect: Dialect) -> Any: except AttributeError: decoder = ormar.DECODERS_MAP.get(self.type_, None) if decoder: + if self.type_ == bytes: + return decoder(decrypted_value, self._field_type.represent_as_base64_str) return decoder(decrypted_value) # type: ignore return self._field_type.__type__(decrypted_value) # type: ignore diff --git a/ormar/models/descriptors/descriptors.py b/ormar/models/descriptors/descriptors.py index f8ea6043f..ff788edbe 100644 --- a/ormar/models/descriptors/descriptors.py +++ b/ormar/models/descriptors/descriptors.py @@ -1,7 +1,7 @@ import base64 from typing import TYPE_CHECKING, Any, List, Type -from ormar.fields.parsers import encode_json +from ormar.fields.parsers import encode_json, decode_bytes if TYPE_CHECKING: # pragma: no cover from ormar import Model @@ -65,10 +65,7 @@ def __get__(self, instance: "Model", owner: Type["Model"]) -> Any: def __set__(self, instance: "Model", value: Any) -> None: field = instance.ormar_config.model_fields[self.name] if isinstance(value, str): - if field.represent_as_base64_str: - value = base64.b64decode(value) - else: - value = value.encode("utf-8") + value = decode_bytes(value=value, represent_as_string=field.represent_as_base64_str) instance._internal_set(self.name, value) instance.set_save_status(False) diff --git a/ormar/models/helpers/relations.py b/ormar/models/helpers/relations.py index 8ba7f1a4b..886e05a3b 100644 --- a/ormar/models/helpers/relations.py +++ b/ormar/models/helpers/relations.py @@ -175,7 +175,7 @@ def serialize( try: serialized = handler([child]) except ValueError as exc: - if not str(exc).startswith("Circular reference"): + if not str(exc).startswith("Circular reference"): # pragma: no cover raise exc result.append({child.ormar_config.pkname: child.pk}) else: diff --git a/ormar/models/helpers/validation.py b/ormar/models/helpers/validation.py index 71129bed2..f3a08adf0 100644 --- a/ormar/models/helpers/validation.py +++ b/ormar/models/helpers/validation.py @@ -21,7 +21,6 @@ import pydantic -import ormar # noqa: I100, I202 from ormar.models.helpers.models import config_field_not_set from ormar.queryset.utils import translate_list_to_dict @@ -30,33 +29,6 @@ from ormar.fields import BaseField -def convert_value_if_needed(field: "BaseField", value: Any) -> Any: - """ - Converts dates to isoformat as fastapi can check this condition in routes - and the fields are not yet parsed. - Converts enums to list of it's values. - Converts uuids to strings. - Converts decimal to float with given scale. - - :param field: ormar field to check with choices - :type field: BaseField - :param value: current values of the model to verify - :type value: Any - :return: value, choices list - :rtype: Any - """ - encoder = ormar.ENCODERS_MAP.get(field.__type__, lambda x: x) - if field.__type__ == decimal.Decimal: - precision = field.scale # type: ignore - value = encoder(value, precision) - elif field.__type__ == bytes: - represent_as_string = field.represent_as_base64_str - value = encoder(value, represent_as_string) - elif encoder: - value = encoder(value) - return value - - def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> Dict: """ Generates example to be included in schema in fastapi. @@ -198,8 +170,6 @@ def overwrite_example_and_description( :type model: Type["Model"] """ schema["example"] = generate_model_example(model=model) - if "Main base class of ormar Model." in schema.get("description", ""): - schema["description"] = f"{model.__name__}" def overwrite_binary_format(schema: Dict[str, Any], model: Type["Model"]) -> None: @@ -218,36 +188,6 @@ def overwrite_binary_format(schema: Dict[str, Any], model: Type["Model"]) -> Non and model.ormar_config.model_fields[field_id].represent_as_base64_str ): prop["format"] = "base64" - if prop.get("enum"): - prop["enum"] = [ - base64.b64encode(choice).decode() for choice in prop.get("enum", []) - ] - - -def construct_modify_schema_function(fields_with_choices: List) -> Callable: - """ - Modifies the schema to include fields with choices validator. - Those fields will be displayed in schema as Enum types with available choices - values listed next to them. - - Note that schema extra has to be a function, otherwise it's called to soon - before all the relations are expanded. - - :param fields_with_choices: list of fields with choices validation - :type fields_with_choices: List - :return: callable that will be run by pydantic to modify the schema - :rtype: Callable - """ - - def schema_extra(schema: Dict[str, Any], model: Type["Model"]) -> None: - for field_id, prop in schema.get("properties", {}).items(): - if field_id in fields_with_choices: - prop["enum"] = list(model.ormar_config.model_fields[field_id].choices) - prop["description"] = prop.get("description", "") + "An enumeration." - overwrite_example_and_description(schema=schema, model=model) - overwrite_binary_format(schema=schema, model=model) - - return staticmethod(schema_extra) # type: ignore def construct_schema_function_without_choices() -> Callable: diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 3ae65a4c9..778862a28 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -26,7 +26,7 @@ from ormar.exceptions import ModelError, ModelPersistenceError from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField -from ormar.fields.parsers import encode_json +from ormar.fields.parsers import encode_json, decode_bytes from ormar.models.helpers import register_relation_in_alias_manager from ormar.models.helpers.relations import expand_reverse_relationship from ormar.models.helpers.sqlalchemy import ( @@ -425,22 +425,6 @@ def __same__(self, other: "NewBaseModel") -> bool: else: return hash(self) == other.__hash__() - def _copy_and_set_values( - self: "NewBaseModel", values: "DictStrAny", fields_set: "SetStr", *, deep: bool - ) -> "NewBaseModel": - """ - Overwrite related models values with dict representation to avoid infinite - recursion through related fields. - """ - self_dict = values - self_dict.update(self.dict(exclude_list=True)) - return cast( - "NewBaseModel", - super()._copy_and_set_values( - values=self_dict, fields_set=fields_set, deep=deep - ), - ) - @classmethod def get_name(cls, lower: bool = True) -> str: """ @@ -989,11 +973,8 @@ def _convert_to_bytes(self, column_name: str, value: Any) -> Union[str, Dict]: if column_name not in self._bytes_fields: return value field = self.ormar_config.model_fields[column_name] - if not isinstance(value, bytes) and value is not None: - if field.represent_as_base64_str: - value = base64.b64decode(value) - else: - value = value.encode("utf-8") + if value is not None: + value = decode_bytes(value=value, represent_as_string=field.represent_as_base64_str) return value def _convert_bytes_to_str(self, column_name: str, value: Any) -> Union[str, Dict]: diff --git a/ormar/queryset/field_accessor.py b/ormar/queryset/field_accessor.py index 81440e07c..2ff557a21 100644 --- a/ormar/queryset/field_accessor.py +++ b/ormar/queryset/field_accessor.py @@ -26,15 +26,6 @@ def __init__( self._model = model self._access_chain = access_chain - def __bool__(self) -> bool: - """ - Hack to avoid pydantic name check from parent model, returns false - - :return: False - :rtype: bool - """ - return False - def __getattr__(self, item: str) -> Any: """ Accessor return new accessor for each field and nested models. diff --git a/tests/test_encryption/test_encrypted_columns.py b/tests/test_encryption/test_encrypted_columns.py index 01546baf4..35e317c4b 100644 --- a/tests/test_encryption/test_encrypted_columns.py +++ b/tests/test_encryption/test_encrypted_columns.py @@ -69,6 +69,8 @@ class Author(ormar.Model): test_smallint: int = ormar.SmallInteger(default=0, **default_fernet) test_decimal = ormar.Decimal(scale=2, precision=10, **default_fernet) test_decimal2 = ormar.Decimal(max_digits=10, decimal_places=2, **default_fernet) + test_bytes = ormar.LargeBinary(max_length=100, **default_fernet) + test_b64bytes = ormar.LargeBinary(max_length=100, represent_as_base64_str=True, **default_fernet) custom_backend: str = ormar.String( max_length=200, encrypt_secret="asda8", @@ -187,6 +189,8 @@ async def test_save_and_retrieve(): test_decimal2=decimal.Decimal(5.5), test_json=dict(aa=12), custom_backend="test12", + test_bytes=b"test", + test_b64bytes=b"test2" ).save() author = await Author.objects.get() @@ -209,6 +213,9 @@ async def test_save_and_retrieve(): assert author.test_decimal == 3.5 assert author.test_decimal2 == 5.5 assert author.custom_backend == "test12" + assert author.test_bytes == "test".encode("utf-8") + assert author.test_b64bytes == "dGVzdDI=" + assert base64.b64decode(author.test_b64bytes) == b"test2" @pytest.mark.asyncio diff --git a/tests/test_fastapi/test_skip_reverse_models.py b/tests/test_fastapi/test_skip_reverse_models.py index e7a8edb2d..4e879e2ff 100644 --- a/tests/test_fastapi/test_skip_reverse_models.py +++ b/tests/test_fastapi/test_skip_reverse_models.py @@ -74,15 +74,13 @@ def create_test_database(): metadata.drop_all(engine) -@app.post("/categories/", response_model=Category) -async def create_category(category: Category): - await category.save() - await category.save_related(follow=True, save_all=True) - return category +@app.post("/categories/forbid/", response_model=Category2) +async def create_category_forbid(category: Category2): # pragma: no cover + pass -@app.post("/categories/forbid/", response_model=Category2) -async def create_category_forbid(category: Category2): +@app.post("/categories/", response_model=Category) +async def create_category(category: Category): await category.save() await category.save_related(follow=True, save_all=True) return category diff --git a/tests/test_model_definition/test_model_definition.py b/tests/test_model_definition/test_model_definition.py index 82ce8c96b..57fd521c5 100644 --- a/tests/test_model_definition/test_model_definition.py +++ b/tests/test_model_definition/test_model_definition.py @@ -134,6 +134,17 @@ class JsonSample3(ormar.Model): test_json = ormar.JSON(nullable=True) +def test_wrong_pydantic_config(): + with pytest.raises(ModelDefinitionError): + + class ErrorSample(ormar.Model): + model_config = ["test"] + ormar_config = ormar.OrmarConfig(tablename="jsons3") + + id: int = ormar.Integer(primary_key=True) + test_json = ormar.JSON(nullable=True) + + def test_non_existing_attr(example): with pytest.raises(ValueError): example.new_attr = 12 diff --git a/tests/test_model_definition/test_models.py b/tests/test_model_definition/test_models.py index a7f664b38..efdfa6ae9 100644 --- a/tests/test_model_definition/test_models.py +++ b/tests/test_model_definition/test_models.py @@ -41,7 +41,7 @@ class LargeBinarySample(ormar.Model): ) id: int = ormar.Integer(primary_key=True) - test_binary: bytes = ormar.LargeBinary(max_length=100000, choices=[blob, blob2]) + test_binary: bytes = ormar.LargeBinary(max_length=100000) blob3 = os.urandom(64) @@ -518,15 +518,6 @@ async def test_model_first(): assert await User.objects.order_by("name").first() == jane - -def not_contains(a, b): - return a not in b - - -def contains(a, b): - return a in b - - @pytest.mark.asyncio async def test_model_choices(): """Test that choices work properly for various types of fields."""