diff --git a/openedx_tagging/core/tagging/models/utils.py b/openedx_tagging/core/tagging/models/utils.py index e653df2b..c9539299 100644 --- a/openedx_tagging/core/tagging/models/utils.py +++ b/openedx_tagging/core/tagging/models/utils.py @@ -3,8 +3,7 @@ """ from django.db.models import Aggregate, CharField from django.db.models.expressions import Func -from django.db import connection - +from django.db import connection as db_connection RESERVED_TAG_CHARS = [ '\t', # Used in the database to separate tag levels in the "lineage" field @@ -36,7 +35,10 @@ def as_sqlite(self, compiler, connection, **extra_context): ) -class StringAgg(Aggregate): +from django.db.models import Aggregate, CharField +from django.db.models.expressions import Combinable + +class StringAgg(Aggregate, Combinable): """ Aggregate function that collects the values of some column across all rows, and creates a string by concatenating those values, with a specified separator. @@ -45,28 +47,27 @@ class StringAgg(Aggregate): """ # Default function is for MySQL (GROUP_CONCAT) function = 'GROUP_CONCAT' - template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)' + template = '%(function)s(%(distinct)s%(expressions)s)' def __init__(self, expression, distinct=False, delimiter=',', **extra): - - self.delimiter=delimiter + self.delimiter = delimiter # Handle the distinct option and output type distinct_str = 'DISTINCT ' if distinct else '' + extra.update(dict( + distinct=distinct_str, + output_field=CharField() + )) + # Check the database backend (PostgreSQL, MySQL, or SQLite) - if 'postgresql' in connection.vendor.lower(): + if 'postgresql' in db_connection.vendor.lower(): self.function = 'STRING_AGG' self.template = '%(function)s(%(distinct)s%(expressions)s, %(delimiter)s)' - elif 'mysql' in connection.vendor.lower() or 'sqlite' in connection.vendor.lower(): - self.function = 'GROUP_CONCAT' - self.template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)' + extra.update({"delimiter": delimiter}) # Initialize the parent class with the necessary parameters super().__init__( expression, - distinct=distinct_str, - delimiter=delimiter, - output_field=CharField(), **extra, ) @@ -80,3 +81,13 @@ def as_sql(self, compiler, connection, **extra_context): else: # MySQL/SQLite handles GROUP_CONCAT with SEPARATOR return super().as_sql(compiler, connection, **extra_context) + + # Implementing abstract methods from Combinable + def __rand__(self, other): + return self._combine(other, 'AND', is_combinable=True) + + def __ror__(self, other): + return self._combine(other, 'OR', is_combinable=True) + + def __rxor__(self, other): + return self._combine(other, 'XOR', is_combinable=True)