Skip to content

Commit

Permalink
Refactor for mypy & coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
ewjoachim committed Dec 27, 2023
1 parent 74e6a8b commit 3626347
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 13 deletions.
25 changes: 12 additions & 13 deletions procrastinate/psycopg_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,29 @@
import asyncio
import functools
import logging
from typing import Any, Callable, Coroutine, Iterable
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Iterable

from typing_extensions import LiteralString

from procrastinate import connector, exceptions, sql
from procrastinate import connector, exceptions, sql, utils

try:
if TYPE_CHECKING:
import psycopg
import psycopg.errors
import psycopg.rows
import psycopg.sql
import psycopg.types.json
import psycopg_pool
except ImportError as exc:
# In case psycopg is not installed, we'll raise an explicit error
# only when the connector is used.
exception = exc
else:
psycopg, *_ = utils.import_or_wrapper(
"psycopg",
"psycopg.errors",
"psycopg.rows",
"psycopg.sql",
"psycopg.types.json",
)
(psycopg_pool,) = utils.import_or_wrapper("psycopg_pool")

class Wrapper:
def __getattr__(self, item):
raise exception

locals()["psycopg"] = Wrapper()
locals()["psycopg_pool"] = Wrapper()

logger = logging.getLogger(__name__)

Expand Down
20 changes: 20 additions & 0 deletions procrastinate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,23 @@ async def _main():

def add_namespace(name: str, namespace: str) -> str:
return f"{namespace}:{name}"


def import_or_wrapper(*names: str) -> Iterable[types.ModuleType]:
"""
Import given modules, or return a dummy wrapper that will raise an
ImportError when used.
"""
try:
for name in names:
yield importlib.import_module(name)
except ImportError as exc:
# In case psycopg is not installed, we'll raise an explicit error
# only when the connector is used.
exception = exc

class Wrapper:
def __getattr__(self, item):
raise exception

yield Wrapper() # type: ignore
15 changes: 15 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,3 +527,18 @@ def test_check_stack_failure(mocker):
mocker.patch("inspect.currentframe", return_value=None)
with pytest.raises(exceptions.CallerModuleUnknown):
assert utils.caller_module_name()


def test_import_or_wrapper__ok():
result = list(utils.import_or_wrapper("json", "csv"))
import csv
import json

assert result == [json, csv]


def test_import_or_wrapper__fail():
(result,) = utils.import_or_wrapper("a" * 30)

with pytest.raises(ImportError):
assert result.foo

0 comments on commit 3626347

Please sign in to comment.