Skip to content

Commit

Permalink
Merge pull request #881 from procrastinate-org/psycopg-binary
Browse files Browse the repository at this point in the history
  • Loading branch information
ewjoachim authored Dec 28, 2023
2 parents be3d302 + 3626347 commit 4b096a1
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 26 deletions.
77 changes: 76 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

64 changes: 40 additions & 24 deletions procrastinate/psycopg_connector.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
from __future__ import annotations

import asyncio
import functools
import logging
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional

import psycopg
import psycopg.errors
import psycopg.sql
import psycopg.types.json
import psycopg_pool
from psycopg.rows import DictRow, dict_row
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Iterable

from typing_extensions import LiteralString

from procrastinate import connector, exceptions, sql
from procrastinate import connector, exceptions, sql, utils

if TYPE_CHECKING:
import psycopg
import psycopg.errors
import psycopg.rows
import psycopg.sql
import psycopg.types.json
import psycopg_pool
else:
psycopg, *_ = utils.import_or_wrapper(
"psycopg",
"psycopg.errors",
"psycopg.rows",
"psycopg.sql",
"psycopg.types.json",
)
(psycopg_pool,) = utils.import_or_wrapper("psycopg_pool")


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,8 +60,8 @@ class PsycopgConnector(connector.BaseAsyncConnector):
def __init__(
self,
*,
json_dumps: Optional[Callable] = None,
json_loads: Optional[Callable] = None,
json_dumps: Callable | None = None,
json_loads: Callable | None = None,
**kwargs: Any,
):
"""
Expand Down Expand Up @@ -89,25 +103,25 @@ def __init__(
argument is passed, it will connect to localhost:5432 instead of a
Unix-domain local socket file.
"""
self._pool: Optional[psycopg_pool.AsyncConnectionPool] = None
self._pool: psycopg_pool.AsyncConnectionPool | None = None
self.json_dumps = json_dumps
self._pool_externally_set = False
self._pool_args = self._adapt_pool_args(kwargs, json_loads, json_dumps)
self.json_loads = json_loads

@staticmethod
def _adapt_pool_args(
pool_args: Dict[str, Any],
json_loads: Optional[Callable],
json_dumps: Optional[Callable],
) -> Dict[str, Any]:
pool_args: dict[str, Any],
json_loads: Callable | None,
json_dumps: Callable | None,
) -> dict[str, Any]:
"""
Adapt the pool args for ``psycopg``, using sensible defaults for Procrastinate.
"""
base_configure = pool_args.pop("configure", None)

@wrap_exceptions
async def configure(connection: psycopg.AsyncConnection[DictRow]):
async def configure(connection: psycopg.AsyncConnection[psycopg.rows.DictRow]):
if base_configure:
await base_configure(connection)

Expand All @@ -122,7 +136,7 @@ async def configure(connection: psycopg.AsyncConnection[DictRow]):
"min_size": 1,
"max_size": 10,
"kwargs": {
"row_factory": dict_row,
"row_factory": psycopg.rows.dict_row,
},
"configure": configure,
**pool_args,
Expand All @@ -131,13 +145,15 @@ async def configure(connection: psycopg.AsyncConnection[DictRow]):
@property
def pool(
self,
) -> psycopg_pool.AsyncConnectionPool[psycopg.AsyncConnection[DictRow]]:
) -> psycopg_pool.AsyncConnectionPool[
psycopg.AsyncConnection[psycopg.rows.DictRow]
]:
if self._pool is None: # Set by open_async
raise exceptions.AppNotOpen
return self._pool

async def open_async(
self, pool: Optional[psycopg_pool.AsyncConnectionPool] = None
self, pool: psycopg_pool.AsyncConnectionPool | None = None
) -> None:
"""
Instantiate the pool.
Expand All @@ -160,7 +176,7 @@ async def open_async(
@staticmethod
@wrap_exceptions
async def _create_pool(
pool_args: Dict[str, Any]
pool_args: dict[str, Any]
) -> psycopg_pool.AsyncConnectionPool:
return psycopg_pool.AsyncConnectionPool(
**pool_args,
Expand All @@ -184,7 +200,7 @@ async def close_async(self) -> None:
await self._pool.close()
self._pool = None

def _wrap_json(self, arguments: Dict[str, Any]):
def _wrap_json(self, arguments: dict[str, Any]):
return {
key: psycopg.types.json.Jsonb(value) if isinstance(value, dict) else value
for key, value in arguments.items()
Expand All @@ -198,7 +214,7 @@ async def execute_query_async(self, query: LiteralString, **arguments: Any) -> N
@wrap_exceptions
async def execute_query_one_async(
self, query: LiteralString, **arguments: Any
) -> DictRow:
) -> psycopg.rows.DictRow:
async with self.pool.connection() as connection:
async with connection.cursor() as cursor:
await cursor.execute(query, self._wrap_json(arguments))
Expand All @@ -212,7 +228,7 @@ async def execute_query_one_async(
@wrap_exceptions
async def execute_query_all_async(
self, query: LiteralString, **arguments: Any
) -> List[DictRow]:
) -> list[psycopg.rows.DictRow]:
async with self.pool.connection() as connection:
async with connection.cursor() as cursor:
await cursor.execute(query, self._wrap_json(arguments))
Expand Down
20 changes: 20 additions & 0 deletions procrastinate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,23 @@ async def _main():

def add_namespace(name: str, namespace: str) -> str:
return f"{namespace}:{name}"


def import_or_wrapper(*names: str) -> Iterable[types.ModuleType]:
"""
Import given modules, or return a dummy wrapper that will raise an
ImportError when used.
"""
try:
for name in names:
yield importlib.import_module(name)
except ImportError as exc:
# In case psycopg is not installed, we'll raise an explicit error
# only when the connector is used.
exception = exc

class Wrapper:
def __getattr__(self, item):
raise exception

yield Wrapper() # type: ignore
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ psycopg2-binary = "*"
python-dateutil = "*"
sqlalchemy = { version = "^2.0", optional = true }
typing-extensions = { version = "*", python = "<3.8" }
psycopg = {extras = ["pool"], version = "^3.1.13"}
psycopg = { extras = ["pool"], version = "^3.1.13" }

[tool.poetry.extras]
django = ["django"]
Expand All @@ -59,6 +59,7 @@ types-psycopg2 = "*"
types-python-dateutil = "*"
SQLAlchemy = { extras = ["mypy"], version = "^2.0.0" }
tomlkit = "*"
psycopg = { extras = ["binary"], version = "^3.1.13" }

[tool.poetry.group.docs.dependencies]
Sphinx = "*"
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,3 +527,18 @@ def test_check_stack_failure(mocker):
mocker.patch("inspect.currentframe", return_value=None)
with pytest.raises(exceptions.CallerModuleUnknown):
assert utils.caller_module_name()


def test_import_or_wrapper__ok():
result = list(utils.import_or_wrapper("json", "csv"))
import csv
import json

assert result == [json, csv]


def test_import_or_wrapper__fail():
(result,) = utils.import_or_wrapper("a" * 30)

with pytest.raises(ImportError):
assert result.foo

0 comments on commit 4b096a1

Please sign in to comment.