Skip to content

Commit

Permalink
allow effect.catch to use async functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Sune Debel authored Aug 22, 2020
1 parent 3eca0be commit 3b1b7e0
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 72 deletions.
4 changes: 1 addition & 3 deletions docs/effectful_but_side_effect_free.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,9 +490,7 @@ async def f(r: str) -> Either[Exception, float]:
effect: Effect[str, Exception, float] = from_callable(f)
```

`pfun.effect.catch` is used to decorate functions
that may raise exceptions. If the decorated function performs side effects, they
are not carried out until the effect is run
`pfun.effect.catch` is used to decorate sync and async functions that may raise exceptions. If the decorated function performs side effects, they are not carried out until the effect is run
```python
from pfun.effect import catch, Effect

Expand Down
50 changes: 31 additions & 19 deletions pfun/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,7 @@ async def run_e(env):
return Done(Right(env.r))

return Effect(
run_e,
f'depend({r_type.__name__ if r_type is not None else ""})'
run_e, f'depend({r_type.__name__ if r_type is not None else ""})'
)


Expand Down Expand Up @@ -684,9 +683,7 @@ async def thunk() -> Trampoline[Either[E1, Iterable[A1]]]:
# TODO should this be run in an executor to avoid blocking?
# maybe depending on the number of effects?
trampoline = sequence_trampolines(trampolines)
return trampoline.map(
lambda eithers: sequence_eithers(eithers)
)
return trampoline.map(lambda eithers: sequence_eithers(eithers))

return Call(thunk)

Expand All @@ -702,18 +699,17 @@ async def run_e(r: RuntimeEnv[R1]) -> Trampoline[Either[E1, Iterable[A1]]]:
async def thunk() -> Trampoline[Either[E1, Iterable[A1]]]:
trampolines = [await e.run_e(r) for e in iterable] # type: ignore
trampoline = sequence_trampolines(trampolines)
return trampoline.map(
lambda eithers: sequence_eithers(eithers)
)
return trampoline.map(lambda eithers: sequence_eithers(eithers))

return Call(thunk)

return Effect(run_e)


@curry
@add_repr
def for_each(f: Callable[[A1], Effect[R1, E1, B]], iterable: Iterable[A1]
) -> Effect[R1, E1, Iterable[B]]:
def for_each(f: Callable[[A1], Effect[R1, E1, B]],
iterable: Iterable[A1]) -> Effect[R1, E1, Iterable[B]]:
"""
Map each in element in ``iterable`` to
an `Effect` by applying ``f``,
Expand All @@ -735,8 +731,8 @@ def for_each(f: Callable[[A1], Effect[R1, E1, B]], iterable: Iterable[A1]

@curry
@add_repr
def filter_(f: Callable[[A], Effect[R1, E1, bool]], iterable: Iterable[A]
) -> Effect[R1, E1, Iterable[A]]:
def filter_(f: Callable[[A], Effect[R1, E1, bool]],
iterable: Iterable[A]) -> Effect[R1, E1, Iterable[A]]:
"""
Map each element in ``iterable`` by applying ``f``,
filter the results by the value returned by ``f``
Expand All @@ -755,9 +751,7 @@ def filter_(f: Callable[[A], Effect[R1, E1, bool]], iterable: Iterable[A]
"""
iterable = tuple(iterable)
bools = sequence(f(a) for a in iterable)
return bools.map(
lambda bs: tuple(a for b, a in zip(bs, iterable) if b)
)
return bools.map(lambda bs: tuple(a for b, a in zip(bs, iterable) if b))


@add_repr
Expand Down Expand Up @@ -1054,7 +1048,18 @@ def __init__(self, error: Type[EX], *errors: Type[EX]):
"""
object.__setattr__(self, 'errors', (error, ) + errors)

def __call__(self, f: Callable[..., B]) -> Callable[..., Try[EX, B]]:
@overload
def __call__(self, f: Callable[..., Awaitable[B]]
) -> Callable[..., Try[EX, B]]:
...

@overload
def __call__(self, f: Callable[..., B]
) -> Callable[..., Try[EX, B]]:
...

def __call__(self, f: Union[Callable[..., Awaitable[B]], Callable[..., B]]
) -> Callable[..., Try[EX, B]]:
"""
Decorate `f` to catch exceptions as an `Effect`
"""
Expand All @@ -1067,17 +1072,24 @@ async def run_e(r: RuntimeEnv[object]
return Done(
Right(
await
r.run_in_process_executor(f, *args, *kwargs)
r.run_in_process_executor(
f, *args, *kwargs # type: ignore
)
)
)
elif is_io_bound(f):
return Done(
Right(
await
r.run_in_thread_executor(f, *args, **kwargs)
r.run_in_thread_executor(
f, *args, **kwargs # type: ignore
)
)
)
return Done(Right(f(*args, **kwargs)))
result = f(*args, **kwargs)
if asyncio.iscoroutine(result):
result = await result # type: ignore
return Done(Right(result)) # type: ignore
except Exception as e:
if any(isinstance(e, t) for t in self.errors):
return Done(Left(e)) # type: ignore
Expand Down
11 changes: 7 additions & 4 deletions pfun/mypy_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,13 @@ def _lift_call_hook(context: MethodContext) -> Type:


def _effect_catch_hook(context: FunctionContext) -> Type:
error_types = [
arg_type[0].ret_type for arg_type in context.arg_types if arg_type
]
return context.default_return_type.copy_modified(args=error_types)
try:
error_types = [
arg_type[0].ret_type for arg_type in context.arg_types if arg_type
]
return context.default_return_type.copy_modified(args=error_types)
except AttributeError:
return context.default_return_type


def _effect_catch_call_hook(context: MethodContext) -> Type:
Expand Down
70 changes: 24 additions & 46 deletions pfun/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing_extensions import Protocol

from .dict import Dict
from .effect import Effect, Resource, Try, add_repr, depend, error, success
from .effect import (Effect, Resource, Try, add_repr, catch, depend, error,
success)
from .either import Either, Left, Right
from .functions import curry
from .immutable import Immutable
Expand Down Expand Up @@ -154,17 +155,11 @@ def execute(self, query: str, *args: Any,
Return:
`Effect` that executes `query` and produces the database response
"""
async def _execute(connection: asyncpg.Connection
) -> Try[asyncpg.PostgresError, str]:
try:
result = await connection.execute(
query, *args, timeout=timeout
)
return success(result)
except asyncpg.PostgresError as e:
return error(e)
@catch(asyncpg.PostgresError)
async def execute(connection: asyncpg.Connection) -> str:
return await connection.execute(query, *args, timeout=timeout)

return self.get_connection().and_then(_execute)
return self.get_connection().and_then(execute)

def execute_many(
self, query: str, args: Iterable[Any], timeout: float = None
Expand All @@ -190,17 +185,11 @@ def execute_many(
`Effect` that executes `query` with all args in `args` and \
produces a database response for each query
"""
async def _execute_many(connection: asyncpg.Connection
) -> Try[asyncpg.PostgresError, str]:
try:
result = await connection.executemany(
query, *args, timeout=timeout
)
return success(result)
except asyncpg.PostgresError as e:
return error(e)
@catch(asyncpg.PostgresError)
async def execute_many(connection: asyncpg.Connection) -> str:
return await connection.executemany(query, *args, timeout=timeout)

return self.get_connection().and_then(_execute_many)
return self.get_connection().and_then(execute_many)

def fetch(self, query: str, *args: Any,
timeout: float = None) -> Try[asyncpg.PostgresError, Results]:
Expand All @@ -221,16 +210,12 @@ def fetch(self, query: str, *args: Any,
Return:
`Effect` that retrieves rows returned by `query` as `Results`
"""
async def _fetch(connection: asyncpg.Connection
) -> Try[asyncpg.PostgresError, Results]:
try:
result = await connection.fetch(query, *args, timeout=timeout)
result = List(Dict(record) for record in result)
return success(result)
except asyncpg.PostgresError as e:
return error(e)
@catch(asyncpg.PostgresError)
async def fetch(connection: asyncpg.Connection) -> Results:
result = await connection.fetch(query, *args, timeout=timeout)
return List(Dict(record) for record in result)

return self.get_connection().and_then(_fetch)
return self.get_connection().and_then(fetch)

def fetch_one(self, query: str, *args: Any,
timeout: float = None) -> Try[SQLError, Dict[str, Any]]:
Expand All @@ -253,24 +238,17 @@ def fetch_one(self, query: str, *args: Any,
`Effect` that retrieves the first row returned by `query` as \
`pfun.dict.Dict[str, Any]`
"""
async def _fetch_row(connection: asyncpg.Connection
) -> Try[asyncpg.PostgresError, Dict[str, Any]]:
try:
result = await connection.fetchrow(
query, *args, timeout=timeout
@catch(asyncpg.PostgresError, EmptyResultSetError)
async def fetch_row(connection: asyncpg.Connection) -> Dict[str, Any]:
result = await connection.fetchrow(query, *args, timeout=timeout)
if result is None:
raise EmptyResultSetError(
f'query "{query}" with args "{args}" '
'returned no results'
)
if result is None:
return error(
EmptyResultSetError(
f'query "{query}" with args "{args}" '
'returned no results'
)
)
return success(Dict(result))
except asyncpg.PostgresError as e:
return error(e)
return Dict(result)

return self.get_connection().and_then(_fetch_row)
return self.get_connection().and_then(fetch_row)


class HasSQL(Protocol):
Expand Down

0 comments on commit 3b1b7e0

Please sign in to comment.