From c5946327c8324e28467f71653a4e2f50adfa356d Mon Sep 17 00:00:00 2001 From: qasimgulzar Date: Mon, 18 Nov 2024 14:34:04 +0500 Subject: [PATCH] fix: make database function compatible with postgresql and mysql both --- openedx_tagging/core/tagging/models/utils.py | 42 ++++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/openedx_tagging/core/tagging/models/utils.py b/openedx_tagging/core/tagging/models/utils.py index 86a5f128..e653df2b 100644 --- a/openedx_tagging/core/tagging/models/utils.py +++ b/openedx_tagging/core/tagging/models/utils.py @@ -3,6 +3,8 @@ """ from django.db.models import Aggregate, CharField from django.db.models.expressions import Func +from django.db import connection + RESERVED_TAG_CHARS = [ '\t', # Used in the database to separate tag levels in the "lineage" field @@ -34,21 +36,47 @@ def as_sqlite(self, compiler, connection, **extra_context): ) -class StringAgg(Aggregate): # pylint: disable=abstract-method +class StringAgg(Aggregate): """ Aggregate function that collects the values of some column across all rows, - and creates a string by concatenating those values, with "," as a separator. + and creates a string by concatenating those values, with a specified separator. - This is the same as Django's django.contrib.postgres.aggregates.StringAgg, - but this version works with MySQL and SQLite. + This version supports PostgreSQL (STRING_AGG), MySQL (GROUP_CONCAT), and SQLite. """ + # Default function is for MySQL (GROUP_CONCAT) function = 'GROUP_CONCAT' - template = '%(function)s(%(distinct)s%(expressions)s)' + template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)' + + def __init__(self, expression, distinct=False, delimiter=',', **extra): + + self.delimiter=delimiter + # Handle the distinct option and output type + distinct_str = 'DISTINCT ' if distinct else '' - def __init__(self, expression, distinct=False, **extra): + # Check the database backend (PostgreSQL, MySQL, or SQLite) + if 'postgresql' in connection.vendor.lower(): + self.function = 'STRING_AGG' + self.template = '%(function)s(%(distinct)s%(expressions)s, %(delimiter)s)' + elif 'mysql' in connection.vendor.lower() or 'sqlite' in connection.vendor.lower(): + self.function = 'GROUP_CONCAT' + self.template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)' + + # Initialize the parent class with the necessary parameters super().__init__( expression, - distinct='DISTINCT ' if distinct else '', + distinct=distinct_str, + delimiter=delimiter, output_field=CharField(), **extra, ) + + def as_sql(self, compiler, connection, **extra_context): + # If PostgreSQL, we use STRING_AGG with a separator + if 'postgresql' in connection.vendor.lower(): + # Ensure that expressions are cast to TEXT for PostgreSQL + expressions_sql, params = compiler.compile(self.source_expressions[0]) + expressions_sql = f"({expressions_sql})::TEXT" # Cast to TEXT for PostgreSQL + return f"{self.function}({expressions_sql}, {self.delimiter!r})", params + else: + # MySQL/SQLite handles GROUP_CONCAT with SEPARATOR + return super().as_sql(compiler, connection, **extra_context)