Skip to content

Commit

Permalink
Rename the functionality from extra fields to tree fields
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiask committed Apr 25, 2024
1 parent 228a633 commit 873fe07
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 27 deletions.
10 changes: 5 additions & 5 deletions tests/testapp/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,12 +926,12 @@ def test_tree_filter_q_mix(self):
],
)

@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests")
def test_extra_fields(self):
@unittest.skipUnless(
connection.vendor in {"postgresql", "sqlite"}, "Not all DB engines supported"
)
def test_tree_fields(self):
self.create_tree()
names = [
obj.tree_names for obj in Model.objects.extra_fields(tree_names="name")
]
names = [obj.tree_names for obj in Model.objects.tree_fields(tree_names="name")]
self.assertEqual(
names,
[
Expand Down
40 changes: 20 additions & 20 deletions tree_queries/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def _setup_query(self):
# so we can avoid recursion
self.rank_table_query = QuerySet(model=_find_tree_model(self.model))

if not hasattr(self, "extra_fields"):
self.extra_fields = {}
if not hasattr(self, "tree_fields"):
self.tree_fields = {}

def get_compiler(self, using=None, connection=None, **kwargs):
# Copied from django/db/models/sql/query.py
Expand All @@ -60,29 +60,29 @@ def get_sibling_order(self):
def get_rank_table_query(self):
return self.rank_table_query

def get_extra_fields(self):
return self.extra_fields
def get_tree_fields(self):
return self.tree_fields


class TreeCompiler(SQLCompiler):
CTE_POSTGRESQL = """
WITH RECURSIVE __rank_table(
{extra_fields_columns}
{tree_fields_columns}
"{pk}",
"{parent}",
"rank_order"
) AS (
{rank_table}
),
__tree (
{extra_fields_names}
{tree_fields_names}
"tree_depth",
"tree_path",
"tree_ordering",
"tree_pk"
) AS (
SELECT
{extra_fields_initial}
{tree_fields_initial}
0 AS tree_depth,
array[T.{pk}] AS tree_path,
array[T.rank_order] AS tree_ordering,
Expand All @@ -93,7 +93,7 @@ class TreeCompiler(SQLCompiler):
UNION ALL
SELECT
{extra_fields_recursive}
{tree_fields_recursive}
__tree.tree_depth + 1 AS tree_depth,
__tree.tree_path || T.{pk},
__tree.tree_ordering || T.rank_order,
Expand Down Expand Up @@ -130,7 +130,7 @@ class TreeCompiler(SQLCompiler):
)
"""

CTE_SQLITE3 = """
CTE_SQLITE = """
WITH RECURSIVE __rank_table({pk}, {parent}, rank_order) AS (
{rank_table}
),
Expand Down Expand Up @@ -190,7 +190,7 @@ def get_rank_table(self):
# Values allows us to both limit and specify the order of
# the columns selected so that they match the CTE
.values(
*self.query.get_extra_fields().values(),
*self.query.get_tree_fields().values(),
"pk",
"parent",
rank_order=Window(
Expand Down Expand Up @@ -251,20 +251,20 @@ def as_sql(self, *args, **kwargs):
rank_table_sql, rank_table_params = self.get_rank_table()
params["rank_table"] = rank_table_sql

extra_fields = self.query.get_extra_fields()
tree_fields = self.query.get_tree_fields()
qn = self.connection.ops.quote_name
params.update({
"extra_fields_columns": "".join(
f"{qn(column)}, " for column in extra_fields.values()
"tree_fields_columns": "".join(
f"{qn(column)}, " for column in tree_fields.values()
),
"extra_fields_names": "".join(f"{qn(name)}, " for name in extra_fields),
"extra_fields_initial": "".join(
"tree_fields_names": "".join(f"{qn(name)}, " for name in tree_fields),
"tree_fields_initial": "".join(
f"array[T.{qn(column)}]::text[] AS {qn(name)}, "
for name, column in extra_fields.items()
for name, column in tree_fields.items()
),
"extra_fields_recursive": "".join(
"tree_fields_recursive": "".join(
f"__tree.{qn(name)} || T.{qn(column)}, "
for name, column in extra_fields.items()
for name, column in tree_fields.items()
),
})

Expand All @@ -287,7 +287,7 @@ def as_sql(self, *args, **kwargs):
"tree_path": "__tree.tree_path",
"tree_ordering": "__tree.tree_ordering",
}
select.update({name: f"__tree.{name}" for name in extra_fields})
select.update({name: f"__tree.{name}" for name in tree_fields})
self.query.add_extra(
# Do not add extra fields to the select statement when it is a
# summary query or when using .values() or .values_list()
Expand All @@ -308,7 +308,7 @@ def as_sql(self, *args, **kwargs):
if self.connection.vendor == "postgresql":
cte = self.CTE_POSTGRESQL
elif self.connection.vendor == "sqlite":
cte = self.CTE_SQLITE3
cte = self.CTE_SQLITE
elif self.connection.vendor == "mysql":
cte = self.CTE_MYSQL
sql_0, sql_1 = super().as_sql(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions tree_queries/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def tree_exclude(self, *args, **kwargs):
)
return self

def extra_fields(self, **extra_fields):
def tree_fields(self, **tree_fields):
self.query.__class__ = TreeQuery
self.query._setup_query()
self.query.extra_fields = extra_fields
self.query.tree_fields = tree_fields
return self

@classmethod
Expand Down

0 comments on commit 873fe07

Please sign in to comment.