Skip to content

Commit

Permalink
fix: make database function compatible with postgresql and mysql both
Browse files Browse the repository at this point in the history
  • Loading branch information
qasimgulzar committed Nov 18, 2024
1 parent fc4c1ff commit c594632
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions openedx_tagging/core/tagging/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
from django.db.models import Aggregate, CharField
from django.db.models.expressions import Func
from django.db import connection


RESERVED_TAG_CHARS = [
'\t', # Used in the database to separate tag levels in the "lineage" field
Expand Down Expand Up @@ -34,21 +36,47 @@ def as_sqlite(self, compiler, connection, **extra_context):
)


class StringAgg(Aggregate): # pylint: disable=abstract-method
class StringAgg(Aggregate):
"""
Aggregate function that collects the values of some column across all rows,
and creates a string by concatenating those values, with "," as a separator.
and creates a string by concatenating those values, with a specified separator.
This is the same as Django's django.contrib.postgres.aggregates.StringAgg,
but this version works with MySQL and SQLite.
This version supports PostgreSQL (STRING_AGG), MySQL (GROUP_CONCAT), and SQLite.
"""
# Default function is for MySQL (GROUP_CONCAT)
function = 'GROUP_CONCAT'
template = '%(function)s(%(distinct)s%(expressions)s)'
template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)'

def __init__(self, expression, distinct=False, delimiter=',', **extra):

self.delimiter=delimiter
# Handle the distinct option and output type
distinct_str = 'DISTINCT ' if distinct else ''

def __init__(self, expression, distinct=False, **extra):
# Check the database backend (PostgreSQL, MySQL, or SQLite)
if 'postgresql' in connection.vendor.lower():
self.function = 'STRING_AGG'
self.template = '%(function)s(%(distinct)s%(expressions)s, %(delimiter)s)'
elif 'mysql' in connection.vendor.lower() or 'sqlite' in connection.vendor.lower():
self.function = 'GROUP_CONCAT'
self.template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)'

# Initialize the parent class with the necessary parameters
super().__init__(
expression,
distinct='DISTINCT ' if distinct else '',
distinct=distinct_str,
delimiter=delimiter,
output_field=CharField(),
**extra,
)

def as_sql(self, compiler, connection, **extra_context):
# If PostgreSQL, we use STRING_AGG with a separator
if 'postgresql' in connection.vendor.lower():
# Ensure that expressions are cast to TEXT for PostgreSQL
expressions_sql, params = compiler.compile(self.source_expressions[0])
expressions_sql = f"({expressions_sql})::TEXT" # Cast to TEXT for PostgreSQL
return f"{self.function}({expressions_sql}, {self.delimiter!r})", params
else:
# MySQL/SQLite handles GROUP_CONCAT with SEPARATOR
return super().as_sql(compiler, connection, **extra_context)

0 comments on commit c594632

Please sign in to comment.