Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-29679: Implement @contextlib.asynccontextmanager #360

Merged
merged 18 commits into from
May 1, 2017
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions Doc/library/contextlib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,35 @@ Functions and classes provided:
Use of :class:`ContextDecorator`.


.. decorator:: asynccontextmanager

Similar to :func:`~contextlib.contextmanager`, but works with
:term:`coroutines <coroutine>`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"works with coroutine" isn't really descriptive. I'd rephrase to something akin to "but creates asynchronous context managers".


This function is a :term:`decorator` that can be used to define a factory
function for :keyword:`async with` statement asynchronous context managers,
without needing to create a class or separate :meth:`__aenter__` and
:meth:`__aexit__` methods.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add that the decorator expects to be applied to asynchronous generator functions.


A simple example::

from contextlib import asynccontextmanager

@asynccontextmanager
async def get_connection():
conn = await acquire_db_connection()
try:
yield
finally:
await release_db_connection(conn)

async def get_all_users():
async with get_connection() as conn:
return conn.query('SELECT ...')

.. versionadded:: 3.7


.. function:: closing(thing)

Return a context manager that closes *thing* upon completion of the block. This
Expand Down
96 changes: 91 additions & 5 deletions Lib/contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from collections import deque
from functools import wraps

__all__ = ["contextmanager", "closing", "AbstractContextManager",
"ContextDecorator", "ExitStack", "redirect_stdout",
"redirect_stderr", "suppress"]
__all__ = ["asynccontextmanager", "contextmanager", "closing",
"AbstractContextManager", "ContextDecorator", "ExitStack",
"redirect_stdout", "redirect_stderr", "suppress"]


class AbstractContextManager(abc.ABC):
Expand Down Expand Up @@ -54,8 +54,9 @@ def inner(*args, **kwds):
return inner


class _GeneratorContextManager(ContextDecorator, AbstractContextManager):
"""Helper for @contextmanager decorator."""
class _GeneratorContextManagerBase:
"""Shared functionality for the @contextmanager and @asynccontextmanager
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First line of docstring must be a single sentence. Just make it

"""Shared functionality for the @contextmanager and @asynccontextmanager."""

implementations."""

def __init__(self, func, args, kwds):
self.gen = func(*args, **kwds)
Expand All @@ -71,6 +72,12 @@ def __init__(self, func, args, kwds):
# for the class instead.
# See http://bugs.python.org/issue19404 for more details.


class _GeneratorContextManager(_GeneratorContextManagerBase,
AbstractContextManager,
ContextDecorator):
"""Helper for @contextmanager decorator."""

def _recreate_cm(self):
# _GCM instances are one-shot context managers, so the
# CM must be recreated each time a decorated function is
Expand Down Expand Up @@ -126,6 +133,51 @@ def __exit__(self, type, value, traceback):
raise


class _AsyncGeneratorContextManager(_GeneratorContextManagerBase):
"""Helper for @asynccontextmanager."""

async def __aenter__(self):
try:
return await self.gen.__anext__()
except StopAsyncIteration:
raise RuntimeError("generator didn't yield") from None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Diff coverage shows a missing test case for this line.


async def __aexit__(self, type, value, traceback):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type -> err_type or ex_type or typ. Don't mask globals.

if type is None:
try:
await self.gen.__anext__()
except StopAsyncIteration:
return
else:
raise RuntimeError("generator didn't stop")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test case here as well.

else:
if value is None:
value = type()
Copy link
Contributor

@ncoghlan ncoghlan Mar 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You won't be able to hit this line via the async with syntax, but it can be exercised by awaiting __aexit__ directly with a non-normalised exception (i.e. only the exception type, with both the exception value and the traceback as None)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also write a small C helper to do that, but your idea is much better.

# See _GeneratorContextManager.__exit__ for comments on subtleties
# in this implementation
try:
await self.gen.athrow(type, value, traceback)
raise RuntimeError("generator didn't stop after throw()")
except StopAsyncIteration as exc:
return exc is not value
except RuntimeError as exc:
if exc is value:
return False
# Avoid suppressing if a StopIteration exception
# was passed to throw() and later wrapped into a RuntimeError
# (see PEP 479 for sync generators; async generators also
# have this behavior). But do this only if the exception wrapped
# by the RuntimeError is actully Stop(Async)Iteration (see
# issue29692).
if isinstance(value, (StopIteration, StopAsyncIteration)):
if exc.__cause__ is value:
return False
raise
except:
if sys.exc_info()[1] is not value:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we use sys.exc_info() here instead of

except BaseException as exc:
    if exc is not value:
        raise

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this in order to not make unnecessary changes relative to the @contextmanager implementation, although I agree your code is better. I'm happy to change this if @ncoghlan is OK with it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It uses sys.exc_info() because the corresponding _GeneratorContextManager code predates the acceptance and implementation of PEP 352, and is also common between 2.7 and 3.x (and hence needs to handle exceptions that don't inherit from BaseException)

I'm fine with switching this to using the BaseException form, but add a comment to the synchronous version saying it's written that way on purpose to keep the code consistent with the 2.7 branch and the contextlib2 backport.

raise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hitting this line means adding a test case that replaces the thrown in exception with an entirely unrelated one that is neither StopAsyncIteration nor RuntimeError



def contextmanager(func):
"""@contextmanager decorator.

Expand Down Expand Up @@ -160,6 +212,40 @@ def helper(*args, **kwds):
return helper


def asynccontextmanager(func):
"""@asynccontextmanager decorator.

Typical usage:

@asynccontextmanager
async def some_async_generator(<arguments>):
<setup>
try:
yield <value>
finally:
<cleanup>

This makes this:

async with some_async_generator(<arguments>) as <variable>:
<body>

equivalent to this:

<setup>
try:
<variable> = <value>
<body>
finally:
<cleanup>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing empty newline.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean I should remove or add a newline? The formatting here is the same as in @contextmanager.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd remove both empty newlines then. They aren't needed.

"""
@wraps(func)
def helper(*args, **kwds):
return _AsyncGeneratorContextManager(func, args, kwds)
return helper


class closing(AbstractContextManager):
"""Context to automatically close something at the end of a block.

Expand Down
205 changes: 205 additions & 0 deletions Lib/test/test_contextlib_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import asyncio
from contextlib import asynccontextmanager
from test import support
import unittest


def _async_test(func):
"""Decorator to turn an async function into a test case."""
def wrapper(*args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I know it's just a test, but please add @functools.wraps(func)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a wraps decorator.

loop = asyncio.new_event_loop()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to close the loop (1), and make sure it's the default loop (2).

Please rewrite to:

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
    loop.run_until_complete(coro)
finally:
    loop.close()
    asyncio.set_event_loop(None)

coro = func(*args, **kwargs)
return loop.run_until_complete(coro)
return wrapper


class AsyncContextManagerTestCase(unittest.TestCase):

@_async_test
async def test_contextmanager_plain(self):
state = []
@asynccontextmanager
async def woohoo():
state.append(1)
yield 42
state.append(999)
async with woohoo() as x:
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
self.assertEqual(state, [1, 42, 999])

@_async_test
async def test_contextmanager_finally(self):
state = []
@asynccontextmanager
async def woohoo():
state.append(1)
try:
yield 42
finally:
state.append(999)
with self.assertRaises(ZeroDivisionError):
async with woohoo() as x:
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
raise ZeroDivisionError()
self.assertEqual(state, [1, 42, 999])

@_async_test
async def test_contextmanager_no_reraise(self):
@asynccontextmanager
async def whee():
yield
ctx = whee()
await ctx.__aenter__()
# Calling __aexit__ should not result in an exception
self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))

@_async_test
async def test_contextmanager_trap_yield_after_throw(self):
@asynccontextmanager
async def whoo():
try:
yield
except:
yield
ctx = whoo()
await ctx.__aenter__()
with self.assertRaises(RuntimeError):
await ctx.__aexit__(TypeError, TypeError('foo'), None)

@_async_test
async def test_contextmanager_trap_no_yield(self):
@asynccontextmanager
async def whoo():
if False:
yield
ctx = whoo()
with self.assertRaises(RuntimeError):
await ctx.__aenter__()

@_async_test
async def test_contextmanager_trap_second_yield(self):
@asynccontextmanager
async def whoo():
yield
yield
ctx = whoo()
await ctx.__aenter__()
with self.assertRaises(RuntimeError):
await ctx.__aexit__(None, None, None)

@_async_test
async def test_contextmanager_non_normalised(self):
@asynccontextmanager
async def whoo():
try:
yield
except RuntimeError:
raise SyntaxError

ctx = whoo()
await ctx.__aenter__()
with self.assertRaises(SyntaxError):
await ctx.__aexit__(RuntimeError, None, None)

@_async_test
async def test_contextmanager_except(self):
state = []
@asynccontextmanager
async def woohoo():
state.append(1)
try:
yield 42
except ZeroDivisionError as e:
state.append(e.args[0])
self.assertEqual(state, [1, 42, 999])
async with woohoo() as x:
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
raise ZeroDivisionError(999)
self.assertEqual(state, [1, 42, 999])

@_async_test
async def test_contextmanager_except_stopiter(self):
@asynccontextmanager
async def woohoo():
yield

for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
with self.subTest(type=type(stop_exc)):
try:
async with woohoo():
raise stop_exc
except Exception as ex:
self.assertIs(ex, stop_exc)
else:
self.fail(f'{stop_exc} was suppressed')

@_async_test
async def test_contextmanager_wrap_runtimeerror(self):
@asynccontextmanager
async def woohoo():
try:
yield
except Exception as exc:
raise RuntimeError(f'caught {exc}') from exc

with self.assertRaises(RuntimeError):
async with woohoo():
1 / 0

# If the context manager wrapped StopAsyncIteration in a RuntimeError,
# we also unwrap it, because we can't tell whether the wrapping was
# done by the generator machinery or by the generator itself.
with self.assertRaises(StopAsyncIteration):
async with woohoo():
raise StopAsyncIteration

def _create_contextmanager_attribs(self):
def attribs(**kw):
def decorate(func):
for k,v in kw.items():
setattr(func,k,v)
return func
return decorate
@asynccontextmanager
@attribs(foo='bar')
async def baz(spam):
"""Whee!"""
yield
return baz

def test_contextmanager_attribs(self):
baz = self._create_contextmanager_attribs()
self.assertEqual(baz.__name__,'baz')
self.assertEqual(baz.foo, 'bar')

@support.requires_docstrings
def test_contextmanager_doc_attrib(self):
baz = self._create_contextmanager_attribs()
self.assertEqual(baz.__doc__, "Whee!")

@support.requires_docstrings
@_async_test
async def test_instance_docstring_given_cm_docstring(self):
baz = self._create_contextmanager_attribs()(None)
self.assertEqual(baz.__doc__, "Whee!")
async with baz:
pass # suppress warning

@_async_test
async def test_keywords(self):
# Ensure no keyword arguments are inhibited
@asynccontextmanager
async def woohoo(self, func, args, kwds):
yield (self, func, args, kwds)
async with woohoo(self=11, func=22, args=33, kwds=44) as target:
self.assertEqual(target, (11, 22, 33, 44))


if __name__ == '__main__':
unittest.main()