From 228a6331d78c2a64462c3f5e13d33a2b1d97c14f Mon Sep 17 00:00:00 2001 From: Matthias Kestenholz Date: Thu, 25 Apr 2024 14:09:05 +0200 Subject: [PATCH] Ugly hack for supporting additional recursive fields (#67) --- .pre-commit-config.yaml | 10 ++++----- pyproject.toml | 17 ++++++++------ tests/testapp/models.py | 5 ++++- tests/testapp/test_queries.py | 25 +++++++++++++++++---- tree_queries/compiler.py | 42 +++++++++++++++++++++++++++++------ tree_queries/fields.py | 2 +- tree_queries/query.py | 15 ++++++++----- 7 files changed, 85 insertions(+), 31 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da146f6..0f399c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ exclude: ".yarn/|yarn.lock|\\.min\\.(css|js)$" repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-added-large-files - id: check-builtin-literals @@ -14,7 +14,7 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/adamchainz/django-upgrade - rev: 1.15.0 + rev: 1.16.0 hooks: - id: django-upgrade args: [--target-version, "3.2"] @@ -23,7 +23,7 @@ repos: hooks: - id: absolufy-imports - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.1.6" + rev: "v0.4.1" hooks: - id: ruff - id: ruff-format @@ -34,10 +34,10 @@ repos: args: [--list-different, --no-semi] exclude: "^conf/|.*\\.html$" - repo: https://github.com/tox-dev/pyproject-fmt - rev: 1.5.1 + rev: 1.8.0 hooks: - id: pyproject-fmt - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.15 + rev: v0.16 hooks: - id: validate-pyproject diff --git a/pyproject.toml b/pyproject.toml index a367b68..a5d782d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,12 @@ include = ["tree_queries/"] path = "tree_queries/__init__.py" [tool.ruff] +fix = true +preview = true +show-fixes = true +target-version = "py38" + +[tool.ruff.lint] extend-select = [ # pyflakes, pycodestyle "F", "E", "W", @@ -80,7 +86,7 @@ extend-select = [ # pygrep-hooks "PGH", # pylint - "PL", + "PLC", "PLE", "PLW", # unused noqa "RUF100", ] @@ -90,18 +96,15 @@ extend-ignore = [ # No line length errors "E501", ] -fix = true -show-fixes = true -target-version = "py38" -[tool.ruff.isort] +[tool.ruff.lint.isort] combine-as-imports = true lines-after-imports = 2 -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] max-complexity = 15 -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "*/migrat*/*" = [ # Allow using PascalCase model names in migrations "N806", diff --git a/tests/testapp/models.py b/tests/testapp/models.py index 03ded71..4954f41 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -75,7 +75,7 @@ def __str__(self): class UUIDModel(TreeNode): - id = models.UUIDField(primary_key=True, default=uuid.uuid4) # noqa: A003 + id = models.UUIDField(primary_key=True, default=uuid.uuid4) name = models.CharField(max_length=100) def __str__(self): @@ -136,3 +136,6 @@ class OneToOneRelatedOrder(models.Model): related_name="related", ) order = models.PositiveIntegerField(default=0) + + def __str__(self): + return "" diff --git a/tests/testapp/test_queries.py b/tests/testapp/test_queries.py index 044b67f..8d034bd 100644 --- a/tests/testapp/test_queries.py +++ b/tests/testapp/test_queries.py @@ -1,8 +1,9 @@ +import unittest from types import SimpleNamespace from django import forms from django.core.exceptions import ValidationError -from django.db import connections, models +from django.db import connection, connections, models from django.db.models import Count, Q, Sum from django.db.models.expressions import RawSQL from django.test import TestCase, override_settings @@ -502,9 +503,7 @@ def test_annotate_tree(self): else: qs = qs.annotate( is_my_field=RawSQL( - 'instr(__tree.tree_path, "{sep}{pk}{sep}") <> 0'.format( - pk=pk(tree.child2_1), sep=SEPARATOR - ), + f'instr(__tree.tree_path, "{SEPARATOR}{pk(tree.child2_1)}{SEPARATOR}") <> 0', [], output_field=models.BooleanField(), ) @@ -926,3 +925,21 @@ def test_tree_filter_q_mix(self): tree.child2_2, ], ) + + @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests") + def test_extra_fields(self): + self.create_tree() + names = [ + obj.tree_names for obj in Model.objects.extra_fields(tree_names="name") + ] + self.assertEqual( + names, + [ + ["root"], + ["root", "1"], + ["root", "1", "1-1"], + ["root", "2"], + ["root", "2", "2-1"], + ["root", "2", "2-2"], + ], + ) diff --git a/tree_queries/compiler.py b/tree_queries/compiler.py index e5827bc..b5578e4 100644 --- a/tree_queries/compiler.py +++ b/tree_queries/compiler.py @@ -39,6 +39,9 @@ def _setup_query(self): # so we can avoid recursion self.rank_table_query = QuerySet(model=_find_tree_model(self.model)) + if not hasattr(self, "extra_fields"): + self.extra_fields = {} + def get_compiler(self, using=None, connection=None, **kwargs): # Copied from django/db/models/sql/query.py if using is None and connection is None: @@ -57,10 +60,14 @@ def get_sibling_order(self): def get_rank_table_query(self): return self.rank_table_query + def get_extra_fields(self): + return self.extra_fields + class TreeCompiler(SQLCompiler): CTE_POSTGRESQL = """ WITH RECURSIVE __rank_table( + {extra_fields_columns} "{pk}", "{parent}", "rank_order" @@ -68,12 +75,14 @@ class TreeCompiler(SQLCompiler): {rank_table} ), __tree ( + {extra_fields_names} "tree_depth", "tree_path", "tree_ordering", "tree_pk" ) AS ( SELECT + {extra_fields_initial} 0 AS tree_depth, array[T.{pk}] AS tree_path, array[T.rank_order] AS tree_ordering, @@ -84,6 +93,7 @@ class TreeCompiler(SQLCompiler): UNION ALL SELECT + {extra_fields_recursive} __tree.tree_depth + 1 AS tree_depth, __tree.tree_path || T.{pk}, __tree.tree_ordering || T.rank_order, @@ -180,6 +190,7 @@ def get_rank_table(self): # Values allows us to both limit and specify the order of # the columns selected so that they match the CTE .values( + *self.query.get_extra_fields().values(), "pk", "parent", rank_order=Window( @@ -240,6 +251,23 @@ def as_sql(self, *args, **kwargs): rank_table_sql, rank_table_params = self.get_rank_table() params["rank_table"] = rank_table_sql + extra_fields = self.query.get_extra_fields() + qn = self.connection.ops.quote_name + params.update({ + "extra_fields_columns": "".join( + f"{qn(column)}, " for column in extra_fields.values() + ), + "extra_fields_names": "".join(f"{qn(name)}, " for name in extra_fields), + "extra_fields_initial": "".join( + f"array[T.{qn(column)}]::text[] AS {qn(name)}, " + for name, column in extra_fields.items() + ), + "extra_fields_recursive": "".join( + f"__tree.{qn(name)} || T.{qn(column)}, " + for name, column in extra_fields.items() + ), + }) + if "__tree" not in self.query.extra_tables: # pragma: no branch - unlikely tree_params = params.copy() @@ -254,16 +282,16 @@ def as_sql(self, *args, **kwargs): if aliases: tree_params["db_table"] = aliases[0] + select = { + "tree_depth": "__tree.tree_depth", + "tree_path": "__tree.tree_path", + "tree_ordering": "__tree.tree_ordering", + } + select.update({name: f"__tree.{name}" for name in extra_fields}) self.query.add_extra( # Do not add extra fields to the select statement when it is a # summary query or when using .values() or .values_list() - select={} - if skip_tree_fields or self.query.values_select - else { - "tree_depth": "__tree.tree_depth", - "tree_path": "__tree.tree_path", - "tree_ordering": "__tree.tree_ordering", - }, + select={} if skip_tree_fields or self.query.values_select else select, select_params=None, where=["__tree.tree_pk = {db_table}.{pk}".format(**tree_params)], params=None, diff --git a/tree_queries/fields.py b/tree_queries/fields.py index 76340bd..fe5dd07 100644 --- a/tree_queries/fields.py +++ b/tree_queries/fields.py @@ -5,7 +5,7 @@ class TreeNodeForeignKey(models.ForeignKey): def deconstruct(self): - name, path, args, kwargs = super().deconstruct() + name, _path, args, kwargs = super().deconstruct() return (name, "django.db.models.ForeignKey", args, kwargs) def formfield(self, **kwargs): diff --git a/tree_queries/query.py b/tree_queries/query.py index 3a61b8e..3fac05d 100644 --- a/tree_queries/query.py +++ b/tree_queries/query.py @@ -76,7 +76,14 @@ def tree_exclude(self, *args, **kwargs): ) return self - def as_manager(cls, *, with_tree_fields=False): # noqa: N805 + def extra_fields(self, **extra_fields): + self.query.__class__ = TreeQuery + self.query._setup_query() + self.query.extra_fields = extra_fields + return self + + @classmethod + def as_manager(cls, *, with_tree_fields=False): manager_class = TreeManager.from_queryset(cls) # Only used in deconstruct: manager_class._built_with_as_manager = True @@ -87,7 +94,6 @@ def as_manager(cls, *, with_tree_fields=False): # noqa: N805 return manager_class() as_manager.queryset_only = True - as_manager = classmethod(as_manager) def ancestors(self, of, *, include_self=False): """ @@ -122,10 +128,7 @@ def descendants(self, of, *, include_self=False): where=[ # XXX This *may* be unsafe with some primary key field types. # It is certainly safe with integers. - 'instr(__tree.tree_path, "{sep}{pk}{sep}") <> 0'.format( - pk=self.model._meta.pk.get_db_prep_value(pk(of), connection), - sep=SEPARATOR, - ) + f'instr(__tree.tree_path, "{SEPARATOR}{self.model._meta.pk.get_db_prep_value(pk(of), connection)}{SEPARATOR}") <> 0' ] )