From 6a95af18f6ef9fa135f1dc313a6abffdb0699580 Mon Sep 17 00:00:00 2001 From: rhomboss <54870528+rhomboss@users.noreply.github.com> Date: Mon, 15 Apr 2024 12:56:40 -0500 Subject: [PATCH 1/5] Update query.py (#6) * Update query.py Add `pre_filter` and `pre_exclude` methods to TreeQuerySet. Make query methods call `_setup_query` to deal with sibling order and pre_filter persistence issues. * Update compiler.py Replace the `get_sibling_order_params` with `get_rank_table_params` to support early tree filtering. Change `sibling_order` and `pre_filter` from class variables to variables that have to be initiated by `_setup_query` so that they don't persist between user queries. Handle pre_filter params by passing them to the django backend. * Update test_queries.py Added tests for the `pre_filter` and `pre_exclude` methods. --- tests/testapp/test_queries.py | 118 +++++++++++++++++++++++++++++++ tree_queries/compiler.py | 126 +++++++++++++++++++++++----------- tree_queries/query.py | 26 +++++++ 3 files changed, 231 insertions(+), 39 deletions(-) diff --git a/tests/testapp/test_queries.py b/tests/testapp/test_queries.py index 344ce45..2a935a9 100644 --- a/tests/testapp/test_queries.py +++ b/tests/testapp/test_queries.py @@ -753,3 +753,121 @@ def test_order_by_related(self): tree.child2_2, ], ) + + def test_pre_exclude(self): + tree = self.create_tree() + # Pre-filter should remove children if + # the parent meets the filtering criteria + nodes = Model.objects.pre_exclude(name="2") + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child1, + tree.child1_1, + ], + ) + + def test_pre_filter(self): + tree = self.create_tree() + # Pre-filter should remove children if + # the parent does not meet the filtering criteria + nodes = Model.objects.pre_filter(name__in=["root","1-1","2","2-1","2-2"]) + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child2, + tree.child2_1, + tree.child2_2, + ], + ) + + def test_pre_filter_chaining(self): + tree = self.create_tree() + # Pre-filter should remove children if + # the parent does not meet the filtering criteria + nodes = Model.objects.pre_exclude(name="2-2").pre_filter(name__in=["root","1-1","2","2-1","2-2"]) + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child2, + tree.child2_1, + ], + ) + + def test_pre_filter_related(self): + tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + + tree.root = RelatedOrderModel.objects.create(name="root") + tree.root_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.root, order=0 + ) + tree.child1 = RelatedOrderModel.objects.create(parent=tree.root, name="1") + tree.child1_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.child1, order=0 + ) + tree.child2 = RelatedOrderModel.objects.create(parent=tree.root, name="2") + tree.child2_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.child2, order=1 + ) + tree.child1_1 = RelatedOrderModel.objects.create(parent=tree.child1, name="1-1") + tree.child1_1_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.child1_1, order=0 + ) + tree.child2_1 = RelatedOrderModel.objects.create(parent=tree.child2, name="2-1") + tree.child2_1_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.child2_1, order=0 + ) + tree.child2_2 = RelatedOrderModel.objects.create(parent=tree.child2, name="2-2") + tree.child2_2_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.child2_2, order=1 + ) + + nodes = RelatedOrderModel.objects.pre_filter(related__order=0) + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child1, + tree.child1_1, + ], + ) + + def test_pre_filter_with_order(self): + tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + + tree.root = MultiOrderedModel.objects.create( + name="root", first_position=1, + ) + tree.child1 = MultiOrderedModel.objects.create( + parent=tree.root, first_position=0, second_position=1, name="1" + ) + tree.child2 = MultiOrderedModel.objects.create( + parent=tree.root, first_position=1, second_position=0, name="2" + ) + tree.child1_1 = MultiOrderedModel.objects.create( + parent=tree.child1, first_position=1, second_position=1, name="1-1" + ) + tree.child2_1 = MultiOrderedModel.objects.create( + parent=tree.child2, first_position=1, second_position=1, name="2-1" + ) + tree.child2_2 = MultiOrderedModel.objects.create( + parent=tree.child2, first_position=1, second_position=0, name="2-2" + ) + + nodes = ( + MultiOrderedModel.objects + .pre_filter(first_position__gt=0) + .order_siblings_by("-second_position") + ) + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child2, + tree.child2_1, + tree.child2_2, + ], + ) diff --git a/tree_queries/compiler.py b/tree_queries/compiler.py index 74cc32c..dfbc8bc 100644 --- a/tree_queries/compiler.py +++ b/tree_queries/compiler.py @@ -12,8 +12,34 @@ def _find_tree_model(cls): class TreeQuery(Query): - # Set by TreeQuerySet.order_siblings_by - sibling_order = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._setup_query() + + def _setup_query(self): + """ + Run on initialization and at the end of chaining. Any attributes that + would normally be set in __init__() should go here instead. + """ + # We add the variables for `sibling_order` and `pre_filter` here so they + # act as instance variables which do not persist between user queries + # the way class variables do + + # Only add the sibling_order attribute if the query doesn't already have one to preserve cloning behavior + if not hasattr(self, "sibling_order"): + # Add an attribute to control the ordering of siblings within trees + opts = _find_tree_model(self.model)._meta + self.sibling_order = ( + opts.ordering + if opts.ordering + else opts.pk.attname + ) + + # Only add the pre_filter attribute if the query doesn't already have one to preserve cloning behavior + if not hasattr(self, "pre_filter"): + self.pre_filter = [] + def get_compiler(self, using=None, connection=None, **kwargs): # Copied from django/db/models/sql/query.py @@ -28,13 +54,10 @@ def get_compiler(self, using=None, connection=None, **kwargs): return TreeCompiler(self, connection, using, **kwargs) def get_sibling_order(self): - if self.sibling_order is not None: - return self.sibling_order - opts = _find_tree_model(self.model)._meta - if opts.ordering: - return opts.ordering - return opts.pk.attname + return self.sibling_order + def get_pre_filter(self): + return self.pre_filter class TreeCompiler(SQLCompiler): CTE_POSTGRESQL = """ @@ -48,6 +71,7 @@ class TreeCompiler(SQLCompiler): {rank_parent}, ROW_NUMBER() OVER (ORDER BY {rank_order_by}) FROM {rank_from} + {pre_filter} ), __tree ( "tree_depth", @@ -82,6 +106,7 @@ class TreeCompiler(SQLCompiler): {rank_parent}, ROW_NUMBER() OVER (ORDER BY {rank_order_by}) FROM {rank_from} + {pre_filter} ), __tree(tree_depth, tree_path, tree_ordering, tree_pk) AS ( SELECT @@ -113,6 +138,7 @@ class TreeCompiler(SQLCompiler): {rank_parent}, row_number() OVER (ORDER BY {rank_order_by}) FROM {rank_from} + {pre_filter} ), __tree(tree_depth, tree_path, tree_ordering, tree_pk) AS ( SELECT @@ -135,13 +161,14 @@ class TreeCompiler(SQLCompiler): ) """ - def get_sibling_order_params(self): + def get_rank_table_params(self): """ This method uses a simple django queryset to generate sql - that can be used to create the __rank_table that orders - siblings. This is done so that any joins required by order_by - are pre-calculated by django + that can be used to create the __rank_table that pre-filters + and orders siblings. This is done so that any joins required + by order_by or filter/exclude are pre-calculated by django """ + # Get can validate sibling_order sibling_order = self.query.get_sibling_order() if isinstance(sibling_order, (list, tuple)): @@ -152,39 +179,57 @@ def get_sibling_order_params(self): raise ValueError( "Sibling order must be a string or a list or tuple of strings." ) - - # Use Django to make a SQL query whose parts can be repurposed for __rank_table - base_query = ( - _find_tree_model(self.query.model) - .objects.only("pk", "parent") - .order_by(*order_fields) - .query - ) - - # Use the base compiler because we want vanilla sql and want to avoid recursion. + + # Get pre_filter + pre_filter = self.query.get_pre_filter() + + # Use Django to make a SQL query that can be repurposed for __rank_table + base_query = _find_tree_model(self.query.model).objects.only("pk", "parent") + + # Add pre_filters if they exist + if pre_filter: + # Apply filters and excludes to the query in the order provided by the user + for is_filter, filter_fields in pre_filter: + if is_filter: + base_query = base_query.filter(**filter_fields) + else: + base_query = base_query.exclude(**filter_fields) + + # Apply sibling_order + base_query = base_query.order_by(*order_fields).query + + # Get SQL and parameters base_compiler = SQLCompiler(base_query, self.connection, None) base_sql, base_params = base_compiler.as_sql() - result_sql = base_sql % base_params - # Split the base SQL string on the SQL keywords 'FROM' and 'ORDER BY' - from_split = result_sql.split("FROM") - order_split = from_split[1].split("ORDER BY") + # Split sql on the last ORDER BY to get the rank_order param + head, sep, tail = base_sql.rpartition("ORDER BY") - # Identify the FROM and ORDER BY parts of the base SQL - ordering_params = { - "rank_from": order_split[0].strip(), - "rank_order_by": order_split[1].strip(), + # Add rank_order_by to params + rank_table_params = { + "rank_order_by": tail.strip(), } - # Identify the primary key field and parent_id field from the SELECT section - base_select = from_split[0][6:] - for field in base_select.split(","): + # Split on the first WHERE if present to get the pre_filter param + if pre_filter: + head, sep, tail = head.partition("WHERE") + rank_table_params["pre_filter"] = "WHERE " + tail.strip() # Note the space after WHERE + else: + rank_table_params["pre_filter"] = "" + + # Split on the first FROM to get any joins etc. + head, sep, tail = head.partition("FROM") + rank_table_params["rank_from"] = tail.strip() + + # Identify the parent and primary key fields + head, sep, tail = head.partition("SELECT") + for field in tail.split(","): if "parent_id" in field: # XXX Taking advantage of Hardcoded. - ordering_params["rank_parent"] = field.strip() + rank_table_params["rank_parent"] = field.strip() else: - ordering_params["rank_pk"] = field.strip() + rank_table_params["rank_pk"] = field.strip() - return ordering_params + return rank_table_params, base_params def as_sql(self, *args, **kwargs): # Try detecting if we're used in a EXISTS(1 as "a") subquery like @@ -229,8 +274,9 @@ def as_sql(self, *args, **kwargs): "sep": SEPARATOR, } - # Add ordering params to params - params.update(self.get_sibling_order_params()) + # Get params needed by the rank_table + rank_table_params, rank_table_sql_params = self.get_rank_table_params() + params.update(rank_table_params) if "__tree" not in self.query.extra_tables: # pragma: no branch - unlikely tree_params = params.copy() @@ -280,7 +326,9 @@ def as_sql(self, *args, **kwargs): if sql_0.startswith("EXPLAIN "): explain, sql_0 = sql_0.split(" ", 1) - return ("".join([explain, cte.format(**params), sql_0]), sql_1) + # Pass any additional rank table sql paramaters so that the db backend can handle them. + # This only works because we know that the CTE is at the start of the query. + return ("".join([explain, cte.format(**params), sql_0]), rank_table_sql_params + sql_1) def get_converters(self, expressions): converters = super().get_converters(expressions) diff --git a/tree_queries/query.py b/tree_queries/query.py index 250ad48..b999ffd 100644 --- a/tree_queries/query.py +++ b/tree_queries/query.py @@ -27,6 +27,7 @@ def with_tree_fields(self, tree_fields=True): # noqa: FBT002 """ if tree_fields: self.query.__class__ = TreeQuery + self.query._setup_query() else: self.query.__class__ = Query return self @@ -45,9 +46,34 @@ def order_siblings_by(self, *order_by): to order tree siblings by those model fields """ self.query.__class__ = TreeQuery + self.query._setup_query() self.query.sibling_order = order_by return self + def pre_filter(self, **filter): + """ + Sets TreeQuery pre_filter attribute + + Pass a dict of fields and their values to filter by + """ + self.query.__class__ = TreeQuery + self.query._setup_query() + filter_tuple = (True, filter) + self.query.pre_filter.append(filter_tuple) + return self + + def pre_exclude(self, **filter): + """ + Sets TreeQuery pre_filter attribute + + Pass a dict of fields and their values to filter by + """ + self.query.__class__ = TreeQuery + self.query._setup_query() + exclude_tuple = (False, filter) + self.query.pre_filter.append(exclude_tuple) + return self + def as_manager(cls, *, with_tree_fields=False): # noqa: N805 manager_class = TreeManager.from_queryset(cls) # Only used in deconstruct: From 370ab510ab5f598edf886e0a5b019b28cdca28ad Mon Sep 17 00:00:00 2001 From: rhomboss <54870528+rhomboss@users.noreply.github.com> Date: Mon, 15 Apr 2024 15:52:56 -0500 Subject: [PATCH 2/5] Support Q objects and change name to tree_filter Change `pre_filter` to `tree_filter` and add support for `Q` objects --- tests/testapp/test_queries.py | 77 +++++++++++++++++++++++++++++------ tree_queries/compiler.py | 42 +++++++++---------- tree_queries/query.py | 16 ++++---- 3 files changed, 93 insertions(+), 42 deletions(-) diff --git a/tests/testapp/test_queries.py b/tests/testapp/test_queries.py index 2a935a9..9458bd6 100644 --- a/tests/testapp/test_queries.py +++ b/tests/testapp/test_queries.py @@ -754,11 +754,11 @@ def test_order_by_related(self): ], ) - def test_pre_exclude(self): + def test_tree_exclude(self): tree = self.create_tree() - # Pre-filter should remove children if + # Tree-filter should remove children if # the parent meets the filtering criteria - nodes = Model.objects.pre_exclude(name="2") + nodes = Model.objects.tree_exclude(name="2") self.assertEqual( list(nodes), [ @@ -768,11 +768,11 @@ def test_pre_exclude(self): ], ) - def test_pre_filter(self): + def test_tree_filter(self): tree = self.create_tree() - # Pre-filter should remove children if + # Tree-filter should remove children if # the parent does not meet the filtering criteria - nodes = Model.objects.pre_filter(name__in=["root","1-1","2","2-1","2-2"]) + nodes = Model.objects.tree_filter(name__in=["root","1-1","2","2-1","2-2"]) self.assertEqual( list(nodes), [ @@ -783,11 +783,11 @@ def test_pre_filter(self): ], ) - def test_pre_filter_chaining(self): + def test_tree_filter_chaining(self): tree = self.create_tree() - # Pre-filter should remove children if + # Tree-filter should remove children if # the parent does not meet the filtering criteria - nodes = Model.objects.pre_exclude(name="2-2").pre_filter(name__in=["root","1-1","2","2-1","2-2"]) + nodes = Model.objects.tree_exclude(name="2-2").tree_filter(name__in=["root","1-1","2","2-1","2-2"]) self.assertEqual( list(nodes), [ @@ -797,7 +797,7 @@ def test_pre_filter_chaining(self): ], ) - def test_pre_filter_related(self): + def test_tree_filter_related(self): tree = type("Namespace", (), {})() # SimpleNamespace for PY2... tree.root = RelatedOrderModel.objects.create(name="root") @@ -825,7 +825,7 @@ def test_pre_filter_related(self): relatedmodel=tree.child2_2, order=1 ) - nodes = RelatedOrderModel.objects.pre_filter(related__order=0) + nodes = RelatedOrderModel.objects.tree_filter(related__order=0) self.assertEqual( list(nodes), [ @@ -835,7 +835,7 @@ def test_pre_filter_related(self): ], ) - def test_pre_filter_with_order(self): + def test_tree_filter_with_order(self): tree = type("Namespace", (), {})() # SimpleNamespace for PY2... tree.root = MultiOrderedModel.objects.create( @@ -859,7 +859,7 @@ def test_pre_filter_with_order(self): nodes = ( MultiOrderedModel.objects - .pre_filter(first_position__gt=0) + .tree_filter(first_position__gt=0) .order_siblings_by("-second_position") ) self.assertEqual( @@ -871,3 +871,54 @@ def test_pre_filter_with_order(self): tree.child2_2, ], ) + + def test_tree_filter_Q_objects(self): + tree = self.create_tree() + # Tree-filter should remove children if + # the parent does not meet the filtering criteria + nodes = Model.objects.tree_filter(Q(name__in=["root","1-1","2","2-1","2-2"])) + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child2, + tree.child2_1, + tree.child2_2, + ], + ) + + def test_tree_filter_Q_mix(self): + tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + + tree.root = MultiOrderedModel.objects.create( + name="root", first_position=1, second_position=2 + ) + tree.child1 = MultiOrderedModel.objects.create( + parent=tree.root, first_position=1, second_position=0, name="1" + ) + tree.child2 = MultiOrderedModel.objects.create( + parent=tree.root, first_position=1, second_position=2, name="2" + ) + tree.child1_1 = MultiOrderedModel.objects.create( + parent=tree.child1, first_position=1, second_position=1, name="1-1" + ) + tree.child2_1 = MultiOrderedModel.objects.create( + parent=tree.child2, first_position=1, second_position=1, name="2-1" + ) + tree.child2_2 = MultiOrderedModel.objects.create( + parent=tree.child2, first_position=1, second_position=2, name="2-2" + ) + # Tree-filter should remove children if + # the parent does not meet the filtering criteria + nodes = ( + MultiOrderedModel.objects + .tree_filter(Q(first_position=1), second_position=2) + ) + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child2, + tree.child2_2, + ], + ) diff --git a/tree_queries/compiler.py b/tree_queries/compiler.py index dfbc8bc..fcb73e8 100644 --- a/tree_queries/compiler.py +++ b/tree_queries/compiler.py @@ -22,7 +22,7 @@ def _setup_query(self): Run on initialization and at the end of chaining. Any attributes that would normally be set in __init__() should go here instead. """ - # We add the variables for `sibling_order` and `pre_filter` here so they + # We add the variables for `sibling_order` and `tree_filter` here so they # act as instance variables which do not persist between user queries # the way class variables do @@ -36,9 +36,9 @@ def _setup_query(self): else opts.pk.attname ) - # Only add the pre_filter attribute if the query doesn't already have one to preserve cloning behavior - if not hasattr(self, "pre_filter"): - self.pre_filter = [] + # Only add the tree_filter attribute if the query doesn't already have one to preserve cloning behavior + if not hasattr(self, "tree_filter"): + self.tree_filter = [] def get_compiler(self, using=None, connection=None, **kwargs): @@ -56,8 +56,8 @@ def get_compiler(self, using=None, connection=None, **kwargs): def get_sibling_order(self): return self.sibling_order - def get_pre_filter(self): - return self.pre_filter + def get_tree_filter(self): + return self.tree_filter class TreeCompiler(SQLCompiler): CTE_POSTGRESQL = """ @@ -71,7 +71,7 @@ class TreeCompiler(SQLCompiler): {rank_parent}, ROW_NUMBER() OVER (ORDER BY {rank_order_by}) FROM {rank_from} - {pre_filter} + {tree_filter} ), __tree ( "tree_depth", @@ -106,7 +106,7 @@ class TreeCompiler(SQLCompiler): {rank_parent}, ROW_NUMBER() OVER (ORDER BY {rank_order_by}) FROM {rank_from} - {pre_filter} + {tree_filter} ), __tree(tree_depth, tree_path, tree_ordering, tree_pk) AS ( SELECT @@ -138,7 +138,7 @@ class TreeCompiler(SQLCompiler): {rank_parent}, row_number() OVER (ORDER BY {rank_order_by}) FROM {rank_from} - {pre_filter} + {tree_filter} ), __tree(tree_depth, tree_path, tree_ordering, tree_pk) AS ( SELECT @@ -164,7 +164,7 @@ class TreeCompiler(SQLCompiler): def get_rank_table_params(self): """ This method uses a simple django queryset to generate sql - that can be used to create the __rank_table that pre-filters + that can be used to create the __rank_table that tree-filters and orders siblings. This is done so that any joins required by order_by or filter/exclude are pre-calculated by django """ @@ -180,20 +180,20 @@ def get_rank_table_params(self): "Sibling order must be a string or a list or tuple of strings." ) - # Get pre_filter - pre_filter = self.query.get_pre_filter() + # Get tree_filter + tree_filter = self.query.get_tree_filter() # Use Django to make a SQL query that can be repurposed for __rank_table base_query = _find_tree_model(self.query.model).objects.only("pk", "parent") - # Add pre_filters if they exist - if pre_filter: + # Add tree_filters if they exist + if tree_filter: # Apply filters and excludes to the query in the order provided by the user - for is_filter, filter_fields in pre_filter: + for is_filter, filter_Q, filter_fields in tree_filter: if is_filter: - base_query = base_query.filter(**filter_fields) + base_query = base_query.filter(*filter_Q, **filter_fields) else: - base_query = base_query.exclude(**filter_fields) + base_query = base_query.exclude(*filter_Q, **filter_fields) # Apply sibling_order base_query = base_query.order_by(*order_fields).query @@ -210,12 +210,12 @@ def get_rank_table_params(self): "rank_order_by": tail.strip(), } - # Split on the first WHERE if present to get the pre_filter param - if pre_filter: + # Split on the first WHERE if present to get the tree_filter param + if tree_filter: head, sep, tail = head.partition("WHERE") - rank_table_params["pre_filter"] = "WHERE " + tail.strip() # Note the space after WHERE + rank_table_params["tree_filter"] = "WHERE " + tail.strip() # Note the space after WHERE else: - rank_table_params["pre_filter"] = "" + rank_table_params["tree_filter"] = "" # Split on the first FROM to get any joins etc. head, sep, tail = head.partition("FROM") diff --git a/tree_queries/query.py b/tree_queries/query.py index b999ffd..3393228 100644 --- a/tree_queries/query.py +++ b/tree_queries/query.py @@ -50,28 +50,28 @@ def order_siblings_by(self, *order_by): self.query.sibling_order = order_by return self - def pre_filter(self, **filter): + def tree_filter(self, *Q_objects, **filter): """ - Sets TreeQuery pre_filter attribute + Sets TreeQuery tree_filter attribute Pass a dict of fields and their values to filter by """ self.query.__class__ = TreeQuery self.query._setup_query() - filter_tuple = (True, filter) - self.query.pre_filter.append(filter_tuple) + filter_tuple = (True, Q_objects, filter) + self.query.tree_filter.append(filter_tuple) return self - def pre_exclude(self, **filter): + def tree_exclude(self, *Q_objects, **filter): """ - Sets TreeQuery pre_filter attribute + Sets TreeQuery tree_filter attribute Pass a dict of fields and their values to filter by """ self.query.__class__ = TreeQuery self.query._setup_query() - exclude_tuple = (False, filter) - self.query.pre_filter.append(exclude_tuple) + exclude_tuple = (False, Q_objects, filter) + self.query.tree_filter.append(exclude_tuple) return self def as_manager(cls, *, with_tree_fields=False): # noqa: N805 From 6661b5eb399401b20ac99be44e0f965f92d32b2c Mon Sep 17 00:00:00 2001 From: rhomboss <54870528+rhomboss@users.noreply.github.com> Date: Tue, 16 Apr 2024 16:01:13 -0500 Subject: [PATCH 3/5] Generate the CTE rank_table entirely with Django's backend * Update compiler.py & query.py Generate the rank_table entirely with django's backend by providing a queryset instance variable that can be modified by the API --- tree_queries/compiler.py | 139 +++++++++++++++------------------------ tree_queries/query.py | 24 ++++--- 2 files changed, 68 insertions(+), 95 deletions(-) diff --git a/tree_queries/compiler.py b/tree_queries/compiler.py index fcb73e8..6b8f85c 100644 --- a/tree_queries/compiler.py +++ b/tree_queries/compiler.py @@ -1,5 +1,6 @@ from django.db import connections -from django.db.models import Value +from django.db.models import Value, F, Window, Expression, QuerySet +from django.db.models.functions import RowNumber from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.query import Query @@ -22,7 +23,7 @@ def _setup_query(self): Run on initialization and at the end of chaining. Any attributes that would normally be set in __init__() should go here instead. """ - # We add the variables for `sibling_order` and `tree_filter` here so they + # We add the variables for `sibling_order` and `rank_table_query` here so they # act as instance variables which do not persist between user queries # the way class variables do @@ -36,9 +37,13 @@ def _setup_query(self): else opts.pk.attname ) - # Only add the tree_filter attribute if the query doesn't already have one to preserve cloning behavior - if not hasattr(self, "tree_filter"): - self.tree_filter = [] + # Only add the rank_table_query attribute if the query doesn't already have one to preserve cloning behavior + if not hasattr(self, "rank_table_query"): + # Create a default QuerySet for the rank_table to use + # so we can avoid recursion + self.rank_table_query = QuerySet( + model=_find_tree_model(self.model) + ) def get_compiler(self, using=None, connection=None, **kwargs): @@ -56,8 +61,8 @@ def get_compiler(self, using=None, connection=None, **kwargs): def get_sibling_order(self): return self.sibling_order - def get_tree_filter(self): - return self.tree_filter + def get_rank_table_query(self): + return self.rank_table_query class TreeCompiler(SQLCompiler): CTE_POSTGRESQL = """ @@ -66,12 +71,7 @@ class TreeCompiler(SQLCompiler): "{parent}", "rank_order" ) AS ( - SELECT - {rank_pk}, - {rank_parent}, - ROW_NUMBER() OVER (ORDER BY {rank_order_by}) - FROM {rank_from} - {tree_filter} + {rank_table} ), __tree ( "tree_depth", @@ -101,12 +101,7 @@ class TreeCompiler(SQLCompiler): CTE_MYSQL = """ WITH RECURSIVE __rank_table({pk}, {parent}, rank_order) AS ( - SELECT - {rank_pk}, - {rank_parent}, - ROW_NUMBER() OVER (ORDER BY {rank_order_by}) - FROM {rank_from} - {tree_filter} + {rank_table} ), __tree(tree_depth, tree_path, tree_ordering, tree_pk) AS ( SELECT @@ -133,12 +128,7 @@ class TreeCompiler(SQLCompiler): CTE_SQLITE3 = """ WITH RECURSIVE __rank_table({pk}, {parent}, rank_order) AS ( - SELECT - {rank_pk}, - {rank_parent}, - row_number() OVER (ORDER BY {rank_order_by}) - FROM {rank_from} - {tree_filter} + {rank_table} ), __tree(tree_depth, tree_path, tree_ordering, tree_pk) AS ( SELECT @@ -161,14 +151,8 @@ class TreeCompiler(SQLCompiler): ) """ - def get_rank_table_params(self): - """ - This method uses a simple django queryset to generate sql - that can be used to create the __rank_table that tree-filters - and orders siblings. This is done so that any joins required - by order_by or filter/exclude are pre-calculated by django - """ - # Get can validate sibling_order + def get_rank_table(self): + # Get and validate sibling_order sibling_order = self.query.get_sibling_order() if isinstance(sibling_order, (list, tuple)): @@ -180,56 +164,41 @@ def get_rank_table_params(self): "Sibling order must be a string or a list or tuple of strings." ) - # Get tree_filter - tree_filter = self.query.get_tree_filter() - - # Use Django to make a SQL query that can be repurposed for __rank_table - base_query = _find_tree_model(self.query.model).objects.only("pk", "parent") - - # Add tree_filters if they exist - if tree_filter: - # Apply filters and excludes to the query in the order provided by the user - for is_filter, filter_Q, filter_fields in tree_filter: - if is_filter: - base_query = base_query.filter(*filter_Q, **filter_fields) + # Convert strings to expressions. This is to maintain backwards compatibility + # with Django versions < 4.1 + base_order = [] + for field in order_fields: + if isinstance(field, Expression): + base_order.append(field) + elif isinstance(field, str): + if field[0] == "-": + base_order.append(F(field[1:]).desc()) else: - base_query = base_query.exclude(*filter_Q, **filter_fields) - - # Apply sibling_order - base_query = base_query.order_by(*order_fields).query - - # Get SQL and parameters - base_compiler = SQLCompiler(base_query, self.connection, None) - base_sql, base_params = base_compiler.as_sql() - - # Split sql on the last ORDER BY to get the rank_order param - head, sep, tail = base_sql.rpartition("ORDER BY") - - # Add rank_order_by to params - rank_table_params = { - "rank_order_by": tail.strip(), - } - - # Split on the first WHERE if present to get the tree_filter param - if tree_filter: - head, sep, tail = head.partition("WHERE") - rank_table_params["tree_filter"] = "WHERE " + tail.strip() # Note the space after WHERE - else: - rank_table_params["tree_filter"] = "" - - # Split on the first FROM to get any joins etc. - head, sep, tail = head.partition("FROM") - rank_table_params["rank_from"] = tail.strip() - - # Identify the parent and primary key fields - head, sep, tail = head.partition("SELECT") - for field in tail.split(","): - if "parent_id" in field: # XXX Taking advantage of Hardcoded. - rank_table_params["rank_parent"] = field.strip() - else: - rank_table_params["rank_pk"] = field.strip() + base_order.append(F(field).asc()) + order_fields = base_order + # End of back compat code + + # Get the rank table query + rank_table_query = self.query.get_rank_table_query() + + rank_table_query = ( + rank_table_query + .order_by() # Ensure there is no ORDER BY at the end of the SQL + # Values allows us to both limit and specify the order of + # the columns selected so that they match the CTE + .values( + "pk", + "parent", + rank_order=Window( + expression=RowNumber(), + order_by=order_fields, + ), + ) + ) + + rank_table_sql, rank_table_params = rank_table_query.query.sql_with_params() - return rank_table_params, base_params + return rank_table_sql, rank_table_params def as_sql(self, *args, **kwargs): # Try detecting if we're used in a EXISTS(1 as "a") subquery like @@ -274,9 +243,9 @@ def as_sql(self, *args, **kwargs): "sep": SEPARATOR, } - # Get params needed by the rank_table - rank_table_params, rank_table_sql_params = self.get_rank_table_params() - params.update(rank_table_params) + # Get the rank_table SQL and params + rank_table_sql, rank_table_params = self.get_rank_table() + params["rank_table"] = rank_table_sql if "__tree" not in self.query.extra_tables: # pragma: no branch - unlikely tree_params = params.copy() @@ -328,7 +297,7 @@ def as_sql(self, *args, **kwargs): # Pass any additional rank table sql paramaters so that the db backend can handle them. # This only works because we know that the CTE is at the start of the query. - return ("".join([explain, cte.format(**params), sql_0]), rank_table_sql_params + sql_1) + return ("".join([explain, cte.format(**params), sql_0]), rank_table_params + sql_1) def get_converters(self, expressions): converters = super().get_converters(expressions) diff --git a/tree_queries/query.py b/tree_queries/query.py index 3393228..df19cb6 100644 --- a/tree_queries/query.py +++ b/tree_queries/query.py @@ -50,28 +50,32 @@ def order_siblings_by(self, *order_by): self.query.sibling_order = order_by return self - def tree_filter(self, *Q_objects, **filter): + def tree_filter(self, *args, **kwargs): """ - Sets TreeQuery tree_filter attribute + Adds a filter to the TreeQuery rank_table_query - Pass a dict of fields and their values to filter by + Takes the same arguements as a Django QuerySet .filter() """ self.query.__class__ = TreeQuery self.query._setup_query() - filter_tuple = (True, Q_objects, filter) - self.query.tree_filter.append(filter_tuple) + self.query.rank_table_query = ( + self.query.rank_table_query + .filter(*args, **kwargs) + ) return self - def tree_exclude(self, *Q_objects, **filter): + def tree_exclude(self, *args, **kwargs): """ - Sets TreeQuery tree_filter attribute + Adds a filter to the TreeQuery rank_table_query - Pass a dict of fields and their values to filter by + Takes the same arguements as a Django QuerySet .exclude() """ self.query.__class__ = TreeQuery self.query._setup_query() - exclude_tuple = (False, Q_objects, filter) - self.query.tree_filter.append(exclude_tuple) + self.query.rank_table_query = ( + self.query.rank_table_query + .exclude(*args, **kwargs) + ) return self def as_manager(cls, *, with_tree_fields=False): # noqa: N805 From 2fcbc05dabb44d6044e56cfb6af440890dd0ce2a Mon Sep 17 00:00:00 2001 From: rhomboss <54870528+rhomboss@users.noreply.github.com> Date: Wed, 24 Apr 2024 11:54:00 -0500 Subject: [PATCH 4/5] Update compiler.py (#9) Add django version checking to back compat code --- tree_queries/compiler.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tree_queries/compiler.py b/tree_queries/compiler.py index 6b8f85c..b5177d1 100644 --- a/tree_queries/compiler.py +++ b/tree_queries/compiler.py @@ -3,6 +3,7 @@ from django.db.models.functions import RowNumber from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.query import Query +import django SEPARATOR = "\x1f" @@ -166,17 +167,17 @@ def get_rank_table(self): # Convert strings to expressions. This is to maintain backwards compatibility # with Django versions < 4.1 - base_order = [] - for field in order_fields: - if isinstance(field, Expression): - base_order.append(field) - elif isinstance(field, str): - if field[0] == "-": - base_order.append(F(field[1:]).desc()) - else: - base_order.append(F(field).asc()) - order_fields = base_order - # End of back compat code + if django.VERSION < (4, 1): + base_order = [] + for field in order_fields: + if isinstance(field, Expression): + base_order.append(field) + elif isinstance(field, str): + if field[0] == "-": + base_order.append(F(field[1:]).desc()) + else: + base_order.append(F(field).asc()) + order_fields = base_order # Get the rank table query rank_table_query = self.query.get_rank_table_query() From 6bdb2cc1ef4fbaf4cb80e11852602712c0b2cf69 Mon Sep 17 00:00:00 2001 From: rhomboss <54870528+rhomboss@users.noreply.github.com> Date: Wed, 24 Apr 2024 12:19:32 -0500 Subject: [PATCH 5/5] Update test_queries.py Remove Python 2 support and use SimpleNamespace --- tests/testapp/test_queries.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/testapp/test_queries.py b/tests/testapp/test_queries.py index 9458bd6..2648e2a 100644 --- a/tests/testapp/test_queries.py +++ b/tests/testapp/test_queries.py @@ -25,11 +25,13 @@ from tree_queries.compiler import SEPARATOR, TreeQuery from tree_queries.query import pk +from types import SimpleNamespace + @override_settings(DEBUG=True) class Test(TestCase): def create_tree(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.root = Model.objects.create(name="root") tree.child1 = Model.objects.create(parent=tree.root, order=0, name="1") tree.child2 = Model.objects.create(parent=tree.root, order=1, name="2") @@ -257,7 +259,7 @@ class OtherForm(forms.Form): self.assertNotIn("root", html) def test_string_ordering(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.americas = StringOrderedModel.objects.create(name="Americas") tree.europe = StringOrderedModel.objects.create(name="Europe") @@ -373,7 +375,7 @@ def test_always_tree_query_relations(self): def test_reference(self): tree = self.create_tree() - references = type("Namespace", (), {})() # SimpleNamespace for PY2... + references = SimpleNamespace() references.none = ReferenceModel.objects.create(position=0) references.root = ReferenceModel.objects.create( position=1, tree_field=tree.root @@ -534,7 +536,7 @@ def test_uuid_queries(self): ) def test_sibling_ordering(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.root = MultiOrderedModel.objects.create(name="root") tree.child1 = MultiOrderedModel.objects.create( @@ -682,7 +684,7 @@ def test_descending_order(self): ) def test_multi_field_order(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.root = MultiOrderedModel.objects.create(name="root") tree.child1 = MultiOrderedModel.objects.create( @@ -717,7 +719,7 @@ def test_multi_field_order(self): ) def test_order_by_related(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.root = RelatedOrderModel.objects.create(name="root") tree.child1 = RelatedOrderModel.objects.create(parent=tree.root, name="1") @@ -798,7 +800,7 @@ def test_tree_filter_chaining(self): ) def test_tree_filter_related(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.root = RelatedOrderModel.objects.create(name="root") tree.root_related = OneToOneRelatedOrder.objects.create( @@ -836,7 +838,7 @@ def test_tree_filter_related(self): ) def test_tree_filter_with_order(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.root = MultiOrderedModel.objects.create( name="root", first_position=1, @@ -888,7 +890,7 @@ def test_tree_filter_Q_objects(self): ) def test_tree_filter_Q_mix(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.root = MultiOrderedModel.objects.create( name="root", first_position=1, second_position=2