Skip to content

Commit

Permalink
Ugly hack for supporting additional recursive fields (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiask authored Apr 25, 2024
1 parent 22e230f commit 228a633
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 31 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
exclude: ".yarn/|yarn.lock|\\.min\\.(css|js)$"
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-added-large-files
- id: check-builtin-literals
Expand All @@ -14,7 +14,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/adamchainz/django-upgrade
rev: 1.15.0
rev: 1.16.0
hooks:
- id: django-upgrade
args: [--target-version, "3.2"]
Expand All @@ -23,7 +23,7 @@ repos:
hooks:
- id: absolufy-imports
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.1.6"
rev: "v0.4.1"
hooks:
- id: ruff
- id: ruff-format
Expand All @@ -34,10 +34,10 @@ repos:
args: [--list-different, --no-semi]
exclude: "^conf/|.*\\.html$"
- repo: https://github.com/tox-dev/pyproject-fmt
rev: 1.5.1
rev: 1.8.0
hooks:
- id: pyproject-fmt
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.15
rev: v0.16
hooks:
- id: validate-pyproject
17 changes: 10 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ include = ["tree_queries/"]
path = "tree_queries/__init__.py"

[tool.ruff]
fix = true
preview = true
show-fixes = true
target-version = "py38"

[tool.ruff.lint]
extend-select = [
# pyflakes, pycodestyle
"F", "E", "W",
Expand Down Expand Up @@ -80,7 +86,7 @@ extend-select = [
# pygrep-hooks
"PGH",
# pylint
"PL",
"PLC", "PLE", "PLW",
# unused noqa
"RUF100",
]
Expand All @@ -90,18 +96,15 @@ extend-ignore = [
# No line length errors
"E501",
]
fix = true
show-fixes = true
target-version = "py38"

[tool.ruff.isort]
[tool.ruff.lint.isort]
combine-as-imports = true
lines-after-imports = 2

[tool.ruff.mccabe]
[tool.ruff.lint.mccabe]
max-complexity = 15

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"*/migrat*/*" = [
# Allow using PascalCase model names in migrations
"N806",
Expand Down
5 changes: 4 additions & 1 deletion tests/testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __str__(self):


class UUIDModel(TreeNode):
id = models.UUIDField(primary_key=True, default=uuid.uuid4) # noqa: A003
id = models.UUIDField(primary_key=True, default=uuid.uuid4)
name = models.CharField(max_length=100)

def __str__(self):
Expand Down Expand Up @@ -136,3 +136,6 @@ class OneToOneRelatedOrder(models.Model):
related_name="related",
)
order = models.PositiveIntegerField(default=0)

def __str__(self):
return ""
25 changes: 21 additions & 4 deletions tests/testapp/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import unittest
from types import SimpleNamespace

from django import forms
from django.core.exceptions import ValidationError
from django.db import connections, models
from django.db import connection, connections, models
from django.db.models import Count, Q, Sum
from django.db.models.expressions import RawSQL
from django.test import TestCase, override_settings
Expand Down Expand Up @@ -502,9 +503,7 @@ def test_annotate_tree(self):
else:
qs = qs.annotate(
is_my_field=RawSQL(
'instr(__tree.tree_path, "{sep}{pk}{sep}") <> 0'.format(
pk=pk(tree.child2_1), sep=SEPARATOR
),
f'instr(__tree.tree_path, "{SEPARATOR}{pk(tree.child2_1)}{SEPARATOR}") <> 0',
[],
output_field=models.BooleanField(),
)
Expand Down Expand Up @@ -926,3 +925,21 @@ def test_tree_filter_q_mix(self):
tree.child2_2,
],
)

@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests")
def test_extra_fields(self):
self.create_tree()
names = [
obj.tree_names for obj in Model.objects.extra_fields(tree_names="name")
]
self.assertEqual(
names,
[
["root"],
["root", "1"],
["root", "1", "1-1"],
["root", "2"],
["root", "2", "2-1"],
["root", "2", "2-2"],
],
)
42 changes: 35 additions & 7 deletions tree_queries/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def _setup_query(self):
# so we can avoid recursion
self.rank_table_query = QuerySet(model=_find_tree_model(self.model))

if not hasattr(self, "extra_fields"):
self.extra_fields = {}

def get_compiler(self, using=None, connection=None, **kwargs):
# Copied from django/db/models/sql/query.py
if using is None and connection is None:
Expand All @@ -57,23 +60,29 @@ def get_sibling_order(self):
def get_rank_table_query(self):
return self.rank_table_query

def get_extra_fields(self):
return self.extra_fields


class TreeCompiler(SQLCompiler):
CTE_POSTGRESQL = """
WITH RECURSIVE __rank_table(
{extra_fields_columns}
"{pk}",
"{parent}",
"rank_order"
) AS (
{rank_table}
),
__tree (
{extra_fields_names}
"tree_depth",
"tree_path",
"tree_ordering",
"tree_pk"
) AS (
SELECT
{extra_fields_initial}
0 AS tree_depth,
array[T.{pk}] AS tree_path,
array[T.rank_order] AS tree_ordering,
Expand All @@ -84,6 +93,7 @@ class TreeCompiler(SQLCompiler):
UNION ALL
SELECT
{extra_fields_recursive}
__tree.tree_depth + 1 AS tree_depth,
__tree.tree_path || T.{pk},
__tree.tree_ordering || T.rank_order,
Expand Down Expand Up @@ -180,6 +190,7 @@ def get_rank_table(self):
# Values allows us to both limit and specify the order of
# the columns selected so that they match the CTE
.values(
*self.query.get_extra_fields().values(),
"pk",
"parent",
rank_order=Window(
Expand Down Expand Up @@ -240,6 +251,23 @@ def as_sql(self, *args, **kwargs):
rank_table_sql, rank_table_params = self.get_rank_table()
params["rank_table"] = rank_table_sql

extra_fields = self.query.get_extra_fields()
qn = self.connection.ops.quote_name
params.update({
"extra_fields_columns": "".join(
f"{qn(column)}, " for column in extra_fields.values()
),
"extra_fields_names": "".join(f"{qn(name)}, " for name in extra_fields),
"extra_fields_initial": "".join(
f"array[T.{qn(column)}]::text[] AS {qn(name)}, "
for name, column in extra_fields.items()
),
"extra_fields_recursive": "".join(
f"__tree.{qn(name)} || T.{qn(column)}, "
for name, column in extra_fields.items()
),
})

if "__tree" not in self.query.extra_tables: # pragma: no branch - unlikely
tree_params = params.copy()

Expand All @@ -254,16 +282,16 @@ def as_sql(self, *args, **kwargs):
if aliases:
tree_params["db_table"] = aliases[0]

select = {
"tree_depth": "__tree.tree_depth",
"tree_path": "__tree.tree_path",
"tree_ordering": "__tree.tree_ordering",
}
select.update({name: f"__tree.{name}" for name in extra_fields})
self.query.add_extra(
# Do not add extra fields to the select statement when it is a
# summary query or when using .values() or .values_list()
select={}
if skip_tree_fields or self.query.values_select
else {
"tree_depth": "__tree.tree_depth",
"tree_path": "__tree.tree_path",
"tree_ordering": "__tree.tree_ordering",
},
select={} if skip_tree_fields or self.query.values_select else select,
select_params=None,
where=["__tree.tree_pk = {db_table}.{pk}".format(**tree_params)],
params=None,
Expand Down
2 changes: 1 addition & 1 deletion tree_queries/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class TreeNodeForeignKey(models.ForeignKey):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
name, _path, args, kwargs = super().deconstruct()
return (name, "django.db.models.ForeignKey", args, kwargs)

def formfield(self, **kwargs):
Expand Down
15 changes: 9 additions & 6 deletions tree_queries/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,14 @@ def tree_exclude(self, *args, **kwargs):
)
return self

def as_manager(cls, *, with_tree_fields=False): # noqa: N805
def extra_fields(self, **extra_fields):
self.query.__class__ = TreeQuery
self.query._setup_query()
self.query.extra_fields = extra_fields
return self

@classmethod
def as_manager(cls, *, with_tree_fields=False):
manager_class = TreeManager.from_queryset(cls)
# Only used in deconstruct:
manager_class._built_with_as_manager = True
Expand All @@ -87,7 +94,6 @@ def as_manager(cls, *, with_tree_fields=False): # noqa: N805
return manager_class()

as_manager.queryset_only = True
as_manager = classmethod(as_manager)

def ancestors(self, of, *, include_self=False):
"""
Expand Down Expand Up @@ -122,10 +128,7 @@ def descendants(self, of, *, include_self=False):
where=[
# XXX This *may* be unsafe with some primary key field types.
# It is certainly safe with integers.
'instr(__tree.tree_path, "{sep}{pk}{sep}") <> 0'.format(
pk=self.model._meta.pk.get_db_prep_value(pk(of), connection),
sep=SEPARATOR,
)
f'instr(__tree.tree_path, "{SEPARATOR}{self.model._meta.pk.get_db_prep_value(pk(of), connection)}{SEPARATOR}") <> 0'
]
)

Expand Down

0 comments on commit 228a633

Please sign in to comment.