-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtask_manager.py
72 lines (59 loc) · 1.96 KB
/
task_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import asyncio
import logging
from enum import Enum
import pdb
log = logging.getLogger(__name__)
# TODO
# Support task naming and getting a task by name
# asyncio.create_task(coro, *, name=None)!! as well as Task.name etc.!
"""
TaskManager manages a collection of tasks
Note that as long as you ensure that the tasks exit, you don't necessarily need to do `await tmgr.close()` or similar! The cleanup is done automatically!
"""
class TaskManager:
def __init__(self):
self._tasks = []
self.callback_exception = True
self.gather_exception = False
def create_task(self, coro):
t = asyncio.create_task(coro)
t.add_done_callback(self._on_done)
self._tasks.append(t)
return t
def _on_done(self, future):
#log.debug("on_done")
self._tasks.remove(future)
if self.callback_exception:
if not future.cancelled():
if future.exception():
#pdb.set_trace()
log.debug("Exception in on_done")
# This will raise an exception
# It will be passed to the global default exception handler since it is an exception raised in a callback function
future.result()
def cancel(self):
"""
Only send cancel signal
"""
for t in self._tasks:
#log.debug("cancel")
t.cancel()
async def close(self):
"""
Send cancel signal and wait
"""
self.cancel()
await self.gather()
async def gather(self):
# The callback will remove the tasks from the list when they are done
# So, make a copy
ts = self._tasks[:]
result = await asyncio.gather(*self._tasks, return_exceptions=True)
if self.gather_exception:
for t in ts:
# We ignore it if a Task was cancelled (unhandled CancelledError)
# We only raise an exception if case there was a genuine exception (not CancelledError)
# See test/task_manager/cancel.py
if not t.cancelled() and t.exception():
log.debug(f"Exception in {t}")
t.result()