From 3c0ef99b6d6246bd2a674e2400a4376b043de1c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 24 Jan 2025 11:29:21 -0800 Subject: [PATCH] improve task creation and cancellation If a FrameProcessor needs to create a task it should use FrameProcessor.create_task() and FrameProcessor.cancel_task(). This gives Pipecat more control over all the tasks that are created in Pipecat. Both functions internally use the utils module: utils.create_task() and utils.cancel_task() which should also be used outside of FrameProcessors. That is, unless strictly necessary, we should avoid using asyncio.create_task(). --- .../22b-natural-conversation-proposal.py | 3 +- .../22c-natural-conversation-mixed-llms.py | 39 ++-- .../22d-natural-conversation-gemini-audio.py | 9 +- src/pipecat/audio/vad/silero.py | 2 +- src/pipecat/pipeline/parallel_pipeline.py | 50 +++--- src/pipecat/pipeline/task.py | 166 +++++++++--------- src/pipecat/pipeline/task_observer.py | 26 +-- .../aggregators/gated_openai_llm_context.py | 18 +- src/pipecat/processors/frame_processor.py | 79 ++++----- src/pipecat/processors/frameworks/rtvi.py | 28 ++- .../processors/idle_frame_processor.py | 7 +- src/pipecat/processors/user_idle_processor.py | 10 +- src/pipecat/services/ai_services.py | 83 ++++----- src/pipecat/services/cartesia.py | 7 +- src/pipecat/services/elevenlabs.py | 21 +-- src/pipecat/services/fish.py | 7 +- .../services/gemini_multimodal_live/gemini.py | 9 +- src/pipecat/services/gladia.py | 2 +- src/pipecat/services/lmnt.py | 7 +- .../services/openai_realtime_beta/openai.py | 68 +++---- src/pipecat/services/playht.py | 7 +- src/pipecat/services/riva.py | 19 +- src/pipecat/services/simli.py | 64 +++---- src/pipecat/services/websocket_service.py | 4 - src/pipecat/transports/base_input.py | 51 +++--- src/pipecat/transports/base_output.py | 87 ++++----- .../transports/network/fastapi_websocket.py | 13 +- .../transports/network/websocket_server.py | 3 +- src/pipecat/transports/services/daily.py | 32 ++-- src/pipecat/transports/services/livekit.py | 34 ++-- src/pipecat/utils/utils.py | 39 ++++ 31 files changed, 439 insertions(+), 555 deletions(-) diff --git a/examples/foundational/22b-natural-conversation-proposal.py b/examples/foundational/22b-natural-conversation-proposal.py index 328d691c4..c188ae7d4 100644 --- a/examples/foundational/22b-natural-conversation-proposal.py +++ b/examples/foundational/22b-natural-conversation-proposal.py @@ -169,8 +169,7 @@ async def _start(self): self._gate_task = self.get_event_loop().create_task(self._gate_task_handler()) async def _stop(self): - self._gate_task.cancel() - await self._gate_task + await self.cancel_task(self._gate_task) async def _gate_task_handler(self): while True: diff --git a/examples/foundational/22c-natural-conversation-mixed-llms.py b/examples/foundational/22c-natural-conversation-mixed-llms.py index 6281175b7..ff53cd838 100644 --- a/examples/foundational/22c-natural-conversation-mixed-llms.py +++ b/examples/foundational/22c-natural-conversation-mixed-llms.py @@ -101,12 +101,12 @@ Examples: # Complete Wh-question -[{"role": "assistant", "content": "I can help you learn."}, +[{"role": "assistant", "content": "I can help you learn."}, {"role": "user", "content": "What's the fastest way to learn Spanish"}] Output: YES # Complete Yes/No question despite STT error -[{"role": "assistant", "content": "I know about planets."}, +[{"role": "assistant", "content": "I know about planets."}, {"role": "user", "content": "Is is Jupiter the biggest planet"}] Output: YES @@ -118,12 +118,12 @@ Examples: # Direct instruction -[{"role": "assistant", "content": "I can explain many topics."}, +[{"role": "assistant", "content": "I can explain many topics."}, {"role": "user", "content": "Tell me about black holes"}] Output: YES # Action demand -[{"role": "assistant", "content": "I can help with math."}, +[{"role": "assistant", "content": "I can help with math."}, {"role": "user", "content": "Solve this equation x plus 5 equals 12"}] Output: YES @@ -134,12 +134,12 @@ Examples: # Specific answer -[{"role": "assistant", "content": "What's your favorite color?"}, +[{"role": "assistant", "content": "What's your favorite color?"}, {"role": "user", "content": "I really like blue"}] Output: YES # Option selection -[{"role": "assistant", "content": "Would you prefer morning or evening?"}, +[{"role": "assistant", "content": "Would you prefer morning or evening?"}, {"role": "user", "content": "Morning"}] Output: YES @@ -153,17 +153,17 @@ Examples: # Self-correction reaching completion -[{"role": "assistant", "content": "What would you like to know?"}, +[{"role": "assistant", "content": "What would you like to know?"}, {"role": "user", "content": "Tell me about... no wait, explain how rainbows form"}] Output: YES # Topic change with complete thought -[{"role": "assistant", "content": "The weather is nice today."}, +[{"role": "assistant", "content": "The weather is nice today."}, {"role": "user", "content": "Actually can you tell me who invented the telephone"}] Output: YES # Mid-sentence completion -[{"role": "assistant", "content": "Hello I'm ready."}, +[{"role": "assistant", "content": "Hello I'm ready."}, {"role": "user", "content": "What's the capital of? France"}] Output: YES @@ -175,12 +175,12 @@ Examples: # Acknowledgment -[{"role": "assistant", "content": "Should we talk about history?"}, +[{"role": "assistant", "content": "Should we talk about history?"}, {"role": "user", "content": "Sure"}] Output: YES # Disagreement with completion -[{"role": "assistant", "content": "Is that what you meant?"}, +[{"role": "assistant", "content": "Is that what you meant?"}, {"role": "user", "content": "No not really"}] Output: YES @@ -194,12 +194,12 @@ Examples: # Word repetition but complete -[{"role": "assistant", "content": "I can help with that."}, +[{"role": "assistant", "content": "I can help with that."}, {"role": "user", "content": "What what is the time right now"}] Output: YES # Missing punctuation but complete -[{"role": "assistant", "content": "I can explain that."}, +[{"role": "assistant", "content": "I can explain that."}, {"role": "user", "content": "Please tell me how computers work"}] Output: YES @@ -211,12 +211,12 @@ Examples: # Filler words but complete -[{"role": "assistant", "content": "What would you like to know?"}, +[{"role": "assistant", "content": "What would you like to know?"}, {"role": "user", "content": "Um uh how do airplanes fly"}] Output: YES # Thinking pause but incomplete -[{"role": "assistant", "content": "I can explain anything."}, +[{"role": "assistant", "content": "I can explain anything."}, {"role": "user", "content": "Well um I want to know about the"}] Output: NO @@ -241,17 +241,17 @@ Examples: # Incomplete despite corrections -[{"role": "assistant", "content": "What would you like to know about?"}, +[{"role": "assistant", "content": "What would you like to know about?"}, {"role": "user", "content": "Can you tell me about"}] Output: NO # Complete despite multiple artifacts -[{"role": "assistant", "content": "I can help you learn."}, +[{"role": "assistant", "content": "I can help you learn."}, {"role": "user", "content": "How do you I mean what's the best way to learn programming"}] Output: YES # Trailing off incomplete -[{"role": "assistant", "content": "I can explain anything."}, +[{"role": "assistant", "content": "I can explain anything."}, {"role": "user", "content": "I was wondering if you could tell me why"}] Output: NO """ @@ -374,8 +374,7 @@ async def _start(self): self._gate_task = self.get_event_loop().create_task(self._gate_task_handler()) async def _stop(self): - self._gate_task.cancel() - await self._gate_task + await cancel_task(self._gate_task) async def _gate_task_handler(self): while True: diff --git a/examples/foundational/22d-natural-conversation-gemini-audio.py b/examples/foundational/22d-natural-conversation-gemini-audio.py index 921aa0f3b..83ccef2ee 100644 --- a/examples/foundational/22d-natural-conversation-gemini-audio.py +++ b/examples/foundational/22d-natural-conversation-gemini-audio.py @@ -44,9 +44,7 @@ ) from pipecat.processors.filters.function_filter import FunctionFilter from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.processors.user_idle_processor import UserIdleProcessor from pipecat.services.cartesia import CartesiaTTSService -from pipecat.services.deepgram import DeepgramSTTService from pipecat.services.google import GoogleLLMContext, GoogleLLMService from pipecat.sync.base_notifier import BaseNotifier from pipecat.sync.event_notifier import EventNotifier @@ -440,11 +438,11 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, UserStartedSpeakingFrame): if self._idle_task: - self._idle_task.cancel() + await self.cancel_task(self._idle_task) elif isinstance(frame, TextFrame) and frame.text.startswith("YES"): logger.debug("Completeness check YES") if self._idle_task: - self._idle_task.cancel() + await self.cancel_task(self._idle_task) await self.push_frame(UserStoppedSpeakingFrame()) await self._audio_accumulator.reset() await self._notifier.notify() @@ -602,8 +600,7 @@ async def _start(self): self._gate_task = self.get_event_loop().create_task(self._gate_task_handler()) async def _stop(self): - self._gate_task.cancel() - await self._gate_task + await self.cancel_task(self._gate_task) async def _gate_task_handler(self): while True: diff --git a/src/pipecat/audio/vad/silero.py b/src/pipecat/audio/vad/silero.py index fab01ae70..00a1358a5 100644 --- a/src/pipecat/audio/vad/silero.py +++ b/src/pipecat/audio/vad/silero.py @@ -159,5 +159,5 @@ def voice_confidence(self, buffer) -> float: return new_confidence except Exception as e: # This comes from an empty audio array - logger.exception(f"Error analyzing audio with Silero VAD: {e}") + logger.error(f"Error analyzing audio with Silero VAD: {e}") return 0 diff --git a/src/pipecat/pipeline/parallel_pipeline.py b/src/pipecat/pipeline/parallel_pipeline.py index 874a8c3a5..2d5bd7761 100644 --- a/src/pipecat/pipeline/parallel_pipeline.py +++ b/src/pipecat/pipeline/parallel_pipeline.py @@ -150,22 +150,18 @@ async def _start(self): async def _stop(self): # The up task doesn't receive an EndFrame, so we just cancel it. - self._up_task.cancel() - await self._up_task - # The down tasks waits for the last EndFrame send by the internal + await self.cancel_task(self._up_task) + # The down tasks waits for the last EndFrame sent by the internal # pipelines. await self._down_task async def _cancel(self): - self._up_task.cancel() - await self._up_task - self._down_task.cancel() - await self._down_task + await self.cancel_task(self._up_task) + await self.cancel_task(self._down_task) async def _create_tasks(self): - loop = self.get_event_loop() - self._up_task = loop.create_task(self._process_up_queue()) - self._down_task = loop.create_task(self._process_down_queue()) + self._up_task = self.create_task(self._process_up_queue()) + self._down_task = self.create_task(self._process_down_queue()) async def _drain_queues(self): while not self._up_queue.empty: @@ -185,32 +181,26 @@ async def _parallel_push_frame(self, frame: Frame, direction: FrameDirection): async def _process_up_queue(self): while True: - try: - frame = await self._up_queue.get() - await self._parallel_push_frame(frame, FrameDirection.UPSTREAM) - self._up_queue.task_done() - except asyncio.CancelledError: - break + frame = await self._up_queue.get() + await self._parallel_push_frame(frame, FrameDirection.UPSTREAM) + self._up_queue.task_done() async def _process_down_queue(self): running = True while running: - try: - frame = await self._down_queue.get() + frame = await self._down_queue.get() - endframe_counter = self._endframe_counter.get(frame.id, 0) + endframe_counter = self._endframe_counter.get(frame.id, 0) - # If we have a counter, decrement it. - if endframe_counter > 0: - self._endframe_counter[frame.id] -= 1 - endframe_counter = self._endframe_counter[frame.id] + # If we have a counter, decrement it. + if endframe_counter > 0: + self._endframe_counter[frame.id] -= 1 + endframe_counter = self._endframe_counter[frame.id] - # If we don't have a counter or we reached 0, push the frame. - if endframe_counter == 0: - await self._parallel_push_frame(frame, FrameDirection.DOWNSTREAM) + # If we don't have a counter or we reached 0, push the frame. + if endframe_counter == 0: + await self._parallel_push_frame(frame, FrameDirection.DOWNSTREAM) - running = not (endframe_counter == 0 and isinstance(frame, EndFrame)) + running = not (endframe_counter == 0 and isinstance(frame, EndFrame)) - self._down_queue.task_done() - except asyncio.CancelledError: - break + self._down_queue.task_done() diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 82918c239..34a7bb5eb 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -30,7 +30,7 @@ from pipecat.pipeline.base_task import BaseTask from pipecat.pipeline.task_observer import TaskObserver from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.utils.utils import obj_count, obj_id +from pipecat.utils.utils import cancel_task, create_task, obj_count, obj_id HEARTBEAT_SECONDS = 1.0 HEARTBEAT_MONITOR_SECONDS = HEARTBEAT_SECONDS * 5 @@ -49,7 +49,7 @@ class PipelineParams(BaseModel): heartbeats_period_secs: float = HEARTBEAT_SECONDS -class Source(FrameProcessor): +class PipelineTaskSource(FrameProcessor): """This is the source processor that is linked at the beginning of the pipeline given to the pipeline task. It allows us to easily push frames downstream to the pipeline and also receive upstream frames coming from the @@ -57,8 +57,8 @@ class Source(FrameProcessor): """ - def __init__(self, up_queue: asyncio.Queue): - super().__init__() + def __init__(self, up_queue: asyncio.Queue, **kwargs): + super().__init__(**kwargs) self._up_queue = up_queue async def process_frame(self, frame: Frame, direction: FrameDirection): @@ -71,15 +71,15 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self.push_frame(frame, direction) -class Sink(FrameProcessor): +class PipelineTaskSink(FrameProcessor): """This is the sink processor that is linked at the end of the pipeline given to the pipeline task. It allows us to receive downstream frames and act on them, for example, waiting to receive an EndFrame. """ - def __init__(self, down_queue: asyncio.Queue): - super().__init__() + def __init__(self, down_queue: asyncio.Queue, **kwargs): + super().__init__(**kwargs) self._down_queue = down_queue async def process_frame(self, frame: Frame, direction: FrameDirection): @@ -115,10 +115,10 @@ def __init__( # down queue. self._endframe_event = asyncio.Event() - self._source = Source(self._up_queue) + self._source = PipelineTaskSource(self._up_queue) self._source.link(pipeline) - self._sink = Sink(self._down_queue) + self._sink = PipelineTaskSink(self._down_queue) pipeline.link(self._sink) self._observer = TaskObserver(params.observers) @@ -148,13 +148,22 @@ async def cancel(self): # we want to cancel right away. await self._source.push_frame(CancelFrame()) await self._cancel_tasks(True) + await self._cleanup() async def run(self): """ Starts running the given pipeline. """ - tasks = self._create_tasks() - await asyncio.gather(*tasks) + try: + push_task = self._create_tasks() + await asyncio.gather(push_task) + except asyncio.CancelledError: + # We are awaiting on the push task and it might be cancelled + # (e.g. Ctrl-C). This means we will get a CancelledError here as + # well, because you get a CancelledError in every place you are + # awaiting a task. + pass + await self._cancel_tasks(False) self._finished = True async def queue_frame(self, frame: Frame): @@ -175,41 +184,44 @@ async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]): await self.queue_frame(frame) def _create_tasks(self): - tasks = [] - self._process_up_task = asyncio.create_task(self._process_up_queue()) - self._process_down_task = asyncio.create_task(self._process_down_queue()) - self._process_push_task = asyncio.create_task(self._process_push_queue()) - - tasks = [self._process_up_task, self._process_down_task, self._process_push_task] + loop = asyncio.get_running_loop() + self._process_up_task = create_task( + loop, self._process_up_queue(), f"{self}::_process_up_queue" + ) + self._process_down_task = create_task( + loop, self._process_down_queue(), f"{self}::_process_down_queue" + ) + self._process_push_task = create_task( + loop, self._process_push_queue(), f"{self}::_process_push_queue" + ) - return tasks + return self._process_push_task def _maybe_start_heartbeat_tasks(self): if self._params.enable_heartbeats: - self._heartbeat_push_task = asyncio.create_task(self._heartbeat_push_handler()) - self._heartbeat_monitor_task = asyncio.create_task(self._heartbeat_monitor_handler()) + loop = asyncio.get_running_loop() + self._heartbeat_push_task = create_task( + loop, self._heartbeat_push_handler(), f"{self}::_heartbeat_push_handler" + ) + self._heartbeat_monitor_task = create_task( + loop, self._heartbeat_monitor_handler(), f"{self}::_heartbeat_monitor_handler" + ) async def _cancel_tasks(self, cancel_push: bool): await self._maybe_cancel_heartbeat_tasks() if cancel_push: - self._process_push_task.cancel() - await self._process_push_task - - self._process_up_task.cancel() - await self._process_up_task + await cancel_task(self._process_push_task) - self._process_down_task.cancel() - await self._process_down_task + await cancel_task(self._process_up_task) + await cancel_task(self._process_down_task) await self._observer.stop() async def _maybe_cancel_heartbeat_tasks(self): if self._params.enable_heartbeats: - self._heartbeat_push_task.cancel() - await self._heartbeat_push_task - self._heartbeat_monitor_task.cancel() - await self._heartbeat_monitor_task + await cancel_task(self._heartbeat_push_task) + await cancel_task(self._heartbeat_monitor_task) def _initial_metrics_frame(self) -> MetricsFrame: processors = self._pipeline.processors_with_metrics() @@ -223,6 +235,11 @@ async def _wait_for_endframe(self): await self._endframe_event.wait() self._endframe_event.clear() + async def _cleanup(self): + await self._source.cleanup() + await self._pipeline.cleanup() + await self._sink.cleanup() + async def _process_push_queue(self): """This is the task that runs the pipeline for the first time by sending a StartFrame and by pushing any other frames queued by the user. It runs @@ -249,24 +266,16 @@ async def _process_push_queue(self): running = True should_cleanup = True while running: - try: - frame = await self._push_queue.get() - await self._source.queue_frame(frame, FrameDirection.DOWNSTREAM) - if isinstance(frame, EndFrame): - await self._wait_for_endframe() - running = not isinstance(frame, (StopTaskFrame, EndFrame)) - should_cleanup = not isinstance(frame, StopTaskFrame) - self._push_queue.task_done() - except asyncio.CancelledError: - break + frame = await self._push_queue.get() + await self._source.queue_frame(frame, FrameDirection.DOWNSTREAM) + if isinstance(frame, EndFrame): + await self._wait_for_endframe() + running = not isinstance(frame, (StopTaskFrame, EndFrame)) + should_cleanup = not isinstance(frame, StopTaskFrame) + self._push_queue.task_done() # Cleanup only if we need to. if should_cleanup: - await self._source.cleanup() - await self._pipeline.cleanup() - await self._sink.cleanup() - # Finally, cancel internal tasks. We don't cancel the push tasks because - # that's us. - await self._cancel_tasks(False) + await self._cleanup() async def _process_up_queue(self): """This is the task that processes frames coming upstream from the @@ -276,26 +285,23 @@ async def _process_up_queue(self): """ while True: - try: - frame = await self._up_queue.get() - if isinstance(frame, EndTaskFrame): - # Tell the task we should end nicely. - await self.queue_frame(EndFrame()) - elif isinstance(frame, CancelTaskFrame): - # Tell the task we should end right away. + frame = await self._up_queue.get() + if isinstance(frame, EndTaskFrame): + # Tell the task we should end nicely. + await self.queue_frame(EndFrame()) + elif isinstance(frame, CancelTaskFrame): + # Tell the task we should end right away. + await self.queue_frame(CancelFrame()) + elif isinstance(frame, StopTaskFrame): + await self.queue_frame(StopTaskFrame()) + elif isinstance(frame, ErrorFrame): + logger.error(f"Error running app: {frame}") + if frame.fatal: + # Cancel all tasks downstream. await self.queue_frame(CancelFrame()) - elif isinstance(frame, StopTaskFrame): + # Tell the task we should stop. await self.queue_frame(StopTaskFrame()) - elif isinstance(frame, ErrorFrame): - logger.error(f"Error running app: {frame}") - if frame.fatal: - # Cancel all tasks downstream. - await self.queue_frame(CancelFrame()) - # Tell the task we should stop. - await self.queue_frame(StopTaskFrame()) - self._up_queue.task_done() - except asyncio.CancelledError: - break + self._up_queue.task_done() async def _process_down_queue(self): """This tasks process frames coming downstream from the pipeline. For @@ -305,29 +311,23 @@ async def _process_down_queue(self): """ while True: - try: - frame = await self._down_queue.get() - if isinstance(frame, EndFrame): - self._endframe_event.set() - elif isinstance(frame, HeartbeatFrame): - await self._heartbeat_queue.put(frame) - self._down_queue.task_done() - except asyncio.CancelledError: - break + frame = await self._down_queue.get() + if isinstance(frame, EndFrame): + self._endframe_event.set() + elif isinstance(frame, HeartbeatFrame): + await self._heartbeat_queue.put(frame) + self._down_queue.task_done() async def _heartbeat_push_handler(self): """ This tasks pushes a heartbeat frame every heartbeat period. """ while True: - try: - # Don't use `queue_frame()` because if an EndFrame is queued the - # task will just stop waiting for the pipeline to finish not - # allowing more frames to be pushed. - await self._source.queue_frame(HeartbeatFrame(timestamp=self._clock.get_time())) - await asyncio.sleep(self._params.heartbeats_period_secs) - except asyncio.CancelledError: - break + # Don't use `queue_frame()` because if an EndFrame is queued the + # task will just stop waiting for the pipeline to finish not + # allowing more frames to be pushed. + await self._source.queue_frame(HeartbeatFrame(timestamp=self._clock.get_time())) + await asyncio.sleep(self._params.heartbeats_period_secs) async def _heartbeat_monitor_handler(self): """This tasks monitors heartbeat frames. If a heartbeat frame has not @@ -347,8 +347,6 @@ async def _heartbeat_monitor_handler(self): logger.warning( f"{self}: heartbeat frame not received for more than {wait_time} seconds" ) - except asyncio.CancelledError: - break def __str__(self): return self.name diff --git a/src/pipecat/pipeline/task_observer.py b/src/pipecat/pipeline/task_observer.py index 2fd13f517..1bf4ff0f9 100644 --- a/src/pipecat/pipeline/task_observer.py +++ b/src/pipecat/pipeline/task_observer.py @@ -12,6 +12,7 @@ from pipecat.frames.frames import Frame from pipecat.observers.base_observer import BaseObserver from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.utils.utils import cancel_task, create_task, obj_count @dataclass @@ -54,13 +55,13 @@ class TaskObserver(BaseObserver): """ def __init__(self, observers: List[BaseObserver] = []): + self.name: str = f"{self.__class__.__name__}#{obj_count(self)}" self._proxies: List[Proxy] = self._create_proxies(observers) async def stop(self): """Stops all proxy observer tasks.""" for proxy in self._proxies: - proxy.task.cancel() - await proxy.task + await cancel_task(proxy.task) async def on_push_frame( self, @@ -79,19 +80,24 @@ async def on_push_frame( def _create_proxies(self, observers) -> List[Proxy]: proxies = [] + loop = asyncio.get_running_loop() for observer in observers: queue = asyncio.Queue() - task = asyncio.create_task(self._proxy_task_handler(queue, observer)) + task = create_task( + loop, + self._proxy_task_handler(queue, observer), + f"{self}::{observer.__class__.__name__}", + ) proxy = Proxy(queue=queue, task=task, observer=observer) proxies.append(proxy) return proxies async def _proxy_task_handler(self, queue: asyncio.Queue, observer: BaseObserver): while True: - try: - data = await queue.get() - await observer.on_push_frame( - data.src, data.dst, data.frame, data.direction, data.timestamp - ) - except asyncio.CancelledError: - break + data = await queue.get() + await observer.on_push_frame( + data.src, data.dst, data.frame, data.direction, data.timestamp + ) + + def __str__(self): + return self.name diff --git a/src/pipecat/processors/aggregators/gated_openai_llm_context.py b/src/pipecat/processors/aggregators/gated_openai_llm_context.py index a185528f2..8c3f496eb 100644 --- a/src/pipecat/processors/aggregators/gated_openai_llm_context.py +++ b/src/pipecat/processors/aggregators/gated_openai_llm_context.py @@ -4,8 +4,6 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio - from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame from pipecat.processors.frame_processor import FrameDirection, FrameProcessor @@ -38,18 +36,14 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self.push_frame(frame, direction) async def _start(self): - self._gate_task = self.get_event_loop().create_task(self._gate_task_handler()) + self._gate_task = self.create_task(self._gate_task_handler()) async def _stop(self): - self._gate_task.cancel() - await self._gate_task + await self.cancel_task(self._gate_task) async def _gate_task_handler(self): while True: - try: - await self._notifier.wait() - if self._last_context_frame: - await self.push_frame(self._last_context_frame) - self._last_context_frame = None - except asyncio.CancelledError: - break + await self._notifier.wait() + if self._last_context_frame: + await self.push_frame(self._last_context_frame) + self._last_context_frame = None diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index cc483465a..3159d6011 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -6,15 +6,15 @@ import asyncio import inspect +import sys from enum import Enum -from typing import Awaitable, Callable, Optional +from typing import Awaitable, Callable, Coroutine, Optional from loguru import logger from pipecat.clocks.base_clock import BaseClock from pipecat.frames.frames import ( CancelFrame, - EndFrame, ErrorFrame, Frame, StartFrame, @@ -24,7 +24,7 @@ ) from pipecat.metrics.metrics import LLMTokenUsage, MetricsData from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics -from pipecat.utils.utils import obj_count, obj_id +from pipecat.utils.utils import cancel_task, create_task, obj_count, obj_id class FrameDirection(Enum): @@ -141,6 +141,13 @@ async def stop_all_metrics(self): await self.stop_ttfb_metrics() await self.stop_processing_metrics() + def create_task(self, coroutine: Coroutine) -> asyncio.Task: + debug = f"{self}::{coroutine.cr_code.co_name}" + return create_task(self.get_event_loop(), coroutine, debug) + + async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None): + await cancel_task(task, timeout) + async def cleanup(self): await self.__cancel_input_task() await self.__cancel_push_task() @@ -188,7 +195,6 @@ async def pause_processing_frames(self): async def resume_processing_frames(self): logger.trace(f"{self}: resuming frame processing") self.__input_event.set() - self.__should_block_frames = False async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, StartFrame): @@ -283,61 +289,44 @@ async def __internal_push_frame(self, frame: Frame, direction: FrameDirection): def __create_input_task(self): self.__should_block_frames = False self.__input_queue = asyncio.Queue() - self.__input_frame_task = self.get_event_loop().create_task( - self.__input_frame_task_handler() - ) self.__input_event = asyncio.Event() + self.__input_frame_task = self.create_task(self.__input_frame_task_handler()) async def __cancel_input_task(self): - self.__input_frame_task.cancel() - await self.__input_frame_task + await self.cancel_task(self.__input_frame_task) async def __input_frame_task_handler(self): while True: - try: - if self.__should_block_frames: - logger.trace(f"{self}: frame processing paused") - await self.__input_event.wait() - self.__input_event.clear() - logger.trace(f"{self}: frame processing resumed") - - (frame, direction, callback) = await self.__input_queue.get() - - # Process the frame. - await self.process_frame(frame, direction) - - # If this frame has an associated callback, call it now. - if callback: - await callback(self, frame, direction) - - self.__input_queue.task_done() - except asyncio.CancelledError: - logger.trace(f"{self}: cancelled input task") - break - except Exception as e: - logger.exception(f"{self}: Uncaught exception {e}") - await self.push_error(ErrorFrame(str(e))) + if self.__should_block_frames: + logger.trace(f"{self}: frame processing paused") + await self.__input_event.wait() + self.__input_event.clear() + self.__should_block_frames = False + logger.trace(f"{self}: frame processing resumed") + + (frame, direction, callback) = await self.__input_queue.get() + + # Process the frame. + await self.process_frame(frame, direction) + + # If this frame has an associated callback, call it now. + if callback: + await callback(self, frame, direction) + + self.__input_queue.task_done() def __create_push_task(self): self.__push_queue = asyncio.Queue() - self.__push_frame_task = self.get_event_loop().create_task(self.__push_frame_task_handler()) + self.__push_frame_task = self.create_task(self.__push_frame_task_handler()) async def __cancel_push_task(self): - self.__push_frame_task.cancel() - await self.__push_frame_task + await self.cancel_task(self.__push_frame_task) async def __push_frame_task_handler(self): while True: - try: - (frame, direction) = await self.__push_queue.get() - await self.__internal_push_frame(frame, direction) - self.__push_queue.task_done() - except asyncio.CancelledError: - logger.trace(f"{self}: cancelled push task") - break - except Exception as e: - logger.exception(f"{self}: Uncaught exception {e}") - await self.push_error(ErrorFrame(str(e))) + (frame, direction) = await self.__push_queue.get() + await self.__internal_push_frame(frame, direction) + self.__push_queue.task_done() async def _call_event_handler(self, event_name: str, *args, **kwargs): try: diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index fcfb73237..f59b6d48f 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -764,11 +764,11 @@ def __init__( # A task to process incoming action frames. self._action_queue = asyncio.Queue() - self._action_task = self.get_event_loop().create_task(self._action_task_handler()) + self._action_task = self.create_task(self._action_task_handler()) # A task to process incoming transport messages. self._message_queue = asyncio.Queue() - self._message_task = self.get_event_loop().create_task(self._message_task_handler()) + self._message_task = self.create_task(self._message_task_handler()) self._register_event_handler("on_bot_started") self._register_event_handler("on_client_ready") @@ -873,13 +873,11 @@ async def _cancel(self, frame: CancelFrame): async def _cancel_tasks(self): if self._action_task: - self._action_task.cancel() - await self._action_task + await self.cancel_task(self._action_task) self._action_task = None if self._message_task: - self._message_task.cancel() - await self._message_task + await self.cancel_task(self._message_task) self._message_task = None async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True): @@ -888,21 +886,15 @@ async def _push_transport_message(self, model: BaseModel, exclude_none: bool = T async def _action_task_handler(self): while True: - try: - frame = await self._action_queue.get() - await self._handle_action(frame.message_id, frame.rtvi_action_run) - self._action_queue.task_done() - except asyncio.CancelledError: - break + frame = await self._action_queue.get() + await self._handle_action(frame.message_id, frame.rtvi_action_run) + self._action_queue.task_done() async def _message_task_handler(self): while True: - try: - message = await self._message_queue.get() - await self._handle_message(message) - self._message_queue.task_done() - except asyncio.CancelledError: - break + message = await self._message_queue.get() + await self._handle_message(message) + self._message_queue.task_done() async def _handle_transport_message(self, frame: TransportMessageUrgentFrame): try: diff --git a/src/pipecat/processors/idle_frame_processor.py b/src/pipecat/processors/idle_frame_processor.py index 3f5f51e45..85ea215cf 100644 --- a/src/pipecat/processors/idle_frame_processor.py +++ b/src/pipecat/processors/idle_frame_processor.py @@ -49,12 +49,11 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): self._idle_event.set() async def cleanup(self): - self._idle_task.cancel() - await self._idle_task + await self.cancel_task(self._idle_task) def _create_idle_task(self): self._idle_event = asyncio.Event() - self._idle_task = self.get_event_loop().create_task(self._idle_task_handler()) + self._idle_task = self.create_task(self._idle_task_handler()) async def _idle_task_handler(self): while True: @@ -62,7 +61,5 @@ async def _idle_task_handler(self): await asyncio.wait_for(self._idle_event.wait(), timeout=self._timeout) except asyncio.TimeoutError: await self._callback(self) - except asyncio.CancelledError: - break finally: self._idle_event.clear() diff --git a/src/pipecat/processors/user_idle_processor.py b/src/pipecat/processors/user_idle_processor.py index 3a7202c80..fa2ca515b 100644 --- a/src/pipecat/processors/user_idle_processor.py +++ b/src/pipecat/processors/user_idle_processor.py @@ -57,16 +57,12 @@ def __init__( def _create_idle_task(self): """Create the idle task if it hasn't been created yet.""" if self._idle_task is None: - self._idle_task = self.get_event_loop().create_task(self._idle_task_handler()) + self._idle_task = self.create_task(self._idle_task_handler()) async def _stop(self): """Stops and cleans up the idle monitoring task.""" if self._idle_task is not None: - self._idle_task.cancel() - try: - await self._idle_task - except asyncio.CancelledError: - pass # Expected when task is cancelled + await self.cancel_task(self._idle_task) self._idle_task = None async def process_frame(self, frame: Frame, direction: FrameDirection): @@ -122,7 +118,5 @@ async def _idle_task_handler(self): except asyncio.TimeoutError: if not self._interrupted: await self._callback(self) - except asyncio.CancelledError: - break finally: self._idle_event.clear() diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index bfd155d38..2b0485fdd 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -253,20 +253,18 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: async def start(self, frame: StartFrame): await super().start(frame) if self._push_stop_frames: - self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler()) + self._stop_frame_task = self.create_task(self._stop_frame_handler()) async def stop(self, frame: EndFrame): await super().stop(frame) if self._stop_frame_task: - self._stop_frame_task.cancel() - await self._stop_frame_task + await self.cancel_task(self._stop_frame_task) self._stop_frame_task = None async def cancel(self, frame: CancelFrame): await super().cancel(frame) if self._stop_frame_task: - self._stop_frame_task.cancel() - await self._stop_frame_task + await self.cancel_task(self._stop_frame_task) self._stop_frame_task = None async def _update_settings(self, settings: Dict[str, Any]): @@ -364,23 +362,20 @@ async def _push_tts_frames(self, text: str): await self.push_frame(TTSTextFrame(text)) async def _stop_frame_handler(self): - try: - has_started = False - while True: - try: - frame = await asyncio.wait_for( - self._stop_frame_queue.get(), self._stop_frame_timeout_s - ) - if isinstance(frame, TTSStartedFrame): - has_started = True - elif isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)): - has_started = False - except asyncio.TimeoutError: - if has_started: - await self.push_frame(TTSStoppedFrame()) - has_started = False - except asyncio.CancelledError: - pass + has_started = False + while True: + try: + frame = await asyncio.wait_for( + self._stop_frame_queue.get(), self._stop_frame_timeout_s + ) + if isinstance(frame, TTSStartedFrame): + has_started = True + elif isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)): + has_started = False + except asyncio.TimeoutError: + if has_started: + await self.push_frame(TTSStoppedFrame()) + has_started = False class WordTTSService(TTSService): @@ -388,7 +383,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self._initial_word_timestamp = -1 self._words_queue = asyncio.Queue() - self._words_task = self.get_event_loop().create_task(self._words_task_handler()) + self._words_task = self.create_task(self._words_task_handler()) def start_word_timestamps(self): if self._initial_word_timestamp == -1: @@ -421,35 +416,29 @@ async def _handle_interruption(self, frame: StartInterruptionFrame, direction: F async def _stop_words_task(self): if self._words_task: - self._words_task.cancel() - await self._words_task + await self.cancel_task(self._words_task) self._words_task = None async def _words_task_handler(self): last_pts = 0 while True: - try: - (word, timestamp) = await self._words_queue.get() - if word == "Reset" and timestamp == 0: - self.reset_word_timestamps() - frame = None - elif word == "LLMFullResponseEndFrame" and timestamp == 0: - frame = LLMFullResponseEndFrame() - frame.pts = last_pts - elif word == "TTSStoppedFrame" and timestamp == 0: - frame = TTSStoppedFrame() - frame.pts = last_pts - else: - frame = TTSTextFrame(word) - frame.pts = self._initial_word_timestamp + timestamp - if frame: - last_pts = frame.pts - await self.push_frame(frame) - self._words_queue.task_done() - except asyncio.CancelledError: - break - except Exception as e: - logger.exception(f"{self} exception: {e}") + (word, timestamp) = await self._words_queue.get() + if word == "Reset" and timestamp == 0: + self.reset_word_timestamps() + frame = None + elif word == "LLMFullResponseEndFrame" and timestamp == 0: + frame = LLMFullResponseEndFrame() + frame.pts = last_pts + elif word == "TTSStoppedFrame" and timestamp == 0: + frame = TTSStoppedFrame() + frame.pts = last_pts + else: + frame = TTSTextFrame(word) + frame.pts = self._initial_word_timestamp + timestamp + if frame: + last_pts = frame.pts + await self.push_frame(frame) + self._words_queue.task_done() class STTService(AIService): diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index 35856ea52..75da116d5 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -187,16 +187,13 @@ async def cancel(self, frame: CancelFrame): async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task( - self._receive_task_handler(self.push_error) - ) + self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) async def _disconnect(self): await self._disconnect_websocket() if self._receive_task: - self._receive_task.cancel() - await self._receive_task + await self.cancel_task(self._receive_task) self._receive_task = None async def _connect_websocket(self): diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index 2677accaa..81761da71 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -298,20 +298,16 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task( - self._receive_task_handler(self.push_error) - ) - self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler()) + self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) + self._keepalive_task = self.create_task(self._keepalive_task_handler()) async def _disconnect(self): if self._receive_task: - self._receive_task.cancel() - await self._receive_task + await self.cancel_task(self._receive_task) self._receive_task = None if self._keepalive_task: - self._keepalive_task.cancel() - await self._keepalive_task + await self.cancel_task(self._keepalive_task) self._keepalive_task = None await self._disconnect_websocket() @@ -382,13 +378,8 @@ async def _receive_messages(self): async def _keepalive_task_handler(self): while True: - try: - await asyncio.sleep(10) - await self._send_text("") - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"{self} exception: {e}") + await asyncio.sleep(10) + await self._send_text("") async def _send_text(self, text: str): if self._websocket: diff --git a/src/pipecat/services/fish.py b/src/pipecat/services/fish.py index 710ab04ec..6fadab683 100644 --- a/src/pipecat/services/fish.py +++ b/src/pipecat/services/fish.py @@ -104,15 +104,12 @@ async def cancel(self, frame: CancelFrame): async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task( - self._receive_task_handler(self.push_error) - ) + self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) async def _disconnect(self): await self._disconnect_websocket() if self._receive_task: - self._receive_task.cancel() - await self._receive_task + await self.cancel_task(self._receive_task) self._receive_task = None async def _connect_websocket(self): diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index 0241b3b64..aeb263323 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -391,7 +391,7 @@ async def _connect(self): uri = f"wss://{self.base_url}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}" logger.info(f"Connecting to {uri}") self._websocket = await websockets.connect(uri=uri) - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._receive_task = self.create_task(self._receive_task_handler()) config = events.Config.model_validate( { "setup": { @@ -441,11 +441,7 @@ async def _disconnect(self): await self._websocket.close() self._websocket = None if self._receive_task: - self._receive_task.cancel() - try: - await asyncio.wait_for(self._receive_task, timeout=1.0) - except asyncio.TimeoutError: - logger.warning("Timed out waiting for receive task to finish") + await self.cancel_task(self._receive_task, timeout=1.0) self._receive_task = None self._disconnecting = False except Exception as e: @@ -497,6 +493,7 @@ async def _receive_task_handler(self): pass except asyncio.CancelledError: logger.debug("websocket receive task cancelled") + raise except Exception as e: logger.error(f"{self} exception: {e}") diff --git a/src/pipecat/services/gladia.py b/src/pipecat/services/gladia.py index 3487be5ae..e240252a2 100644 --- a/src/pipecat/services/gladia.py +++ b/src/pipecat/services/gladia.py @@ -180,7 +180,7 @@ async def start(self, frame: StartFrame): await super().start(frame) response = await self._setup_gladia() self._websocket = await websockets.connect(response["url"]) - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._receive_task = self.create_task(self._receive_task_handler()) async def stop(self, frame: EndFrame): await super().stop(frame) diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt.py index 6f654293a..327901423 100644 --- a/src/pipecat/services/lmnt.py +++ b/src/pipecat/services/lmnt.py @@ -113,16 +113,13 @@ async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirect async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task( - self._receive_task_handler(self.push_error) - ) + self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) async def _disconnect(self): await self._disconnect_websocket() if self._receive_task: - self._receive_task.cancel() - await self._receive_task + await self.cancel_task(self._receive_task) self._receive_task = None async def _connect_websocket(self): diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index c1d7c4a4d..c64c88a5f 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -266,7 +266,7 @@ async def _connect(self): "OpenAI-Beta": "realtime=v1", }, ) - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._receive_task = self.create_task(self._receive_task_handler()) except Exception as e: logger.error(f"{self} initialization error: {e}") self._websocket = None @@ -280,11 +280,7 @@ async def _disconnect(self): await self._websocket.close() self._websocket = None if self._receive_task: - self._receive_task.cancel() - try: - await asyncio.wait_for(self._receive_task, timeout=1.0) - except asyncio.TimeoutError: - logger.warning("Timed out waiting for receive task to finish") + await self.cancel_task(self._receive_task, timeout=1.0) self._receive_task = None self._disconnecting = False except Exception as e: @@ -321,40 +317,32 @@ async def _update_settings(self): # async def _receive_task_handler(self): - try: - async for message in self._websocket: - evt = events.parse_server_event(message) - if evt.type == "session.created": - await self._handle_evt_session_created(evt) - elif evt.type == "session.updated": - await self._handle_evt_session_updated(evt) - elif evt.type == "response.audio.delta": - await self._handle_evt_audio_delta(evt) - elif evt.type == "response.audio.done": - await self._handle_evt_audio_done(evt) - elif evt.type == "conversation.item.created": - await self._handle_evt_conversation_item_created(evt) - elif evt.type == "conversation.item.input_audio_transcription.completed": - await self.handle_evt_input_audio_transcription_completed(evt) - elif evt.type == "response.done": - await self._handle_evt_response_done(evt) - elif evt.type == "input_audio_buffer.speech_started": - await self._handle_evt_speech_started(evt) - elif evt.type == "input_audio_buffer.speech_stopped": - await self._handle_evt_speech_stopped(evt) - elif evt.type == "response.audio_transcript.delta": - await self._handle_evt_audio_transcript_delta(evt) - elif evt.type == "error": - await self._handle_evt_error(evt) - # errors are fatal, so exit the receive loop - return - - else: - pass - except asyncio.CancelledError: - logger.debug("websocket receive task cancelled") - except Exception as e: - logger.error(f"{self} exception: {e}") + async for message in self._websocket: + evt = events.parse_server_event(message) + if evt.type == "session.created": + await self._handle_evt_session_created(evt) + elif evt.type == "session.updated": + await self._handle_evt_session_updated(evt) + elif evt.type == "response.audio.delta": + await self._handle_evt_audio_delta(evt) + elif evt.type == "response.audio.done": + await self._handle_evt_audio_done(evt) + elif evt.type == "conversation.item.created": + await self._handle_evt_conversation_item_created(evt) + elif evt.type == "conversation.item.input_audio_transcription.completed": + await self.handle_evt_input_audio_transcription_completed(evt) + elif evt.type == "response.done": + await self._handle_evt_response_done(evt) + elif evt.type == "input_audio_buffer.speech_started": + await self._handle_evt_speech_started(evt) + elif evt.type == "input_audio_buffer.speech_stopped": + await self._handle_evt_speech_stopped(evt) + elif evt.type == "response.audio_transcript.delta": + await self._handle_evt_audio_transcript_delta(evt) + elif evt.type == "error": + await self._handle_evt_error(evt) + # errors are fatal, so exit the receive loop + return async def _handle_evt_session_created(self, evt): # session.created is received right after connecting. Send a message diff --git a/src/pipecat/services/playht.py b/src/pipecat/services/playht.py index 46b724a4d..07e78ae52 100644 --- a/src/pipecat/services/playht.py +++ b/src/pipecat/services/playht.py @@ -165,16 +165,13 @@ async def cancel(self, frame: CancelFrame): async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task( - self._receive_task_handler(self.push_error) - ) + self._receive_task = self.create_task(self._receive_task_handler(self.push_error)) async def _disconnect(self): await self._disconnect_websocket() if self._receive_task: - self._receive_task.cancel() - await self._receive_task + await self.cancel_task(self._receive_task) self._receive_task = None async def _connect_websocket(self): diff --git a/src/pipecat/services/riva.py b/src/pipecat/services/riva.py index 662ac21f5..cac1946d5 100644 --- a/src/pipecat/services/riva.py +++ b/src/pipecat/services/riva.py @@ -202,8 +202,8 @@ def can_generate_metrics(self) -> bool: async def start(self, frame: StartFrame): await super().start(frame) - self._thread_task = self.get_event_loop().create_task(self._thread_task_handler()) - self._response_task = self.get_event_loop().create_task(self._response_task_handler()) + self._thread_task = self.create_task(self._thread_task_handler()) + self._response_task = self.create_task(self._response_task_handler()) self._response_queue = asyncio.Queue() async def stop(self, frame: EndFrame): @@ -215,10 +215,8 @@ async def cancel(self, frame: CancelFrame): await self._stop_tasks() async def _stop_tasks(self): - self._thread_task.cancel() - await self._thread_task - self._response_task.cancel() - await self._response_task + await self.cancel_task(self._thread_task) + await self.cancel_task(self._response_task) def _response_handler(self): responses = self._asr_service.streaming_response_generator( @@ -238,7 +236,7 @@ async def _thread_task_handler(self): await asyncio.to_thread(self._response_handler) except asyncio.CancelledError: self._thread_running = False - pass + raise async def _handle_response(self, response): for result in response.results: @@ -260,11 +258,8 @@ async def _handle_response(self, response): async def _response_task_handler(self): while True: - try: - response = await self._response_queue.get() - await self._handle_response(response) - except asyncio.CancelledError: - break + response = await self._response_queue.get() + await self._handle_response(response) async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: await self._queue.put(audio) diff --git a/src/pipecat/services/simli.py b/src/pipecat/services/simli.py index be7edfdbd..5a6aecbe4 100644 --- a/src/pipecat/services/simli.py +++ b/src/pipecat/services/simli.py @@ -49,45 +49,33 @@ def __init__( async def _start_connection(self): await self._simli_client.Initialize() # Create task to consume and process audio and video - self._audio_task = asyncio.create_task(self._consume_and_process_audio()) - self._video_task = asyncio.create_task(self._consume_and_process_video()) + self._audio_task = self.create_task(self._consume_and_process_audio()) + self._video_task = self.create_task(self._consume_and_process_video()) async def _consume_and_process_audio(self): - try: - await self._pipecat_resampler_event.wait() - async for audio_frame in self._simli_client.getAudioStreamIterator(): - resampled_frames = self._pipecat_resampler.resample(audio_frame) - for resampled_frame in resampled_frames: - await self.push_frame( - TTSAudioRawFrame( - audio=resampled_frame.to_ndarray().tobytes(), - sample_rate=self._pipecat_resampler.rate, - num_channels=1, - ), - ) - except Exception as e: - logger.exception(f"{self} exception: {e}") - except asyncio.CancelledError: - pass + await self._pipecat_resampler_event.wait() + async for audio_frame in self._simli_client.getAudioStreamIterator(): + resampled_frames = self._pipecat_resampler.resample(audio_frame) + for resampled_frame in resampled_frames: + await self.push_frame( + TTSAudioRawFrame( + audio=resampled_frame.to_ndarray().tobytes(), + sample_rate=self._pipecat_resampler.rate, + num_channels=1, + ), + ) async def _consume_and_process_video(self): - try: - await self._pipecat_resampler_event.wait() - async for video_frame in self._simli_client.getVideoStreamIterator( - targetFormat="rgb24" - ): - # Process the video frame - convertedFrame: OutputImageRawFrame = OutputImageRawFrame( - image=video_frame.to_rgb().to_image().tobytes(), - size=(video_frame.width, video_frame.height), - format="RGB", - ) - convertedFrame.pts = video_frame.pts - await self.push_frame(convertedFrame) - except Exception as e: - logger.exception(f"{self} exception: {e}") - except asyncio.CancelledError: - pass + await self._pipecat_resampler_event.wait() + async for video_frame in self._simli_client.getVideoStreamIterator(targetFormat="rgb24"): + # Process the video frame + convertedFrame: OutputImageRawFrame = OutputImageRawFrame( + image=video_frame.to_rgb().to_image().tobytes(), + size=(video_frame.width, video_frame.height), + format="RGB", + ) + convertedFrame.pts = video_frame.pts + await self.push_frame(convertedFrame) async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -128,8 +116,6 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): async def _stop(self): await self._simli_client.stop() if self._audio_task: - self._audio_task.cancel() - await self._audio_task + await self.cancel_task(self._audio_task) if self._video_task: - self._video_task.cancel() - await self._video_task + await self.cancel_task(self._video_task) diff --git a/src/pipecat/services/websocket_service.py b/src/pipecat/services/websocket_service.py index 365f5a7c8..7e480f2b7 100644 --- a/src/pipecat/services/websocket_service.py +++ b/src/pipecat/services/websocket_service.py @@ -85,10 +85,6 @@ async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Await await self._receive_messages() logger.debug(f"{self} connection established successfully") retry_count = 0 # Reset counter on successful message receive - - except asyncio.CancelledError: - break - except Exception as e: retry_count += 1 if retry_count >= MAX_RETRIES: diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 93442f90a..192101e62 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -50,13 +50,12 @@ async def start(self, frame: StartFrame): # Create audio input queue and task if needed. if self._params.audio_in_enabled or self._params.vad_enabled: self._audio_in_queue = asyncio.Queue() - self._audio_task = self.get_event_loop().create_task(self._audio_task_handler()) + self._audio_task = self.create_task(self._audio_task_handler()) async def stop(self, frame: EndFrame): # Cancel and wait for the audio input task to finish. if self._audio_task and (self._params.audio_in_enabled or self._params.vad_enabled): - self._audio_task.cancel() - await self._audio_task + await self.cancel_task(self._audio_task) self._audio_task = None # Stop audio filter. if self._params.audio_in_filter: @@ -65,8 +64,7 @@ async def stop(self, frame: EndFrame): async def cancel(self, frame: CancelFrame): # Cancel and wait for the audio input task to finish. if self._audio_task and (self._params.audio_in_enabled or self._params.vad_enabled): - self._audio_task.cancel() - await self._audio_task + await self.cancel_task(self._audio_task) self._audio_task = None def vad_analyzer(self) -> VADAnalyzer | None: @@ -173,27 +171,22 @@ async def _handle_vad(self, audio_frame: InputAudioRawFrame, vad_state: VADState async def _audio_task_handler(self): vad_state: VADState = VADState.QUIET while True: - try: - frame: InputAudioRawFrame = await self._audio_in_queue.get() - - audio_passthrough = True - - # If an audio filter is available, run it before VAD. - if self._params.audio_in_filter: - frame.audio = await self._params.audio_in_filter.filter(frame.audio) - - # Check VAD and push event if necessary. We just care about - # changes from QUIET to SPEAKING and vice versa. - if self._params.vad_enabled: - vad_state = await self._handle_vad(frame, vad_state) - audio_passthrough = self._params.vad_audio_passthrough - - # Push audio downstream if passthrough. - if audio_passthrough: - await self.push_frame(frame) - - self._audio_in_queue.task_done() - except asyncio.CancelledError: - break - except Exception as e: - logger.exception(f"{self} error reading audio frames: {e}") + frame: InputAudioRawFrame = await self._audio_in_queue.get() + + audio_passthrough = True + + # If an audio filter is available, run it before VAD. + if self._params.audio_in_filter: + frame.audio = await self._params.audio_in_filter.filter(frame.audio) + + # Check VAD and push event if necessary. We just care about + # changes from QUIET to SPEAKING and vice versa. + if self._params.vad_enabled: + vad_state = await self._handle_vad(frame, vad_state) + audio_passthrough = self._params.vad_audio_passthrough + + # Push audio downstream if passthrough. + if audio_passthrough: + await self.push_frame(frame) + + self._audio_in_queue.task_done() diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index ef073e22a..5181914ed 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -217,22 +217,19 @@ async def _bot_stopped_speaking(self): # def _create_sink_tasks(self): - loop = self.get_event_loop() self._sink_queue = asyncio.Queue() - self._sink_task = loop.create_task(self._sink_task_handler()) self._sink_clock_queue = asyncio.PriorityQueue() - self._sink_clock_task = loop.create_task(self._sink_clock_task_handler()) + self._sink_task = self.create_task(self._sink_task_handler()) + self._sink_clock_task = self.create_task(self._sink_clock_task_handler()) async def _cancel_sink_tasks(self): # Stop sink tasks. if self._sink_task: - self._sink_task.cancel() - await self._sink_task + await self.cancel_task(self._sink_task) self._sink_task = None # Stop sink clock tasks. if self._sink_clock_task: - self._sink_clock_task.cancel() - await self._sink_clock_task + await self.cancel_task(self._sink_clock_task) self._sink_clock_task = None async def _sink_frame_handler(self, frame: Frame): @@ -269,7 +266,7 @@ async def _sink_clock_task_handler(self): self._sink_clock_queue.task_done() except asyncio.CancelledError: - break + raise except Exception as e: logger.exception(f"{self} error processing sink clock queue: {e}") @@ -317,49 +314,42 @@ async def with_mixer(vad_stop_secs: float) -> AsyncGenerator[Frame, None]: return without_mixer(vad_stop_secs) async def _sink_task_handler(self): - try: - async for frame in self._next_frame(): - # Notify the bot started speaking upstream if necessary and that - # it's actually speaking. - if isinstance(frame, TTSAudioRawFrame): - await self._bot_started_speaking() - await self.push_frame(BotSpeakingFrame()) - await self.push_frame(BotSpeakingFrame(), FrameDirection.UPSTREAM) - - # No need to push EndFrame, it's pushed from process_frame(). - if isinstance(frame, EndFrame): - break - - # Handle frame. - await self._sink_frame_handler(frame) - - # Also, push frame downstream in case anyone else needs it. - await self.push_frame(frame) - - # Send audio. - if isinstance(frame, OutputAudioRawFrame): - await self.write_raw_audio_frames(frame.audio) - except asyncio.CancelledError: - pass - except Exception as e: - logger.exception(f"{self} error writing to microphone: {e}") + async for frame in self._next_frame(): + # Notify the bot started speaking upstream if necessary and that + # it's actually speaking. + if isinstance(frame, TTSAudioRawFrame): + await self._bot_started_speaking() + await self.push_frame(BotSpeakingFrame()) + await self.push_frame(BotSpeakingFrame(), FrameDirection.UPSTREAM) + + # No need to push EndFrame, it's pushed from process_frame(). + if isinstance(frame, EndFrame): + break + + # Handle frame. + await self._sink_frame_handler(frame) + + # Also, push frame downstream in case anyone else needs it. + await self.push_frame(frame) + + # Send audio. + if isinstance(frame, OutputAudioRawFrame): + await self.write_raw_audio_frames(frame.audio) # # Camera task # def _create_camera_task(self): - loop = self.get_event_loop() # Create camera output queue and task if needed. if self._params.camera_out_enabled: self._camera_out_queue = asyncio.Queue() - self._camera_out_task = loop.create_task(self._camera_out_task_handler()) + self._camera_out_task = self.create_task(self._camera_out_task_handler()) async def _cancel_camera_task(self): # Stop camera output task. if self._camera_out_task and self._params.camera_out_enabled: - self._camera_out_task.cancel() - await self._camera_out_task + await self.cancel_task(self._camera_out_task) self._camera_out_task = None async def _draw_image(self, frame: OutputImageRawFrame): @@ -387,19 +377,14 @@ async def _camera_out_task_handler(self): self._camera_out_frame_duration = 1 / self._params.camera_out_framerate self._camera_out_frame_reset = self._camera_out_frame_duration * 5 while True: - try: - if self._params.camera_out_is_live: - await self._camera_out_is_live_handler() - elif self._camera_images: - image = next(self._camera_images) - await self._draw_image(image) - await asyncio.sleep(self._camera_out_frame_duration) - else: - await asyncio.sleep(self._camera_out_frame_duration) - except asyncio.CancelledError: - break - except Exception as e: - logger.exception(f"{self} error writing to camera: {e}") + if self._params.camera_out_is_live: + await self._camera_out_is_live_handler() + elif self._camera_images: + image = next(self._camera_images) + await self._draw_image(image) + await asyncio.sleep(self._camera_out_frame_duration) + else: + await asyncio.sleep(self._camera_out_frame_duration) async def _camera_out_is_live_handler(self): image = await self._camera_out_queue.get() diff --git a/src/pipecat/transports/network/fastapi_websocket.py b/src/pipecat/transports/network/fastapi_websocket.py index fc5804611..42f1d3171 100644 --- a/src/pipecat/transports/network/fastapi_websocket.py +++ b/src/pipecat/transports/network/fastapi_websocket.py @@ -68,11 +68,9 @@ def __init__( async def start(self, frame: StartFrame): await super().start(frame) if self._params.session_timeout: - self._monitor_websocket_task = self.get_event_loop().create_task( - self._monitor_websocket() - ) + self._monitor_websocket_task = self.create_task(self._monitor_websocket()) await self._callbacks.on_client_connected(self._websocket) - self._receive_task = self.get_event_loop().create_task(self._receive_messages()) + self._receive_task = self.create_task(self._receive_messages()) def _iter_data(self) -> typing.AsyncIterator[bytes | str]: if self._params.serializer.type == FrameSerializerType.BINARY: @@ -96,11 +94,8 @@ async def _receive_messages(self): async def _monitor_websocket(self): """Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event.""" - try: - await asyncio.sleep(self._params.session_timeout) - await self._callbacks.on_session_timeout(self._websocket) - except asyncio.CancelledError: - logger.info(f"Monitoring task cancelled for: {self._websocket}") + await asyncio.sleep(self._params.session_timeout) + await self._callbacks.on_session_timeout(self._websocket) class FastAPIWebsocketOutputTransport(BaseOutputTransport): diff --git a/src/pipecat/transports/network/websocket_server.py b/src/pipecat/transports/network/websocket_server.py index 9bd344cb1..6fe5f91ef 100644 --- a/src/pipecat/transports/network/websocket_server.py +++ b/src/pipecat/transports/network/websocket_server.py @@ -71,7 +71,7 @@ def __init__( async def start(self, frame: StartFrame): await super().start(frame) - self._server_task = self.get_event_loop().create_task(self._server_task_handler()) + self._server_task = self.create_task(self._server_task_handler()) async def stop(self, frame: EndFrame): await super().stop(frame) @@ -131,6 +131,7 @@ async def _monitor_websocket(self, websocket: websockets.WebSocketServerProtocol await self._callbacks.on_session_timeout(websocket) except asyncio.CancelledError: logger.info(f"Monitoring task cancelled for: {websocket.remote_address}") + raise class WebsocketServerOutputTransport(BaseOutputTransport): diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index 47575bcce..f3915f662 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -46,6 +46,7 @@ from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.utils.utils import cancel_task, create_task try: from daily import CallClient, Daily, EventHandler @@ -218,7 +219,9 @@ def __init__( # future) we will deadlock because completions use event handlers (which # are holding the GIL). self._callback_queue = asyncio.Queue() - self._callback_task = self._loop.create_task(self._callback_task_handler()) + self._callback_task = create_task( + self._loop, self._callback_task_handler(), "DailyTransportClient::callback_task" + ) self._camera: VirtualCameraDevice | None = None if self._params.camera_out_enabled: @@ -469,8 +472,7 @@ async def _leave(self): return await asyncio.wait_for(future, timeout=10) async def cleanup(self): - self._callback_task.cancel() - await self._callback_task + await cancel_task(self._callback_task) # Make sure we don't block the event loop in case `client.release()` # takes extra time. await self._loop.run_in_executor(self._executor, self._cleanup) @@ -687,11 +689,8 @@ def _call_async_callback(self, callback, *args): async def _callback_task_handler(self): while True: - try: - (callback, *args) = await self._callback_queue.get() - await callback(*args) - except asyncio.CancelledError: - break + (callback, *args) = await self._callback_queue.get() + await callback(*args) class DailyInputTransport(BaseInputTransport): @@ -721,7 +720,7 @@ async def start(self, frame: StartFrame): # Create audio task. It reads audio frames from Daily and push them # internally for VAD processing. if self._params.audio_in_enabled or self._params.vad_enabled: - self._audio_in_task = self.get_event_loop().create_task(self._audio_in_task_handler()) + self._audio_in_task = self.create_task(self._audio_in_task_handler()) async def stop(self, frame: EndFrame): # Parent stop. @@ -730,8 +729,7 @@ async def stop(self, frame: EndFrame): await self._client.leave() # Stop audio thread. if self._audio_in_task and (self._params.audio_in_enabled or self._params.vad_enabled): - self._audio_in_task.cancel() - await self._audio_in_task + await self.cancel_task(self._audio_in_task) self._audio_in_task = None async def cancel(self, frame: CancelFrame): @@ -741,8 +739,7 @@ async def cancel(self, frame: CancelFrame): await self._client.leave() # Stop audio thread. if self._audio_in_task and (self._params.audio_in_enabled or self._params.vad_enabled): - self._audio_in_task.cancel() - await self._audio_in_task + await self.cancel_task(self._audio_in_task) self._audio_in_task = None async def cleanup(self): @@ -779,12 +776,9 @@ async def push_app_message(self, message: Any, sender: str): async def _audio_in_task_handler(self): while True: - try: - frame = await self._client.read_next_audio_frame() - if frame: - await self.push_audio_frame(frame) - except asyncio.CancelledError: - break + frame = await self._client.read_next_audio_frame() + if frame: + await self.push_audio_frame(frame) # # Camera in diff --git a/src/pipecat/transports/services/livekit.py b/src/pipecat/transports/services/livekit.py index 0b036cecb..172f5b473 100644 --- a/src/pipecat/transports/services/livekit.py +++ b/src/pipecat/transports/services/livekit.py @@ -319,23 +319,21 @@ async def start(self, frame: StartFrame): await super().start(frame) await self._client.connect() if self._params.audio_in_enabled or self._params.vad_enabled: - self._audio_in_task = asyncio.create_task(self._audio_in_task_handler()) + self._audio_in_task = self.create_task(self._audio_in_task_handler()) logger.info("LiveKitInputTransport started") async def stop(self, frame: EndFrame): await super().stop(frame) await self._client.disconnect() if self._audio_in_task: - self._audio_in_task.cancel() - await self._audio_in_task + await self.cancel_task(self._audio_in_task) logger.info("LiveKitInputTransport stopped") async def cancel(self, frame: CancelFrame): await super().cancel(frame) await self._client.disconnect() if self._audio_in_task and (self._params.audio_in_enabled or self._params.vad_enabled): - self._audio_in_task.cancel() - await self._audio_in_task + await self.cancel_task(self._audio_in_task) def vad_analyzer(self) -> VADAnalyzer | None: return self._vad_analyzer @@ -347,22 +345,16 @@ async def push_app_message(self, message: Any, sender: str): async def _audio_in_task_handler(self): logger.info("Audio input task started") while True: - try: - audio_data = await self._client.get_next_audio_frame() - if audio_data: - audio_frame_event, participant_id = audio_data - pipecat_audio_frame = self._convert_livekit_audio_to_pipecat(audio_frame_event) - input_audio_frame = InputAudioRawFrame( - audio=pipecat_audio_frame.audio, - sample_rate=pipecat_audio_frame.sample_rate, - num_channels=pipecat_audio_frame.num_channels, - ) - await self.push_audio_frame(input_audio_frame) - except asyncio.CancelledError: - logger.info("Audio input task cancelled") - break - except Exception as e: - logger.error(f"Error in audio input task: {e}") + audio_data = await self._client.get_next_audio_frame() + if audio_data: + audio_frame_event, participant_id = audio_data + pipecat_audio_frame = self._convert_livekit_audio_to_pipecat(audio_frame_event) + input_audio_frame = InputAudioRawFrame( + audio=pipecat_audio_frame.audio, + sample_rate=pipecat_audio_frame.sample_rate, + num_channels=pipecat_audio_frame.num_channels, + ) + await self.push_audio_frame(input_audio_frame) def _convert_livekit_audio_to_pipecat( self, audio_frame_event: rtc.AudioFrameEvent diff --git a/src/pipecat/utils/utils.py b/src/pipecat/utils/utils.py index 5470e63a5..59266601b 100644 --- a/src/pipecat/utils/utils.py +++ b/src/pipecat/utils/utils.py @@ -3,8 +3,13 @@ # # SPDX-License-Identifier: BSD 2-Clause License # + +import asyncio import collections import itertools +from typing import Coroutine, Optional + +from loguru import logger _COUNTS = collections.defaultdict(itertools.count) _ID = itertools.count() @@ -35,3 +40,37 @@ def obj_count(obj) -> int: 0 """ return next(_COUNTS[obj.__class__.__name__]) + + +def create_task(loop: asyncio.AbstractEventLoop, coroutine: Coroutine, name: str) -> asyncio.Task: + async def run_coroutine(): + try: + await coroutine + except asyncio.CancelledError: + logger.trace(f"{name}: cancelling task") + # Re-raise the exception to ensure the task is cancelled. + raise + except Exception as e: + logger.exception(f"{name}: unexpected exception: {e}") + + task = loop.create_task(run_coroutine()) + task.set_name(name) + logger.trace(f"{name}: task created") + return task + + +async def cancel_task(task: asyncio.Task, timeout: Optional[float] = None): + name = task.get_name() + task.cancel() + try: + if timeout: + await asyncio.wait_for(task, timeout=timeout) + else: + await task + except asyncio.TimeoutError: + logger.warning(f"{name}: timed out waiting for task to finish") + except asyncio.CancelledError: + # Here are sure the task is cancelled properly. + logger.trace(f"{name}: task cancelled") + except Exception as e: + logger.exception(f"{name}: unexpected exception while cancelling task: {e}")