Skip to content

Commit

Permalink
fix: tests and quality issues
Browse files Browse the repository at this point in the history
  • Loading branch information
qasimgulzar committed Dec 12, 2024
1 parent c594632 commit bc7a493
Showing 1 changed file with 24 additions and 13 deletions.
37 changes: 24 additions & 13 deletions openedx_tagging/core/tagging/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
"""
from django.db.models import Aggregate, CharField
from django.db.models.expressions import Func
from django.db import connection

from django.db import connection as db_connection

RESERVED_TAG_CHARS = [
'\t', # Used in the database to separate tag levels in the "lineage" field
Expand Down Expand Up @@ -36,7 +35,10 @@ def as_sqlite(self, compiler, connection, **extra_context):
)


class StringAgg(Aggregate):
from django.db.models import Aggregate, CharField
from django.db.models.expressions import Combinable

class StringAgg(Aggregate, Combinable):
"""
Aggregate function that collects the values of some column across all rows,
and creates a string by concatenating those values, with a specified separator.
Expand All @@ -45,28 +47,27 @@ class StringAgg(Aggregate):
"""
# Default function is for MySQL (GROUP_CONCAT)
function = 'GROUP_CONCAT'
template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)'
template = '%(function)s(%(distinct)s%(expressions)s)'

def __init__(self, expression, distinct=False, delimiter=',', **extra):

self.delimiter=delimiter
self.delimiter = delimiter
# Handle the distinct option and output type
distinct_str = 'DISTINCT ' if distinct else ''

extra.update(dict(
distinct=distinct_str,
output_field=CharField()
))

# Check the database backend (PostgreSQL, MySQL, or SQLite)
if 'postgresql' in connection.vendor.lower():
if 'postgresql' in db_connection.vendor.lower():
self.function = 'STRING_AGG'
self.template = '%(function)s(%(distinct)s%(expressions)s, %(delimiter)s)'
elif 'mysql' in connection.vendor.lower() or 'sqlite' in connection.vendor.lower():
self.function = 'GROUP_CONCAT'
self.template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)'
extra.update({"delimiter": delimiter})

# Initialize the parent class with the necessary parameters
super().__init__(
expression,
distinct=distinct_str,
delimiter=delimiter,
output_field=CharField(),
**extra,
)

Expand All @@ -80,3 +81,13 @@ def as_sql(self, compiler, connection, **extra_context):
else:
# MySQL/SQLite handles GROUP_CONCAT with SEPARATOR
return super().as_sql(compiler, connection, **extra_context)

# Implementing abstract methods from Combinable
def __rand__(self, other):
return self._combine(other, 'AND', is_combinable=True)

def __ror__(self, other):
return self._combine(other, 'OR', is_combinable=True)

def __rxor__(self, other):
return self._combine(other, 'XOR', is_combinable=True)

0 comments on commit bc7a493

Please sign in to comment.