-
-
Notifications
You must be signed in to change notification settings - Fork 30.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bpo-29679: Implement @contextlib.asynccontextmanager #360
Changes from 10 commits
904e8a8
b3d59f1
a1d5b3f
c5b8b43
ca77cd2
689f4a5
299d968
e974d48
5808a4c
9caa243
64e6908
178433b
6d0dddb
ad65b4d
737fd0f
3fc20a7
06697a8
bb8de0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,35 @@ Functions and classes provided: | |
Use of :class:`ContextDecorator`. | ||
|
||
|
||
.. decorator:: asynccontextmanager | ||
|
||
Similar to :func:`~contextlib.contextmanager`, but works with | ||
:term:`coroutines <coroutine>`. | ||
|
||
This function is a :term:`decorator` that can be used to define a factory | ||
function for :keyword:`async with` statement asynchronous context managers, | ||
without needing to create a class or separate :meth:`__aenter__` and | ||
:meth:`__aexit__` methods. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd add that the decorator expects to be applied to asynchronous generator functions. |
||
|
||
A simple example:: | ||
|
||
from contextlib import asynccontextmanager | ||
|
||
@asynccontextmanager | ||
async def get_connection(): | ||
conn = await acquire_db_connection() | ||
try: | ||
yield | ||
finally: | ||
await release_db_connection(conn) | ||
|
||
async def get_all_users(): | ||
async with get_connection() as conn: | ||
return conn.query('SELECT ...') | ||
|
||
.. versionadded:: 3.7 | ||
|
||
|
||
.. function:: closing(thing) | ||
|
||
Return a context manager that closes *thing* upon completion of the block. This | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,9 @@ | |
from collections import deque | ||
from functools import wraps | ||
|
||
__all__ = ["contextmanager", "closing", "AbstractContextManager", | ||
"ContextDecorator", "ExitStack", "redirect_stdout", | ||
"redirect_stderr", "suppress"] | ||
__all__ = ["asynccontextmanager", "contextmanager", "closing", | ||
"AbstractContextManager", "ContextDecorator", "ExitStack", | ||
"redirect_stdout", "redirect_stderr", "suppress"] | ||
|
||
|
||
class AbstractContextManager(abc.ABC): | ||
|
@@ -54,8 +54,9 @@ def inner(*args, **kwds): | |
return inner | ||
|
||
|
||
class _GeneratorContextManager(ContextDecorator, AbstractContextManager): | ||
"""Helper for @contextmanager decorator.""" | ||
class _GeneratorContextManagerBase: | ||
"""Shared functionality for the @contextmanager and @asynccontextmanager | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First line of docstring must be a single sentence. Just make it """Shared functionality for the @contextmanager and @asynccontextmanager.""" |
||
implementations.""" | ||
|
||
def __init__(self, func, args, kwds): | ||
self.gen = func(*args, **kwds) | ||
|
@@ -71,6 +72,12 @@ def __init__(self, func, args, kwds): | |
# for the class instead. | ||
# See http://bugs.python.org/issue19404 for more details. | ||
|
||
|
||
class _GeneratorContextManager(_GeneratorContextManagerBase, | ||
AbstractContextManager, | ||
ContextDecorator): | ||
"""Helper for @contextmanager decorator.""" | ||
|
||
def _recreate_cm(self): | ||
# _GCM instances are one-shot context managers, so the | ||
# CM must be recreated each time a decorated function is | ||
|
@@ -126,6 +133,51 @@ def __exit__(self, type, value, traceback): | |
raise | ||
|
||
|
||
class _AsyncGeneratorContextManager(_GeneratorContextManagerBase): | ||
"""Helper for @asynccontextmanager.""" | ||
|
||
async def __aenter__(self): | ||
try: | ||
return await self.gen.__anext__() | ||
except StopAsyncIteration: | ||
raise RuntimeError("generator didn't yield") from None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Diff coverage shows a missing test case for this line. |
||
|
||
async def __aexit__(self, type, value, traceback): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if type is None: | ||
try: | ||
await self.gen.__anext__() | ||
except StopAsyncIteration: | ||
return | ||
else: | ||
raise RuntimeError("generator didn't stop") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing test case here as well. |
||
else: | ||
if value is None: | ||
value = type() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You won't be able to hit this line via the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could also write a small C helper to do that, but your idea is much better. |
||
# See _GeneratorContextManager.__exit__ for comments on subtleties | ||
# in this implementation | ||
try: | ||
await self.gen.athrow(type, value, traceback) | ||
raise RuntimeError("generator didn't stop after throw()") | ||
except StopAsyncIteration as exc: | ||
return exc is not value | ||
except RuntimeError as exc: | ||
if exc is value: | ||
return False | ||
# Avoid suppressing if a StopIteration exception | ||
# was passed to throw() and later wrapped into a RuntimeError | ||
# (see PEP 479 for sync generators; async generators also | ||
# have this behavior). But do this only if the exception wrapped | ||
# by the RuntimeError is actully Stop(Async)Iteration (see | ||
# issue29692). | ||
if isinstance(value, (StopIteration, StopAsyncIteration)): | ||
if exc.__cause__ is value: | ||
return False | ||
raise | ||
except: | ||
if sys.exc_info()[1] is not value: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we use except BaseException as exc:
if exc is not value:
raise There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did this in order to not make unnecessary changes relative to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It uses sys.exc_info() because the corresponding _GeneratorContextManager code predates the acceptance and implementation of PEP 352, and is also common between 2.7 and 3.x (and hence needs to handle exceptions that don't inherit from I'm fine with switching this to using the |
||
raise | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hitting this line means adding a test case that replaces the thrown in exception with an entirely unrelated one that is neither |
||
|
||
|
||
def contextmanager(func): | ||
"""@contextmanager decorator. | ||
|
||
|
@@ -160,6 +212,40 @@ def helper(*args, **kwds): | |
return helper | ||
|
||
|
||
def asynccontextmanager(func): | ||
"""@asynccontextmanager decorator. | ||
|
||
Typical usage: | ||
|
||
@asynccontextmanager | ||
async def some_async_generator(<arguments>): | ||
<setup> | ||
try: | ||
yield <value> | ||
finally: | ||
<cleanup> | ||
|
||
This makes this: | ||
|
||
async with some_async_generator(<arguments>) as <variable>: | ||
<body> | ||
|
||
equivalent to this: | ||
|
||
<setup> | ||
try: | ||
<variable> = <value> | ||
<body> | ||
finally: | ||
<cleanup> | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Trailing empty newline. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean I should remove or add a newline? The formatting here is the same as in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd remove both empty newlines then. They aren't needed. |
||
""" | ||
@wraps(func) | ||
def helper(*args, **kwds): | ||
return _AsyncGeneratorContextManager(func, args, kwds) | ||
return helper | ||
|
||
|
||
class closing(AbstractContextManager): | ||
"""Context to automatically close something at the end of a block. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
import asyncio | ||
from contextlib import asynccontextmanager | ||
from test import support | ||
import unittest | ||
|
||
|
||
def _async_test(func): | ||
"""Decorator to turn an async function into a test case.""" | ||
def wrapper(*args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I know it's just a test, but please add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a |
||
loop = asyncio.new_event_loop() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to close the loop (1), and make sure it's the default loop (2). Please rewrite to: loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
finally:
loop.close()
asyncio.set_event_loop(None) |
||
coro = func(*args, **kwargs) | ||
return loop.run_until_complete(coro) | ||
return wrapper | ||
|
||
|
||
class AsyncContextManagerTestCase(unittest.TestCase): | ||
|
||
@_async_test | ||
async def test_contextmanager_plain(self): | ||
state = [] | ||
@asynccontextmanager | ||
async def woohoo(): | ||
state.append(1) | ||
yield 42 | ||
state.append(999) | ||
async with woohoo() as x: | ||
self.assertEqual(state, [1]) | ||
self.assertEqual(x, 42) | ||
state.append(x) | ||
self.assertEqual(state, [1, 42, 999]) | ||
|
||
@_async_test | ||
async def test_contextmanager_finally(self): | ||
state = [] | ||
@asynccontextmanager | ||
async def woohoo(): | ||
state.append(1) | ||
try: | ||
yield 42 | ||
finally: | ||
state.append(999) | ||
with self.assertRaises(ZeroDivisionError): | ||
async with woohoo() as x: | ||
self.assertEqual(state, [1]) | ||
self.assertEqual(x, 42) | ||
state.append(x) | ||
raise ZeroDivisionError() | ||
self.assertEqual(state, [1, 42, 999]) | ||
|
||
@_async_test | ||
async def test_contextmanager_no_reraise(self): | ||
@asynccontextmanager | ||
async def whee(): | ||
yield | ||
ctx = whee() | ||
await ctx.__aenter__() | ||
# Calling __aexit__ should not result in an exception | ||
self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) | ||
|
||
@_async_test | ||
async def test_contextmanager_trap_yield_after_throw(self): | ||
@asynccontextmanager | ||
async def whoo(): | ||
try: | ||
yield | ||
except: | ||
yield | ||
ctx = whoo() | ||
await ctx.__aenter__() | ||
with self.assertRaises(RuntimeError): | ||
await ctx.__aexit__(TypeError, TypeError('foo'), None) | ||
|
||
@_async_test | ||
async def test_contextmanager_trap_no_yield(self): | ||
@asynccontextmanager | ||
async def whoo(): | ||
if False: | ||
yield | ||
ctx = whoo() | ||
with self.assertRaises(RuntimeError): | ||
await ctx.__aenter__() | ||
|
||
@_async_test | ||
async def test_contextmanager_trap_second_yield(self): | ||
@asynccontextmanager | ||
async def whoo(): | ||
yield | ||
yield | ||
ctx = whoo() | ||
await ctx.__aenter__() | ||
with self.assertRaises(RuntimeError): | ||
await ctx.__aexit__(None, None, None) | ||
|
||
@_async_test | ||
async def test_contextmanager_non_normalised(self): | ||
@asynccontextmanager | ||
async def whoo(): | ||
try: | ||
yield | ||
except RuntimeError: | ||
raise SyntaxError | ||
|
||
ctx = whoo() | ||
await ctx.__aenter__() | ||
with self.assertRaises(SyntaxError): | ||
await ctx.__aexit__(RuntimeError, None, None) | ||
|
||
@_async_test | ||
async def test_contextmanager_except(self): | ||
state = [] | ||
@asynccontextmanager | ||
async def woohoo(): | ||
state.append(1) | ||
try: | ||
yield 42 | ||
except ZeroDivisionError as e: | ||
state.append(e.args[0]) | ||
self.assertEqual(state, [1, 42, 999]) | ||
async with woohoo() as x: | ||
self.assertEqual(state, [1]) | ||
self.assertEqual(x, 42) | ||
state.append(x) | ||
raise ZeroDivisionError(999) | ||
self.assertEqual(state, [1, 42, 999]) | ||
|
||
@_async_test | ||
async def test_contextmanager_except_stopiter(self): | ||
@asynccontextmanager | ||
async def woohoo(): | ||
yield | ||
|
||
for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')): | ||
with self.subTest(type=type(stop_exc)): | ||
try: | ||
async with woohoo(): | ||
raise stop_exc | ||
except Exception as ex: | ||
self.assertIs(ex, stop_exc) | ||
else: | ||
self.fail(f'{stop_exc} was suppressed') | ||
|
||
@_async_test | ||
async def test_contextmanager_wrap_runtimeerror(self): | ||
@asynccontextmanager | ||
async def woohoo(): | ||
try: | ||
yield | ||
except Exception as exc: | ||
raise RuntimeError(f'caught {exc}') from exc | ||
|
||
with self.assertRaises(RuntimeError): | ||
async with woohoo(): | ||
1 / 0 | ||
|
||
# If the context manager wrapped StopAsyncIteration in a RuntimeError, | ||
# we also unwrap it, because we can't tell whether the wrapping was | ||
# done by the generator machinery or by the generator itself. | ||
with self.assertRaises(StopAsyncIteration): | ||
async with woohoo(): | ||
raise StopAsyncIteration | ||
|
||
def _create_contextmanager_attribs(self): | ||
def attribs(**kw): | ||
def decorate(func): | ||
for k,v in kw.items(): | ||
setattr(func,k,v) | ||
return func | ||
return decorate | ||
@asynccontextmanager | ||
@attribs(foo='bar') | ||
async def baz(spam): | ||
"""Whee!""" | ||
yield | ||
return baz | ||
|
||
def test_contextmanager_attribs(self): | ||
baz = self._create_contextmanager_attribs() | ||
self.assertEqual(baz.__name__,'baz') | ||
self.assertEqual(baz.foo, 'bar') | ||
|
||
@support.requires_docstrings | ||
def test_contextmanager_doc_attrib(self): | ||
baz = self._create_contextmanager_attribs() | ||
self.assertEqual(baz.__doc__, "Whee!") | ||
|
||
@support.requires_docstrings | ||
@_async_test | ||
async def test_instance_docstring_given_cm_docstring(self): | ||
baz = self._create_contextmanager_attribs()(None) | ||
self.assertEqual(baz.__doc__, "Whee!") | ||
async with baz: | ||
pass # suppress warning | ||
|
||
@_async_test | ||
async def test_keywords(self): | ||
# Ensure no keyword arguments are inhibited | ||
@asynccontextmanager | ||
async def woohoo(self, func, args, kwds): | ||
yield (self, func, args, kwds) | ||
async with woohoo(self=11, func=22, args=33, kwds=44) as target: | ||
self.assertEqual(target, (11, 22, 33, 44)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"works with coroutine" isn't really descriptive. I'd rephrase to something akin to "but creates asynchronous context managers".