Skip to content

Commit

Permalink
[IMP] pg: add returning_id option to parallel_execute
Browse files Browse the repository at this point in the history
The default behavior is unchanged.

This adds the possibility to parallelize modifying queries that have a
`RETURNING id` clause. For those, return the resulting ids (in a defined order)
instead of the affected row count.

To avoid misuse add a warning to the docstring and try to detect queries other
than the ones of the intended form. Raise an error if such are found.
  • Loading branch information
cawo-odoo committed Jan 2, 2025
1 parent 05aae80 commit 343aa7e
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions src/util/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,34 +73,34 @@ def savepoint(cr):
yield


def _parallel_execute_serial(cr, queries, logger=_logger):
cnt = 0
def _parallel_execute_serial(cr, queries, logger, returning_id):
res = [] if returning_id else 0
for query in log_progress(queries, logger, qualifier="queries", size=len(queries)):
cr.execute(query)
cnt += cr.rowcount
return cnt
res += cr.fetchall() if returning_id else cr.rowcount
return res


if ThreadPoolExecutor is not None:

def _parallel_execute_threaded(cr, queries, logger=_logger):
def _parallel_execute_threaded(cr, queries, logger, returning_id):
if not queries:
return None

if len(queries) == 1:
# No need to spawn other threads
cr.execute(queries[0])
return cr.rowcount
return cr.fetchall() if returning_id else cr.rowcount

max_workers = min(get_max_workers(), len(queries))
cursor = db_connect(cr.dbname).cursor

def execute(query):
with cursor() as tcr:
tcr.execute(query)
cnt = tcr.rowcount
res = tcr.fetchall() if returning_id else tcr.rowcount
tcr.commit()
return cnt
return res

cr.commit()

Expand All @@ -109,7 +109,7 @@ def execute(query):
errorcodes.SERIALIZATION_FAILURE,
}
failed_queries = []
tot_cnt = 0
tot_res = [] if returning_id else 0
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_queries = {executor.submit(execute, q): q for q in queries}
for future in log_progress(
Expand All @@ -121,7 +121,7 @@ def execute(query):
log_hundred_percent=True,
):
try:
tot_cnt += future.result() or 0
tot_res += future.result() or ([] if returning_id else 0)
except psycopg2.OperationalError as exc:
if exc.pgcode not in CONCURRENCY_ERRORCODES:
raise
Expand All @@ -131,16 +131,16 @@ def execute(query):

if failed_queries:
logger.warning("Serialize queries that failed due to concurrency issues")
tot_cnt += _parallel_execute_serial(cr, failed_queries, logger=logger)
tot_res += _parallel_execute_serial(cr, failed_queries, logger, returning_id)
cr.commit()

return tot_cnt
return tot_res

else:
_parallel_execute_threaded = _parallel_execute_serial


def parallel_execute(cr, queries, logger=_logger):
def parallel_execute(cr, queries, logger=_logger, returning_id=False):
"""
Execute queries in parallel.
Expand All @@ -154,15 +154,20 @@ def parallel_execute(cr, queries, logger=_logger):
:param list(str) queries: list of queries to execute concurrently
:param `~logging.Logger` logger: logger used to report the progress
:return: the sum of `cr.rowcount` for each query run
:param bool returning_id: wether to return a tuple of affected ids (default: return affected row count)
:return: the sum of `cr.rowcount` for each query run or a joined array of all result tuples, if `returning_id`
:rtype: int
.. warning::
- As a side effect, the cursor will be committed.
- Due to the nature of `cr.rowcount`, the return value of this function may represent an
underestimate of the real number of affected records. For instance, when some records
are deleted/updated as a result of an `ondelete` clause, they won't be taken into account.
- As a side effect, the cursor will be committed.
- It would not be generally safe to use this function for selecting queries. Because of this,
`returning_id=True` is only accepted for `UPDATE/DELETE/INSERT/MERGE [...] RETURNING id` queries. Also, the
caller cannot influnce the order of the returned result tuples, it is always sorted in ascending order.
.. note::
If a concurrency issue occurs, the *failing* queries will be retried sequentially.
Expand All @@ -172,7 +177,14 @@ def parallel_execute(cr, queries, logger=_logger):
if getattr(threading.current_thread(), "testing", False)
else _parallel_execute_threaded
)
return parallel_execute_impl(cr, queries, logger=_logger)

if returning_id:
returning_id_re = re.compile(r"(?s)(?:UPDATE|DELETE|INSERT|MERGE).*RETURNING\s+\S*\.?id\s*$")
if not all((bool(returning_id_re.search(q)) for q in queries)):
raise ValueError("The returning_id parameter can only be used with certain queries.")

res = parallel_execute_impl(cr, queries, logger, returning_id)
return tuple(sorted([id for (id,) in res])) if returning_id else res


def format_query(cr, query, *args, **kwargs):
Expand Down

0 comments on commit 343aa7e

Please sign in to comment.