From 3b1b7e0d749cffa773fc473b57840720a034b890 Mon Sep 17 00:00:00 2001 From: Sune Debel Date: Sat, 22 Aug 2020 18:26:43 +0200 Subject: [PATCH] allow effect.catch to use async functions --- docs/effectful_but_side_effect_free.md | 4 +- pfun/effect.py | 50 +++++++++++------- pfun/mypy_plugin.py | 11 ++-- pfun/sql.py | 70 +++++++++----------------- 4 files changed, 63 insertions(+), 72 deletions(-) diff --git a/docs/effectful_but_side_effect_free.md b/docs/effectful_but_side_effect_free.md index 7c265a93..7c1b1360 100644 --- a/docs/effectful_but_side_effect_free.md +++ b/docs/effectful_but_side_effect_free.md @@ -490,9 +490,7 @@ async def f(r: str) -> Either[Exception, float]: effect: Effect[str, Exception, float] = from_callable(f) ``` -`pfun.effect.catch` is used to decorate functions -that may raise exceptions. If the decorated function performs side effects, they -are not carried out until the effect is run +`pfun.effect.catch` is used to decorate sync and async functions that may raise exceptions. If the decorated function performs side effects, they are not carried out until the effect is run ```python from pfun.effect import catch, Effect diff --git a/pfun/effect.py b/pfun/effect.py index 8512e5b1..abaf04bf 100644 --- a/pfun/effect.py +++ b/pfun/effect.py @@ -631,8 +631,7 @@ async def run_e(env): return Done(Right(env.r)) return Effect( - run_e, - f'depend({r_type.__name__ if r_type is not None else ""})' + run_e, f'depend({r_type.__name__ if r_type is not None else ""})' ) @@ -684,9 +683,7 @@ async def thunk() -> Trampoline[Either[E1, Iterable[A1]]]: # TODO should this be run in an executor to avoid blocking? # maybe depending on the number of effects? trampoline = sequence_trampolines(trampolines) - return trampoline.map( - lambda eithers: sequence_eithers(eithers) - ) + return trampoline.map(lambda eithers: sequence_eithers(eithers)) return Call(thunk) @@ -702,9 +699,8 @@ async def run_e(r: RuntimeEnv[R1]) -> Trampoline[Either[E1, Iterable[A1]]]: async def thunk() -> Trampoline[Either[E1, Iterable[A1]]]: trampolines = [await e.run_e(r) for e in iterable] # type: ignore trampoline = sequence_trampolines(trampolines) - return trampoline.map( - lambda eithers: sequence_eithers(eithers) - ) + return trampoline.map(lambda eithers: sequence_eithers(eithers)) + return Call(thunk) return Effect(run_e) @@ -712,8 +708,8 @@ async def thunk() -> Trampoline[Either[E1, Iterable[A1]]]: @curry @add_repr -def for_each(f: Callable[[A1], Effect[R1, E1, B]], iterable: Iterable[A1] - ) -> Effect[R1, E1, Iterable[B]]: +def for_each(f: Callable[[A1], Effect[R1, E1, B]], + iterable: Iterable[A1]) -> Effect[R1, E1, Iterable[B]]: """ Map each in element in ``iterable`` to an `Effect` by applying ``f``, @@ -735,8 +731,8 @@ def for_each(f: Callable[[A1], Effect[R1, E1, B]], iterable: Iterable[A1] @curry @add_repr -def filter_(f: Callable[[A], Effect[R1, E1, bool]], iterable: Iterable[A] - ) -> Effect[R1, E1, Iterable[A]]: +def filter_(f: Callable[[A], Effect[R1, E1, bool]], + iterable: Iterable[A]) -> Effect[R1, E1, Iterable[A]]: """ Map each element in ``iterable`` by applying ``f``, filter the results by the value returned by ``f`` @@ -755,9 +751,7 @@ def filter_(f: Callable[[A], Effect[R1, E1, bool]], iterable: Iterable[A] """ iterable = tuple(iterable) bools = sequence(f(a) for a in iterable) - return bools.map( - lambda bs: tuple(a for b, a in zip(bs, iterable) if b) - ) + return bools.map(lambda bs: tuple(a for b, a in zip(bs, iterable) if b)) @add_repr @@ -1054,7 +1048,18 @@ def __init__(self, error: Type[EX], *errors: Type[EX]): """ object.__setattr__(self, 'errors', (error, ) + errors) - def __call__(self, f: Callable[..., B]) -> Callable[..., Try[EX, B]]: + @overload + def __call__(self, f: Callable[..., Awaitable[B]] + ) -> Callable[..., Try[EX, B]]: + ... + + @overload + def __call__(self, f: Callable[..., B] + ) -> Callable[..., Try[EX, B]]: + ... + + def __call__(self, f: Union[Callable[..., Awaitable[B]], Callable[..., B]] + ) -> Callable[..., Try[EX, B]]: """ Decorate `f` to catch exceptions as an `Effect` """ @@ -1067,17 +1072,24 @@ async def run_e(r: RuntimeEnv[object] return Done( Right( await - r.run_in_process_executor(f, *args, *kwargs) + r.run_in_process_executor( + f, *args, *kwargs # type: ignore + ) ) ) elif is_io_bound(f): return Done( Right( await - r.run_in_thread_executor(f, *args, **kwargs) + r.run_in_thread_executor( + f, *args, **kwargs # type: ignore + ) ) ) - return Done(Right(f(*args, **kwargs))) + result = f(*args, **kwargs) + if asyncio.iscoroutine(result): + result = await result # type: ignore + return Done(Right(result)) # type: ignore except Exception as e: if any(isinstance(e, t) for t in self.errors): return Done(Left(e)) # type: ignore diff --git a/pfun/mypy_plugin.py b/pfun/mypy_plugin.py index 41bdcde6..cc20309e 100644 --- a/pfun/mypy_plugin.py +++ b/pfun/mypy_plugin.py @@ -461,10 +461,13 @@ def _lift_call_hook(context: MethodContext) -> Type: def _effect_catch_hook(context: FunctionContext) -> Type: - error_types = [ - arg_type[0].ret_type for arg_type in context.arg_types if arg_type - ] - return context.default_return_type.copy_modified(args=error_types) + try: + error_types = [ + arg_type[0].ret_type for arg_type in context.arg_types if arg_type + ] + return context.default_return_type.copy_modified(args=error_types) + except AttributeError: + return context.default_return_type def _effect_catch_call_hook(context: MethodContext) -> Type: diff --git a/pfun/sql.py b/pfun/sql.py index 63541f7a..208c220f 100644 --- a/pfun/sql.py +++ b/pfun/sql.py @@ -4,7 +4,8 @@ from typing_extensions import Protocol from .dict import Dict -from .effect import Effect, Resource, Try, add_repr, depend, error, success +from .effect import (Effect, Resource, Try, add_repr, catch, depend, error, + success) from .either import Either, Left, Right from .functions import curry from .immutable import Immutable @@ -154,17 +155,11 @@ def execute(self, query: str, *args: Any, Return: `Effect` that executes `query` and produces the database response """ - async def _execute(connection: asyncpg.Connection - ) -> Try[asyncpg.PostgresError, str]: - try: - result = await connection.execute( - query, *args, timeout=timeout - ) - return success(result) - except asyncpg.PostgresError as e: - return error(e) + @catch(asyncpg.PostgresError) + async def execute(connection: asyncpg.Connection) -> str: + return await connection.execute(query, *args, timeout=timeout) - return self.get_connection().and_then(_execute) + return self.get_connection().and_then(execute) def execute_many( self, query: str, args: Iterable[Any], timeout: float = None @@ -190,17 +185,11 @@ def execute_many( `Effect` that executes `query` with all args in `args` and \ produces a database response for each query """ - async def _execute_many(connection: asyncpg.Connection - ) -> Try[asyncpg.PostgresError, str]: - try: - result = await connection.executemany( - query, *args, timeout=timeout - ) - return success(result) - except asyncpg.PostgresError as e: - return error(e) + @catch(asyncpg.PostgresError) + async def execute_many(connection: asyncpg.Connection) -> str: + return await connection.executemany(query, *args, timeout=timeout) - return self.get_connection().and_then(_execute_many) + return self.get_connection().and_then(execute_many) def fetch(self, query: str, *args: Any, timeout: float = None) -> Try[asyncpg.PostgresError, Results]: @@ -221,16 +210,12 @@ def fetch(self, query: str, *args: Any, Return: `Effect` that retrieves rows returned by `query` as `Results` """ - async def _fetch(connection: asyncpg.Connection - ) -> Try[asyncpg.PostgresError, Results]: - try: - result = await connection.fetch(query, *args, timeout=timeout) - result = List(Dict(record) for record in result) - return success(result) - except asyncpg.PostgresError as e: - return error(e) + @catch(asyncpg.PostgresError) + async def fetch(connection: asyncpg.Connection) -> Results: + result = await connection.fetch(query, *args, timeout=timeout) + return List(Dict(record) for record in result) - return self.get_connection().and_then(_fetch) + return self.get_connection().and_then(fetch) def fetch_one(self, query: str, *args: Any, timeout: float = None) -> Try[SQLError, Dict[str, Any]]: @@ -253,24 +238,17 @@ def fetch_one(self, query: str, *args: Any, `Effect` that retrieves the first row returned by `query` as \ `pfun.dict.Dict[str, Any]` """ - async def _fetch_row(connection: asyncpg.Connection - ) -> Try[asyncpg.PostgresError, Dict[str, Any]]: - try: - result = await connection.fetchrow( - query, *args, timeout=timeout + @catch(asyncpg.PostgresError, EmptyResultSetError) + async def fetch_row(connection: asyncpg.Connection) -> Dict[str, Any]: + result = await connection.fetchrow(query, *args, timeout=timeout) + if result is None: + raise EmptyResultSetError( + f'query "{query}" with args "{args}" ' + 'returned no results' ) - if result is None: - return error( - EmptyResultSetError( - f'query "{query}" with args "{args}" ' - 'returned no results' - ) - ) - return success(Dict(result)) - except asyncpg.PostgresError as e: - return error(e) + return Dict(result) - return self.get_connection().and_then(_fetch_row) + return self.get_connection().and_then(fetch_row) class HasSQL(Protocol):