Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to filter queries before construction of the CTE #66

Merged
merged 5 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions tests/testapp/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,3 +753,172 @@ def test_order_by_related(self):
tree.child2_2,
],
)

def test_tree_exclude(self):
tree = self.create_tree()
# Tree-filter should remove children if
# the parent meets the filtering criteria
nodes = Model.objects.tree_exclude(name="2")
self.assertEqual(
list(nodes),
[
tree.root,
tree.child1,
tree.child1_1,
],
)

def test_tree_filter(self):
tree = self.create_tree()
# Tree-filter should remove children if
# the parent does not meet the filtering criteria
nodes = Model.objects.tree_filter(name__in=["root","1-1","2","2-1","2-2"])
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_1,
tree.child2_2,
],
)

def test_tree_filter_chaining(self):
tree = self.create_tree()
# Tree-filter should remove children if
# the parent does not meet the filtering criteria
nodes = Model.objects.tree_exclude(name="2-2").tree_filter(name__in=["root","1-1","2","2-1","2-2"])
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_1,
],
)

def test_tree_filter_related(self):
tree = type("Namespace", (), {})() # SimpleNamespace for PY2...
rhomboss marked this conversation as resolved.
Show resolved Hide resolved

tree.root = RelatedOrderModel.objects.create(name="root")
tree.root_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.root, order=0
)
tree.child1 = RelatedOrderModel.objects.create(parent=tree.root, name="1")
tree.child1_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child1, order=0
)
tree.child2 = RelatedOrderModel.objects.create(parent=tree.root, name="2")
tree.child2_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child2, order=1
)
tree.child1_1 = RelatedOrderModel.objects.create(parent=tree.child1, name="1-1")
tree.child1_1_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child1_1, order=0
)
tree.child2_1 = RelatedOrderModel.objects.create(parent=tree.child2, name="2-1")
tree.child2_1_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child2_1, order=0
)
tree.child2_2 = RelatedOrderModel.objects.create(parent=tree.child2, name="2-2")
tree.child2_2_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child2_2, order=1
)

nodes = RelatedOrderModel.objects.tree_filter(related__order=0)
self.assertEqual(
list(nodes),
[
tree.root,
tree.child1,
tree.child1_1,
],
)

def test_tree_filter_with_order(self):
tree = type("Namespace", (), {})() # SimpleNamespace for PY2...

tree.root = MultiOrderedModel.objects.create(
name="root", first_position=1,
)
tree.child1 = MultiOrderedModel.objects.create(
parent=tree.root, first_position=0, second_position=1, name="1"
)
tree.child2 = MultiOrderedModel.objects.create(
parent=tree.root, first_position=1, second_position=0, name="2"
)
tree.child1_1 = MultiOrderedModel.objects.create(
parent=tree.child1, first_position=1, second_position=1, name="1-1"
)
tree.child2_1 = MultiOrderedModel.objects.create(
parent=tree.child2, first_position=1, second_position=1, name="2-1"
)
tree.child2_2 = MultiOrderedModel.objects.create(
parent=tree.child2, first_position=1, second_position=0, name="2-2"
)

nodes = (
MultiOrderedModel.objects
.tree_filter(first_position__gt=0)
.order_siblings_by("-second_position")
)
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_1,
tree.child2_2,
],
)

def test_tree_filter_Q_objects(self):
tree = self.create_tree()
# Tree-filter should remove children if
# the parent does not meet the filtering criteria
nodes = Model.objects.tree_filter(Q(name__in=["root","1-1","2","2-1","2-2"]))
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_1,
tree.child2_2,
],
)

def test_tree_filter_Q_mix(self):
tree = type("Namespace", (), {})() # SimpleNamespace for PY2...

tree.root = MultiOrderedModel.objects.create(
name="root", first_position=1, second_position=2
)
tree.child1 = MultiOrderedModel.objects.create(
parent=tree.root, first_position=1, second_position=0, name="1"
)
tree.child2 = MultiOrderedModel.objects.create(
parent=tree.root, first_position=1, second_position=2, name="2"
)
tree.child1_1 = MultiOrderedModel.objects.create(
parent=tree.child1, first_position=1, second_position=1, name="1-1"
)
tree.child2_1 = MultiOrderedModel.objects.create(
parent=tree.child2, first_position=1, second_position=1, name="2-1"
)
tree.child2_2 = MultiOrderedModel.objects.create(
parent=tree.child2, first_position=1, second_position=2, name="2-2"
)
# Tree-filter should remove children if
# the parent does not meet the filtering criteria
nodes = (
MultiOrderedModel.objects
.tree_filter(Q(first_position=1), second_position=2)
)
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_2,
],
)
147 changes: 82 additions & 65 deletions tree_queries/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.db import connections
from django.db.models import Value
from django.db.models import Value, F, Window, Expression, QuerySet
from django.db.models.functions import RowNumber
from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.query import Query

Expand All @@ -12,8 +13,38 @@ def _find_tree_model(cls):


class TreeQuery(Query):
# Set by TreeQuerySet.order_siblings_by
sibling_order = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._setup_query()

def _setup_query(self):
"""
Run on initialization and at the end of chaining. Any attributes that
would normally be set in __init__() should go here instead.
"""
# We add the variables for `sibling_order` and `rank_table_query` here so they
# act as instance variables which do not persist between user queries
# the way class variables do

# Only add the sibling_order attribute if the query doesn't already have one to preserve cloning behavior
if not hasattr(self, "sibling_order"):
# Add an attribute to control the ordering of siblings within trees
opts = _find_tree_model(self.model)._meta
self.sibling_order = (
opts.ordering
if opts.ordering
else opts.pk.attname
)

# Only add the rank_table_query attribute if the query doesn't already have one to preserve cloning behavior
if not hasattr(self, "rank_table_query"):
# Create a default QuerySet for the rank_table to use
# so we can avoid recursion
self.rank_table_query = QuerySet(
model=_find_tree_model(self.model)
)


def get_compiler(self, using=None, connection=None, **kwargs):
# Copied from django/db/models/sql/query.py
Expand All @@ -28,13 +59,10 @@ def get_compiler(self, using=None, connection=None, **kwargs):
return TreeCompiler(self, connection, using, **kwargs)

def get_sibling_order(self):
if self.sibling_order is not None:
return self.sibling_order
opts = _find_tree_model(self.model)._meta
if opts.ordering:
return opts.ordering
return opts.pk.attname
return self.sibling_order

def get_rank_table_query(self):
return self.rank_table_query

class TreeCompiler(SQLCompiler):
CTE_POSTGRESQL = """
Expand All @@ -43,11 +71,7 @@ class TreeCompiler(SQLCompiler):
"{parent}",
"rank_order"
) AS (
SELECT
{rank_pk},
{rank_parent},
ROW_NUMBER() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
{rank_table}
),
__tree (
"tree_depth",
Expand Down Expand Up @@ -77,11 +101,7 @@ class TreeCompiler(SQLCompiler):

CTE_MYSQL = """
WITH RECURSIVE __rank_table({pk}, {parent}, rank_order) AS (
SELECT
{rank_pk},
{rank_parent},
ROW_NUMBER() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
{rank_table}
),
__tree(tree_depth, tree_path, tree_ordering, tree_pk) AS (
SELECT
Expand All @@ -108,11 +128,7 @@ class TreeCompiler(SQLCompiler):

CTE_SQLITE3 = """
WITH RECURSIVE __rank_table({pk}, {parent}, rank_order) AS (
SELECT
{rank_pk},
{rank_parent},
row_number() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
{rank_table}
),
__tree(tree_depth, tree_path, tree_ordering, tree_pk) AS (
SELECT
Expand All @@ -135,13 +151,8 @@ class TreeCompiler(SQLCompiler):
)
"""

def get_sibling_order_params(self):
"""
This method uses a simple django queryset to generate sql
that can be used to create the __rank_table that orders
siblings. This is done so that any joins required by order_by
are pre-calculated by django
"""
def get_rank_table(self):
# Get and validate sibling_order
sibling_order = self.query.get_sibling_order()

if isinstance(sibling_order, (list, tuple)):
Expand All @@ -152,39 +163,42 @@ def get_sibling_order_params(self):
raise ValueError(
"Sibling order must be a string or a list or tuple of strings."
)

# Use Django to make a SQL query whose parts can be repurposed for __rank_table
base_query = (
_find_tree_model(self.query.model)
.objects.only("pk", "parent")
.order_by(*order_fields)
.query

# Convert strings to expressions. This is to maintain backwards compatibility
# with Django versions < 4.1
rhomboss marked this conversation as resolved.
Show resolved Hide resolved
base_order = []
for field in order_fields:
if isinstance(field, Expression):
base_order.append(field)
elif isinstance(field, str):
if field[0] == "-":
base_order.append(F(field[1:]).desc())
else:
base_order.append(F(field).asc())
order_fields = base_order
# End of back compat code

# Get the rank table query
rank_table_query = self.query.get_rank_table_query()

rank_table_query = (
rank_table_query
.order_by() # Ensure there is no ORDER BY at the end of the SQL
# Values allows us to both limit and specify the order of
# the columns selected so that they match the CTE
.values(
"pk",
"parent",
rank_order=Window(
expression=RowNumber(),
order_by=order_fields,
),
)
)

rank_table_sql, rank_table_params = rank_table_query.query.sql_with_params()

# Use the base compiler because we want vanilla sql and want to avoid recursion.
base_compiler = SQLCompiler(base_query, self.connection, None)
base_sql, base_params = base_compiler.as_sql()
result_sql = base_sql % base_params

# Split the base SQL string on the SQL keywords 'FROM' and 'ORDER BY'
from_split = result_sql.split("FROM")
order_split = from_split[1].split("ORDER BY")

# Identify the FROM and ORDER BY parts of the base SQL
ordering_params = {
"rank_from": order_split[0].strip(),
"rank_order_by": order_split[1].strip(),
}

# Identify the primary key field and parent_id field from the SELECT section
base_select = from_split[0][6:]
for field in base_select.split(","):
if "parent_id" in field: # XXX Taking advantage of Hardcoded.
ordering_params["rank_parent"] = field.strip()
else:
ordering_params["rank_pk"] = field.strip()

return ordering_params
return rank_table_sql, rank_table_params

def as_sql(self, *args, **kwargs):
# Try detecting if we're used in a EXISTS(1 as "a") subquery like
Expand Down Expand Up @@ -229,8 +243,9 @@ def as_sql(self, *args, **kwargs):
"sep": SEPARATOR,
}

# Add ordering params to params
params.update(self.get_sibling_order_params())
# Get the rank_table SQL and params
rank_table_sql, rank_table_params = self.get_rank_table()
params["rank_table"] = rank_table_sql

if "__tree" not in self.query.extra_tables: # pragma: no branch - unlikely
tree_params = params.copy()
Expand Down Expand Up @@ -280,7 +295,9 @@ def as_sql(self, *args, **kwargs):
if sql_0.startswith("EXPLAIN "):
explain, sql_0 = sql_0.split(" ", 1)

return ("".join([explain, cte.format(**params), sql_0]), sql_1)
# Pass any additional rank table sql paramaters so that the db backend can handle them.
# This only works because we know that the CTE is at the start of the query.
return ("".join([explain, cte.format(**params), sql_0]), rank_table_params + sql_1)

def get_converters(self, expressions):
converters = super().get_converters(expressions)
Expand Down
Loading
Loading