From 3c903649af196084eb20854b8bc5c10302c5f59d Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 16 Nov 2016 18:50:48 -0500 Subject: [PATCH] Add asyncio.run_forever(); add tests. --- asyncio/__init__.py | 6 +- asyncio/run.py | 96 -------------------- asyncio/runners.py | 146 ++++++++++++++++++++++++++++++ runtests.py | 4 + tests/test_runner.py | 205 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 358 insertions(+), 99 deletions(-) delete mode 100644 asyncio/run.py create mode 100644 asyncio/runners.py create mode 100644 tests/test_runner.py diff --git a/asyncio/__init__.py b/asyncio/__init__.py index 3434cb05..30cfbdce 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -24,7 +24,7 @@ from .futures import * from .locks import * from .protocols import * -from .run import * +from .runners import * from .queues import * from .streams import * from .subprocess import * @@ -37,12 +37,12 @@ futures.__all__ + locks.__all__ + protocols.__all__ + + runners.__all__ + queues.__all__ + streams.__all__ + subprocess.__all__ + tasks.__all__ + - transports.__all__ + - ['run']) # Will fix this later. + transports.__all__) if sys.platform == 'win32': # pragma: no cover from .windows_events import * diff --git a/asyncio/run.py b/asyncio/run.py deleted file mode 100644 index f75a567f..00000000 --- a/asyncio/run.py +++ /dev/null @@ -1,96 +0,0 @@ -"""asyncio.run() function.""" - -__all__ = ['run'] - -import inspect -import threading - -from . import coroutines -from . import events - - -def _isasyncgen(obj): - if hasattr(inspect, 'isasyncgen'): - return inspect.isasyncgen(obj) - return False - - -def run(coro, *, debug=False): - """Run a coroutine. - - This function runs the passed coroutine, taking care of - managing the asyncio event loop and finalizing asynchronous - generators. - - This function must be called from the main thread, and it - cannot be called when another asyncio event loop is running. - - If debug is True, the event loop will be run in debug mode. - - This function should be used as a main entry point for - asyncio programs, and should not be used to call asynchronous - APIs. - - Example:: - - import asyncio - - async def main(): - await asyncio.sleep(1) - print('hello') - - asyncio.run(main()) - """ - if events._get_running_loop() is not None: - raise RuntimeError( - "asyncio.run() cannot be called from a running event loop") - if not isinstance(threading.current_thread(), threading._MainThread): - raise RuntimeError( - "asyncio.run() must be called from the main thread") - if not coroutines.iscoroutine(coro) and not _isasyncgen(coro): - raise ValueError( - "a coroutine or an asynchronous generator was expected, " - "got {!r}".format(coro)) - - loop = events.new_event_loop() - try: - events.set_event_loop(loop) - - if debug: - loop.set_debug(True) - - if _isasyncgen(coro): - result = None - loop.run_until_complete(coro.asend(None)) - try: - loop.run_forever() - except BaseException as ex: - try: - loop.run_until_complete(coro.athrow(ex)) - except StopAsyncIteration as ex: - if ex.args: - result = ex.args[0] - else: - try: - loop.run_until_complete(coro.asend(None)) - except StopAsyncIteration as ex: - if ex.args: - result = ex.args[0] - - else: - result = loop.run_until_complete(coro) - - try: - # `shutdown_asyncgens` was added in Python 3.6; not all - # event loops might support it. - shutdown_asyncgens = loop.shutdown_asyncgens - except AttributeError: - pass - else: - loop.run_until_complete(shutdown_asyncgens()) - - return result - - finally: - events.set_event_loop(None) - loop.close() diff --git a/asyncio/runners.py b/asyncio/runners.py new file mode 100644 index 00000000..e8fa22dd --- /dev/null +++ b/asyncio/runners.py @@ -0,0 +1,146 @@ +"""asyncio.run() and asyncio.run_forever() functions.""" + +__all__ = ['run', 'run_forever'] + +import inspect +import threading + +from . import coroutines +from . import events + + +def _cleanup(loop): + try: + # `shutdown_asyncgens` was added in Python 3.6; not all + # event loops might support it. + shutdown_asyncgens = loop.shutdown_asyncgens + except AttributeError: + pass + else: + loop.run_until_complete(shutdown_asyncgens()) + finally: + events.set_event_loop(None) + loop.close() + + +def run(main, *, debug=False): + """Run a coroutine. + + This function runs the passed coroutine, taking care of + managing the asyncio event loop and finalizing asynchronous + generators. + + This function must be called from the main thread, and it + cannot be called when another asyncio event loop is running. + + If debug is True, the event loop will be run in debug mode. + + This function should be used as a main entry point for + asyncio programs, and should not be used to call asynchronous + APIs. + + Example:: + + async def main(): + await asyncio.sleep(1) + print('hello') + + asyncio.run(main()) + """ + if events._get_running_loop() is not None: + raise RuntimeError( + "asyncio.run() cannot be called from a running event loop") + if not isinstance(threading.current_thread(), threading._MainThread): + raise RuntimeError( + "asyncio.run() must be called from the main thread") + if not coroutines.iscoroutine(main): + raise ValueError("a coroutine was expected, got {!r}".format(main)) + + loop = events.new_event_loop() + try: + events.set_event_loop(loop) + + if debug: + loop.set_debug(True) + + return loop.run_until_complete(main) + finally: + _cleanup(loop) + + +def run_forever(main, *, debug=False): + """Run asyncio loop. + + main must be an asynchronous generator with one yield, separating + program initialization from cleanup logic. + + If debug is True, the event loop will be run in debug mode. + + This function should be used as a main entry point for + asyncio programs, and should not be used to call asynchronous + APIs. + + Example: + + async def main(): + server = await asyncio.start_server(...) + try: + yield # <- Let event loop run forever. + except KeyboardInterrupt: + print('^C received; exiting.') + finally: + server.close() + await server.wait_closed() + + asyncio.run_forever(main()) + """ + if not hasattr(inspect, 'isasyncgen'): + raise NotImplementedError + + if events._get_running_loop() is not None: + raise RuntimeError( + "asyncio.run_forever() cannot be called from a running event loop") + if not isinstance(threading.current_thread(), threading._MainThread): + raise RuntimeError( + "asyncio.run() must be called from the main thread") + if not inspect.isasyncgen(main): + raise ValueError( + "an asynchronous generator was expected, got {!r}".format(main)) + + loop = events.new_event_loop() + try: + events.set_event_loop(loop) + if debug: + loop.set_debug(True) + + ret = None + try: + ret = loop.run_until_complete(main.asend(None)) + except StopAsyncIteration as ex: + return + if ret is not None: + raise RuntimeError("only empty yield is supported") + + yielded_twice = False + try: + loop.run_forever() + except BaseException as ex: + try: + loop.run_until_complete(main.athrow(ex)) + except StopAsyncIteration as ex: + pass + else: + yielded_twice = True + else: + try: + loop.run_until_complete(main.asend(None)) + except StopAsyncIteration as ex: + pass + else: + yielded_twice = True + + if yielded_twice: + raise RuntimeError("only one yield is supported") + + finally: + _cleanup(loop) diff --git a/runtests.py b/runtests.py index c4074624..8fa2db93 100644 --- a/runtests.py +++ b/runtests.py @@ -112,6 +112,10 @@ def list_dir(prefix, dir): print("Skipping '{0}': need at least Python 3.5".format(modname), file=sys.stderr) continue + if modname == 'test_runner' and (sys.version_info < (3, 6)): + print("Skipping '{0}': need at least Python 3.6".format(modname), + file=sys.stderr) + continue try: loader = importlib.machinery.SourceFileLoader(modname, sourcefile) mods.append((loader.load_module(), sourcefile)) diff --git a/tests/test_runner.py b/tests/test_runner.py new file mode 100644 index 00000000..e5d9244b --- /dev/null +++ b/tests/test_runner.py @@ -0,0 +1,205 @@ +"""Tests asyncio.run() and asyncio.run_forever().""" + +import asyncio +import unittest +import sys + +from unittest import mock + + +class TestPolicy(asyncio.AbstractEventLoopPolicy): + + def __init__(self, loop_factory): + self.loop_factory = loop_factory + self.loop = None + + def get_event_loop(self): + # shouldn't ever be called by asyncio.run() + # or asyncio.run_forever() + raise RuntimeError + + def new_event_loop(self): + return self.loop_factory() + + def set_event_loop(self, loop): + if loop is not None: + # we want to check if the loop is closed + # in BaseTest.tearDown + self.loop = loop + + +class BaseTest(unittest.TestCase): + + def new_loop(self): + loop = asyncio.BaseEventLoop() + loop._process_events = mock.Mock() + loop._selector = mock.Mock() + loop._selector.select.return_value = () + loop.shutdown_ag_run = False + + async def shutdown_asyncgens(): + loop.shutdown_ag_run = True + loop.shutdown_asyncgens = shutdown_asyncgens + + return loop + + def setUp(self): + super().setUp() + + policy = TestPolicy(self.new_loop) + asyncio.set_event_loop_policy(policy) + + def tearDown(self): + policy = asyncio.get_event_loop_policy() + if policy.loop is not None: + self.assertTrue(policy.loop.is_closed()) + self.assertTrue(policy.loop.shutdown_ag_run) + + asyncio.set_event_loop_policy(None) + super().tearDown() + + +class RunTests(BaseTest): + + def test_asyncio_run_return(self): + async def main(): + await asyncio.sleep(0) + return 42 + + self.assertEqual(asyncio.run(main()), 42) + + def test_asyncio_run_raises(self): + async def main(): + await asyncio.sleep(0) + raise ValueError('spam') + + with self.assertRaisesRegex(ValueError, 'spam'): + asyncio.run(main()) + + def test_asyncio_run_only_coro(self): + for o in {1, lambda: None}: + with self.subTest(obj=o), \ + self.assertRaisesRegex(ValueError, + 'a coroutine was expected'): + asyncio.run(o) + + def test_asyncio_run_debug(self): + async def main(expected): + loop = asyncio.get_event_loop() + self.assertIs(loop.get_debug(), expected) + + asyncio.run(main(False)) + asyncio.run(main(True), debug=True) + + def test_asyncio_run_from_running_loop(self): + async def main(): + asyncio.run(main()) + + with self.assertRaisesRegex(RuntimeError, + 'cannot be called from a running'): + asyncio.run(main()) + + +class RunForeverTests(BaseTest): + + def stop_soon(self, *, exc=None): + loop = asyncio.get_event_loop() + + if exc: + def throw(): + raise exc + loop.call_later(0.01, throw) + else: + loop.call_later(0.01, loop.stop) + + def test_asyncio_run_forever_return(self): + async def main(): + if 0: + yield + return + + self.assertIsNone(asyncio.run_forever(main())) + + def test_asyncio_run_forever_non_none_yield(self): + async def main(): + yield 1 + + with self.assertRaisesRegex(RuntimeError, 'only empty'): + self.assertIsNone(asyncio.run_forever(main())) + + def test_asyncio_run_forever_raises_before_yield(self): + async def main(): + await asyncio.sleep(0) + raise ValueError('spam') + yield + + with self.assertRaisesRegex(ValueError, 'spam'): + asyncio.run_forever(main()) + + def test_asyncio_run_forever_raises_after_yield(self): + async def main(): + self.stop_soon() + yield + raise ValueError('spam') + + with self.assertRaisesRegex(ValueError, 'spam'): + asyncio.run_forever(main()) + + def test_asyncio_run_forever_two_yields(self): + async def main(): + self.stop_soon() + yield + yield + raise ValueError('spam') + + with self.assertRaisesRegex(RuntimeError, 'only one yield'): + asyncio.run_forever(main()) + + def test_asyncio_run_forever_only_ag(self): + async def coro(): + pass + + for o in {1, lambda: None, coro()}: + with self.subTest(obj=o), \ + self.assertRaisesRegex(ValueError, + 'an asynchronous.*was expected'): + asyncio.run_forever(o) + + def test_asyncio_run_forever_debug(self): + async def main(expected): + loop = asyncio.get_event_loop() + self.assertIs(loop.get_debug(), expected) + if 0: + yield + + asyncio.run_forever(main(False)) + asyncio.run_forever(main(True), debug=True) + + def test_asyncio_run_forever_from_running_loop(self): + async def main(): + asyncio.run_forever(main()) + if 0: + yield + + with self.assertRaisesRegex(RuntimeError, + 'cannot be called from a running'): + asyncio.run_forever(main()) + + def test_asyncio_run_forever_base_exception(self): + vi = sys.version_info + if vi[:2] != (3, 6) or vi.releaselevel == 'beta' and vi.serial < 4: + # See http://bugs.python.org/issue28721 for details. + raise unittest.SkipTest( + 'this test requires Python 3.6b4 or greater') + + class MyExc(BaseException): + pass + + async def main(): + self.stop_soon(exc=MyExc) + try: + yield + except MyExc: + pass + + asyncio.run_forever(main())