From 873fe07cc10bf3964a23cbc512821da46570b898 Mon Sep 17 00:00:00 2001 From: Matthias Kestenholz Date: Thu, 25 Apr 2024 14:32:49 +0200 Subject: [PATCH] Rename the functionality from extra fields to tree fields --- tests/testapp/test_queries.py | 10 ++++----- tree_queries/compiler.py | 40 +++++++++++++++++------------------ tree_queries/query.py | 4 ++-- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/tests/testapp/test_queries.py b/tests/testapp/test_queries.py index 8d034bd..6384be7 100644 --- a/tests/testapp/test_queries.py +++ b/tests/testapp/test_queries.py @@ -926,12 +926,12 @@ def test_tree_filter_q_mix(self): ], ) - @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests") - def test_extra_fields(self): + @unittest.skipUnless( + connection.vendor in {"postgresql", "sqlite"}, "Not all DB engines supported" + ) + def test_tree_fields(self): self.create_tree() - names = [ - obj.tree_names for obj in Model.objects.extra_fields(tree_names="name") - ] + names = [obj.tree_names for obj in Model.objects.tree_fields(tree_names="name")] self.assertEqual( names, [ diff --git a/tree_queries/compiler.py b/tree_queries/compiler.py index b5578e4..86fbdcf 100644 --- a/tree_queries/compiler.py +++ b/tree_queries/compiler.py @@ -39,8 +39,8 @@ def _setup_query(self): # so we can avoid recursion self.rank_table_query = QuerySet(model=_find_tree_model(self.model)) - if not hasattr(self, "extra_fields"): - self.extra_fields = {} + if not hasattr(self, "tree_fields"): + self.tree_fields = {} def get_compiler(self, using=None, connection=None, **kwargs): # Copied from django/db/models/sql/query.py @@ -60,14 +60,14 @@ def get_sibling_order(self): def get_rank_table_query(self): return self.rank_table_query - def get_extra_fields(self): - return self.extra_fields + def get_tree_fields(self): + return self.tree_fields class TreeCompiler(SQLCompiler): CTE_POSTGRESQL = """ WITH RECURSIVE __rank_table( - {extra_fields_columns} + {tree_fields_columns} "{pk}", "{parent}", "rank_order" @@ -75,14 +75,14 @@ class TreeCompiler(SQLCompiler): {rank_table} ), __tree ( - {extra_fields_names} + {tree_fields_names} "tree_depth", "tree_path", "tree_ordering", "tree_pk" ) AS ( SELECT - {extra_fields_initial} + {tree_fields_initial} 0 AS tree_depth, array[T.{pk}] AS tree_path, array[T.rank_order] AS tree_ordering, @@ -93,7 +93,7 @@ class TreeCompiler(SQLCompiler): UNION ALL SELECT - {extra_fields_recursive} + {tree_fields_recursive} __tree.tree_depth + 1 AS tree_depth, __tree.tree_path || T.{pk}, __tree.tree_ordering || T.rank_order, @@ -130,7 +130,7 @@ class TreeCompiler(SQLCompiler): ) """ - CTE_SQLITE3 = """ + CTE_SQLITE = """ WITH RECURSIVE __rank_table({pk}, {parent}, rank_order) AS ( {rank_table} ), @@ -190,7 +190,7 @@ def get_rank_table(self): # Values allows us to both limit and specify the order of # the columns selected so that they match the CTE .values( - *self.query.get_extra_fields().values(), + *self.query.get_tree_fields().values(), "pk", "parent", rank_order=Window( @@ -251,20 +251,20 @@ def as_sql(self, *args, **kwargs): rank_table_sql, rank_table_params = self.get_rank_table() params["rank_table"] = rank_table_sql - extra_fields = self.query.get_extra_fields() + tree_fields = self.query.get_tree_fields() qn = self.connection.ops.quote_name params.update({ - "extra_fields_columns": "".join( - f"{qn(column)}, " for column in extra_fields.values() + "tree_fields_columns": "".join( + f"{qn(column)}, " for column in tree_fields.values() ), - "extra_fields_names": "".join(f"{qn(name)}, " for name in extra_fields), - "extra_fields_initial": "".join( + "tree_fields_names": "".join(f"{qn(name)}, " for name in tree_fields), + "tree_fields_initial": "".join( f"array[T.{qn(column)}]::text[] AS {qn(name)}, " - for name, column in extra_fields.items() + for name, column in tree_fields.items() ), - "extra_fields_recursive": "".join( + "tree_fields_recursive": "".join( f"__tree.{qn(name)} || T.{qn(column)}, " - for name, column in extra_fields.items() + for name, column in tree_fields.items() ), }) @@ -287,7 +287,7 @@ def as_sql(self, *args, **kwargs): "tree_path": "__tree.tree_path", "tree_ordering": "__tree.tree_ordering", } - select.update({name: f"__tree.{name}" for name in extra_fields}) + select.update({name: f"__tree.{name}" for name in tree_fields}) self.query.add_extra( # Do not add extra fields to the select statement when it is a # summary query or when using .values() or .values_list() @@ -308,7 +308,7 @@ def as_sql(self, *args, **kwargs): if self.connection.vendor == "postgresql": cte = self.CTE_POSTGRESQL elif self.connection.vendor == "sqlite": - cte = self.CTE_SQLITE3 + cte = self.CTE_SQLITE elif self.connection.vendor == "mysql": cte = self.CTE_MYSQL sql_0, sql_1 = super().as_sql(*args, **kwargs) diff --git a/tree_queries/query.py b/tree_queries/query.py index 3fac05d..f408ef4 100644 --- a/tree_queries/query.py +++ b/tree_queries/query.py @@ -76,10 +76,10 @@ def tree_exclude(self, *args, **kwargs): ) return self - def extra_fields(self, **extra_fields): + def tree_fields(self, **tree_fields): self.query.__class__ = TreeQuery self.query._setup_query() - self.query.extra_fields = extra_fields + self.query.tree_fields = tree_fields return self @classmethod