From a980a76ce1481a18a07971758948dff8cce47c55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Thu, 26 Dec 2024 16:10:13 +0100 Subject: [PATCH 01/19] initial flexible IO implementation (PipelineAgent) --- examples/minimal_worker.py | 15 +- livekit-agents/livekit/agents/__init__.py | 2 - livekit-agents/livekit/agents/cli/cli.py | 145 +++++--- livekit-agents/livekit/agents/cli/proto.py | 13 +- livekit-agents/livekit/agents/io.py | 35 ++ .../livekit/agents/pipeline/__init__.py | 13 +- .../agents/pipeline/audio_recognition.py | 219 +++++++++++ .../livekit/agents/pipeline/chat_cli.py | 219 +++++++++++ .../livekit/agents/pipeline/generation.py | 63 ++++ .../livekit/agents/pipeline/human_input.py | 8 +- .../livekit/agents/pipeline/impl.py | 183 ++++++++++ livekit-agents/livekit/agents/pipeline/io.py | 51 +++ .../livekit/agents/pipeline/multimodal.py | 13 + .../livekit/agents/pipeline/pipeline2.py | 342 ++++++++++++++++++ .../agents/voice_assistant/__init__.py | 11 - livekit-agents/livekit/agents/worker.py | 57 ++- 16 files changed, 1291 insertions(+), 98 deletions(-) create mode 100644 livekit-agents/livekit/agents/io.py create mode 100644 livekit-agents/livekit/agents/pipeline/audio_recognition.py create mode 100644 livekit-agents/livekit/agents/pipeline/chat_cli.py create mode 100644 livekit-agents/livekit/agents/pipeline/generation.py create mode 100644 livekit-agents/livekit/agents/pipeline/impl.py create mode 100644 livekit-agents/livekit/agents/pipeline/io.py create mode 100644 livekit-agents/livekit/agents/pipeline/multimodal.py create mode 100644 livekit-agents/livekit/agents/pipeline/pipeline2.py delete mode 100644 livekit-agents/livekit/agents/voice_assistant/__init__.py diff --git a/examples/minimal_worker.py b/examples/minimal_worker.py index e3a9ed3b9..2f8ada9ce 100644 --- a/examples/minimal_worker.py +++ b/examples/minimal_worker.py @@ -1,18 +1,25 @@ import logging +from dotenv import load_dotenv from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, WorkerType, cli +from livekit.agents.pipeline import ChatCLI, PipelineAgent +from livekit.plugins import deepgram, openai logger = logging.getLogger("my-worker") logger.setLevel(logging.INFO) +load_dotenv() -async def entrypoint(ctx: JobContext): - logger.info("starting entrypoint") +async def entrypoint(ctx: JobContext): await ctx.connect(auto_subscribe=AutoSubscribe.SUBSCRIBE_ALL) - logger.info("connected to the room") - # add your agent logic here! + agent = PipelineAgent(llm=openai.LLM(), stt=deepgram.STT()) + agent.start() + + # start a chat inside the CLI + chat_cli = ChatCLI(agent) + await chat_cli.run() if __name__ == "__main__": diff --git a/livekit-agents/livekit/agents/__init__.py b/livekit-agents/livekit/agents/__init__.py index 780a5f546..5c504b162 100644 --- a/livekit-agents/livekit/agents/__init__.py +++ b/livekit-agents/livekit/agents/__init__.py @@ -25,7 +25,6 @@ tts, utils, vad, - voice_assistant, ) from ._exceptions import ( APIConnectionError, @@ -72,7 +71,6 @@ "transcription", "pipeline", "multimodal", - "voice_assistant", "cli", "AssignmentTimeoutError", "APIConnectionError", diff --git a/livekit-agents/livekit/agents/cli/cli.py b/livekit-agents/livekit/agents/cli/cli.py index 578ce5ced..a29ef17e8 100644 --- a/livekit-agents/livekit/agents/cli/cli.py +++ b/livekit-agents/livekit/agents/cli/cli.py @@ -2,14 +2,14 @@ import pathlib import signal import sys +import threading import click -from livekit.protocol import models from .. import utils from ..log import logger from ..plugin import Plugin -from ..worker import Worker, WorkerOptions +from ..worker import JobExecutorType, Worker, WorkerOptions from . import proto from .log import setup_logging @@ -58,6 +58,7 @@ def start( log_level=log_level, devmode=False, asyncio_debug=False, + register=True, watch=False, drain_timeout=drain_timeout, ) @@ -105,7 +106,63 @@ def dev( asyncio_debug: bool, watch: bool, ) -> None: - _run_dev(opts, log_level, url, api_key, api_secret, asyncio_debug, watch) + opts.ws_url = url or opts.ws_url + opts.api_key = api_key or opts.api_key + opts.api_secret = api_secret or opts.api_secret + args = proto.CliArgs( + opts=opts, + log_level=log_level, + devmode=True, + asyncio_debug=asyncio_debug, + watch=watch, + drain_timeout=0, + register=False, + ) + + _run_dev(args) + + @cli.command(help="Start a new chat") + @click.option( + "--url", + envvar="LIVEKIT_URL", + help="LiveKit server or Cloud project's websocket URL", + ) + @click.option( + "--api-key", + envvar="LIVEKIT_API_KEY", + help="LiveKit server or Cloud project's API key", + ) + @click.option( + "--api-secret", + envvar="LIVEKIT_API_SECRET", + help="LiveKit server or Cloud project's API secret", + ) + def chat( + url: str, + api_key: str, + api_secret: str, + ) -> None: + # keep everything inside the same process when using the chat mode + opts.job_executor_type = JobExecutorType.THREAD + opts.ws_url = url or opts.ws_url + opts.api_key = api_key or opts.api_key + opts.api_secret = api_secret or opts.api_secret + + chat_name = utils.shortuuid("chat_cli_") + + args = proto.CliArgs( + opts=opts, + log_level="ERROR", + devmode=True, + asyncio_debug=False, + watch=False, + drain_timeout=0, + register=False, + simulate_job=proto.SimulateJobArgs( + room=chat_name, + ), + ) + _run_dev(args) @cli.command(help="Connect to a specific room") @click.option( @@ -155,18 +212,25 @@ def connect( room: str, participant_identity: str, ) -> None: - _run_dev( - opts, - log_level, - url, - api_key, - api_secret, - asyncio_debug, - watch, - room, - participant_identity, + opts.ws_url = url or opts.ws_url + opts.api_key = api_key or opts.api_key + opts.api_secret = api_secret or opts.api_secret + args = proto.CliArgs( + opts=opts, + log_level=log_level, + devmode=True, + register=False, + asyncio_debug=asyncio_debug, + watch=watch, + drain_timeout=0, + simulate_job=proto.SimulateJobArgs( + room=room, + participant_identity=participant_identity, + ), ) + _run_dev(args) + @cli.command(help="Download plugin dependency files") @click.option( "--log-level", @@ -188,34 +252,12 @@ def download_files(log_level: str) -> None: def _run_dev( - opts: WorkerOptions, - log_level: str, - url: str, - api_key: str, - api_secret: str, - asyncio_debug: bool, - watch: bool, - room: str = "", - participant_identity: str = "", + args: proto.CliArgs, ): - opts.ws_url = url or opts.ws_url - opts.api_key = api_key or opts.api_key - opts.api_secret = api_secret or opts.api_secret - args = proto.CliArgs( - opts=opts, - log_level=log_level, - devmode=True, - asyncio_debug=asyncio_debug, - watch=watch, - drain_timeout=0, - room=room, - participant_identity=participant_identity, - ) - - if watch: + if args.watch: from .watcher import WatchServer - setup_logging(log_level, args.devmode) + setup_logging(args.log_level, args.devmode) main_file = pathlib.Path(sys.argv[0]).parent async def _run_loop(): @@ -236,27 +278,34 @@ def run_worker(args: proto.CliArgs) -> None: setup_logging(args.log_level, args.devmode) args.opts.validate_config(args.devmode) - loop = asyncio.get_event_loop() - worker = Worker(args.opts, devmode=args.devmode, loop=loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + worker = Worker(args.opts, devmode=args.devmode, register=args.register, loop=loop) loop.set_debug(args.asyncio_debug) loop.slow_callback_duration = 0.1 # 100ms utils.aio.debug.hook_slow_callbacks(2) - if args.room and args.reload_count == 0: - # directly connect to a specific room - @worker.once("worker_registered") - def _connect_on_register(worker_id: str, server_info: models.ServerInfo): - logger.info("connecting to room %s", args.room) - loop.create_task(worker.simulate_job(args.room, args.participant_identity)) + @worker.once("worker_started") + def _worker_started(): + if args.simulate_job and args.reload_count == 0: + logger.info("connecting to room %s", args.simulate_job.room) + loop.create_task( + worker.simulate_job( + args.simulate_job.room, args.simulate_job.participant_identity + ) + ) try: def _signal_handler(): raise KeyboardInterrupt - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, _signal_handler) + if threading.current_thread() is threading.main_thread(): + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, _signal_handler) + except NotImplementedError: # TODO(theomonnom): add_signal_handler is not implemented on win pass diff --git a/livekit-agents/livekit/agents/cli/proto.py b/livekit-agents/livekit/agents/cli/proto.py index f7753c579..6545920d9 100644 --- a/livekit-agents/livekit/agents/cli/proto.py +++ b/livekit-agents/livekit/agents/cli/proto.py @@ -12,6 +12,12 @@ from ..worker import WorkerOptions +@dataclass +class SimulateJobArgs: + room: str = "" + participant_identity: str = "" + + @dataclass class CliArgs: opts: WorkerOptions @@ -20,8 +26,11 @@ class CliArgs: asyncio_debug: bool watch: bool drain_timeout: int - room: str = "" - participant_identity: str = "" + + # register the worker to the worker pool + register: bool = True + + simulate_job: SimulateJobArgs | None = None # amount of time this worker has been reloaded reload_count: int = 0 diff --git a/livekit-agents/livekit/agents/io.py b/livekit-agents/livekit/agents/io.py new file mode 100644 index 000000000..3d4cc6820 --- /dev/null +++ b/livekit-agents/livekit/agents/io.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import AsyncIterable, Protocol + +from livekit import rtc + + +@dataclass +class TextChunk: + text: str + is_final: bool + + +AudioStream = AsyncIterable[rtc.AudioFrame] +VideoStream = AsyncIterable[rtc.VideoFrame] +TextStream = AsyncIterable[TextChunk] + + +class AudioSink(Protocol): + async def capture_frame(self, audio: rtc.AudioFrame) -> None: ... + + def flush(self) -> None: ... + + +class TextSink(Protocol): + async def capture_text(self, text: str) -> None: ... + + def flush(self) -> None: ... + + +class VideoSink(Protocol): + async def capture_frame(self, text: rtc.VideoFrame) -> None: ... + + def flush(self) -> None: ... diff --git a/livekit-agents/livekit/agents/pipeline/__init__.py b/livekit-agents/livekit/agents/pipeline/__init__.py index 480dd7990..d6b34ab53 100644 --- a/livekit-agents/livekit/agents/pipeline/__init__.py +++ b/livekit-agents/livekit/agents/pipeline/__init__.py @@ -1,11 +1,4 @@ -from .pipeline_agent import ( - AgentCallContext, - AgentTranscriptionOptions, - VoicePipelineAgent, -) +from .chat_cli import ChatCLI +from .pipeline2 import PipelineAgent -__all__ = [ - "VoicePipelineAgent", - "AgentCallContext", - "AgentTranscriptionOptions", -] +__all__ = ["ChatCLI", "PipelineAgent"] diff --git a/livekit-agents/livekit/agents/pipeline/audio_recognition.py b/livekit-agents/livekit/agents/pipeline/audio_recognition.py new file mode 100644 index 000000000..9b238ba25 --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/audio_recognition.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Literal, Protocol + +from livekit import rtc + +from .. import llm, stt, utils, vad +from ..log import logger +from ..utils import aio +from . import io + +if TYPE_CHECKING: + from .pipeline2 import PipelineAgent + + +class _TurnDetector(Protocol): + # TODO: Move those two functions to EOU ctor (capabilities dataclass) + def unlikely_threshold(self) -> float: ... + def supports_language(self, language: str | None) -> bool: ... + + async def predict_end_of_turn(self, chat_ctx: llm.ChatContext) -> float: ... + + +EventTypes = Literal[ + "start_of_speech", + "vad_inference_done", + "end_of_speech", + "interim_transcript", + "final_transcript", + "end_of_turn", +] + + +class AudioRecognition(rtc.EventEmitter[EventTypes]): + """ + Audio recognition part of the PipelineAgent. + The class is always instantiated but no tasks may be running if STT/VAD is disabled + + This class is also responsible for the end of turn detection. + """ + + UNLIKELY_END_OF_TURN_EXTRA_DELAY = 6.0 + + def __init__( + self, + *, + agent: PipelineAgent, + stt: io.STTNode, + vad: vad.VAD | None, + turn_detector: _TurnDetector | None, + min_endpointing_delay: float, + chat_ctx: llm.ChatContext, + loop: asyncio.AbstractEventLoop, + ) -> None: + super().__init__() + self._agent = agent + self._stt_atask: asyncio.Task[None] | None = None + self._vad_atask: asyncio.Task[None] | None = None + self._end_of_turn_task: asyncio.Task[None] | None = None + self._audio_input: io.AudioStream | None = None + self._min_endpointing_delay = min_endpointing_delay + self._chat_ctx = chat_ctx + self._loop = loop + self._stt = stt + self._vad = vad + self._turn_detector = turn_detector + + self._speaking = False + self._audio_transcript = "" + self._last_language: str | None = None + + def start(self) -> None: + self.update_stt(self._stt) + self.update_vad(self._vad) + + @property + def audio_input(self) -> io.AudioStream | None: + return self._audio_input + + @audio_input.setter + def audio_input(self, audio_input: io.AudioStream | None) -> None: + self._audio_input = audio_input + self.update_stt(self._stt) + self.update_vad(self._vad) + + async def aclose(self) -> None: + if self._stt_atask is not None: + await aio.gracefully_cancel(self._stt_atask) + + if self._vad_atask is not None: + await aio.gracefully_cancel(self._vad_atask) + + if self._end_of_turn_task is not None: + await aio.gracefully_cancel(self._end_of_turn_task) + + def update_stt(self, stt: io.STTNode | None) -> None: + self._stt = stt + + if self._audio_input and stt: + self._stt_atask = asyncio.create_task( + self._stt_task(stt, self._audio_input, self._stt_atask) + ) + elif self._stt_atask is not None: + self._stt_atask.cancel() + + def update_vad(self, vad: vad.VAD | None) -> None: + self._vad = vad + + if self._audio_input and vad: + self._vad_atask = asyncio.create_task( + self._vad_task(vad, self._audio_input, self._vad_atask) + ) + elif self._vad_atask is not None: + self._vad_atask.cancel() + + async def _on_stt_event(self, ev: stt.SpeechEvent) -> None: + if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: + self.emit("final_transcript", ev) + transcript = ev.alternatives[0].text + if not transcript: + return + + logger.debug( + "received user transcript", + extra={"user_transcript": transcript}, + ) + + self._audio_transcript += f" {transcript}" + self._audio_transcript = self._audio_transcript.lstrip() + + if not self._speaking: + self._run_eou_detection(self._agent.chat_ctx, self._audio_transcript) + elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: + self.emit("interim_transcript", ev) + + async def _on_vad_event(self, ev: vad.VADEvent) -> None: + if ev.type == vad.VADEventType.START_OF_SPEECH: + self.emit("start_of_speech", ev) + self._speaking = True + + if self._end_of_turn_task is not None: + self._end_of_turn_task.cancel() + + elif ev.type == vad.VADEventType.INFERENCE_DONE: + self.emit("vad_inference_done", ev) + + elif ev.type == vad.VADEventType.END_OF_SPEECH: + self.emit("end_of_speech", ev) + self._speaking = False + + def _run_eou_detection( + self, chat_ctx: llm.ChatContext, new_transcript: str + ) -> None: + chat_ctx = self._chat_ctx.copy() + chat_ctx.append(role="user", text=new_transcript) + turn_detector = self._turn_detector + + @utils.log_exceptions(logger=logger) + async def _bounce_eou_task() -> None: + await asyncio.sleep(self._min_endpointing_delay) + + if turn_detector is not None and turn_detector.supports_language( + self._last_language + ): + end_of_turn_probability = await turn_detector.predict_end_of_turn( + chat_ctx + ) + unlikely_threshold = turn_detector.unlikely_threshold() + if end_of_turn_probability > unlikely_threshold: + await asyncio.sleep(self.UNLIKELY_END_OF_TURN_EXTRA_DELAY) + + self.emit("end_of_turn", new_transcript) + + if self._end_of_turn_task is not None: + self._end_of_turn_task.cancel() + + self._end_of_turn_task = asyncio.create_task(_bounce_eou_task()) + + async def _stt_task( + self, + stt_node: io.STTNode, + audio_input: io.AudioStream, + task: asyncio.Task[None] | None, + ) -> None: + if task is not None: + await aio.gracefully_cancel(task) + + node = stt_node(audio_input) + if asyncio.iscoroutine(node): + node = await node + + if node is not None: + async for ev in node: + assert isinstance( + ev, stt.SpeechEvent + ), "STT node must yield SpeechEvent" + await self._on_stt_event(ev) + + async def _vad_task( + self, vad: vad.VAD, audio_input: io.AudioStream, task: asyncio.Task[None] | None + ) -> None: + if task is not None: + await aio.gracefully_cancel(task) + + stream = vad.stream() + + async def _forward() -> None: + async for frame in audio_input: + stream.push_frame(frame) + + forward_task = asyncio.create_task(_forward()) + + try: + async for ev in stream: + await self._on_vad_event(ev) + finally: + await stream.aclose() + await aio.gracefully_cancel(forward_task) diff --git a/livekit-agents/livekit/agents/pipeline/chat_cli.py b/livekit-agents/livekit/agents/pipeline/chat_cli.py new file mode 100644 index 000000000..2a98b0edf --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/chat_cli.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import asyncio +import sys +import termios +import threading +import time +import tty +from typing import Literal + +import click +import numpy as np +import sounddevice as sd +from livekit import rtc + +from ..utils import aio +from . import io +from .pipeline2 import PipelineAgent + +MAX_AUDIO_BAR = 30 +INPUT_DB_MIN = -70.0 +INPUT_DB_MAX = 0.0 +FPS = 20 + + +def _esc(*codes: int) -> str: + return "\033[" + ";".join(str(c) for c in codes) + "m" + + +def _normalize_db(amplitude_db: float, db_min: float, db_max: float) -> float: + amplitude_db = max(db_min, min(amplitude_db, db_max)) + return (amplitude_db - db_min) / (db_max - db_min) + + +class ChatCLI(io.TextSink): + def __init__( + self, + agent: PipelineAgent, + *, + loop: asyncio.AbstractEventLoop | None = None, + ) -> None: + self._loop = loop or asyncio.get_event_loop() + self._agent = agent + self._generation_done_ev = threading.Event() + self._done_fut = asyncio.Future() + self._micro_db = INPUT_DB_MIN + + self._input_ch = aio.Chan[rtc.AudioFrame](loop=self._loop) + self._input_stream: sd.InputStream | None = None + self._input_mode: Literal["audio", "text"] = "audio" + self._text_buffer = [] # in text mode + + self._text_capturing = False + + def _print_welcome(self): + print(_esc(34) + "=" * 50 + _esc(0)) + print(_esc(34) + " Livekit Agents - ChatCLI" + _esc(0)) + print(_esc(34) + "=" * 50 + _esc(0)) + print("Press [Ctrl+B] to toggle between Text/Audio mode, [Q] to quit.\n") + + async def run(self) -> None: + self._print_welcome() + + fd = sys.stdin.fileno() + stdin_ch = aio.Chan[str](loop=self._loop) + + def _on_input(): + try: + ch = sys.stdin.read(1) + stdin_ch.send_nowait(ch) + except Exception: + stdin_ch.close() + + self._loop.add_reader(fd, _on_input) + old_settings = termios.tcgetattr(fd) + + self._update_microphone(enable=True) + + try: + tty.setcbreak(fd) + input_cli_task = asyncio.create_task(self._input_cli_task(stdin_ch)) + input_cli_task.add_done_callback(lambda _: self._done_fut.set_result(None)) + + render_cli_task = asyncio.create_task(self._render_cli_task()) + + await self._done_fut + await aio.gracefully_cancel(render_cli_task) + + self._update_microphone(enable=False) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + self._loop.remove_reader(fd) + + def _update_microphone(self, *, enable: bool) -> None: + input_device, _ = sd.default.device + if input_device is not None and enable: + device_info = sd.query_devices(input_device) + assert isinstance(device_info, dict) + + self._input_device_name = device_info.get("name", "Microphone") + self._input_stream = sd.InputStream( + callback=self._input_sd_callback, + dtype="int16", + channels=1, + device=input_device, + samplerate=24000, + ) + self._input_stream.start() + self._agent.input.audio = self._input_ch + elif self._input_stream is not None: + self._input_stream.stop() + self._input_stream.close() + self._input_stream = None + self._agent.input.audio = None + + def _update_text_output(self, *, enable: bool) -> None: + if enable: + self._agent.output.text = self + else: + self._agent.output.text = None + self._text_buffer = [] + self._text_capturing = False + + def _input_sd_callback(self, indata: np.ndarray, frame_count: int, *_) -> None: + rms = np.sqrt(np.mean(indata.astype(np.float32) ** 2)) + max_int16 = np.iinfo(np.int16).max + self._micro_db = 20.0 * np.log10(rms / max_int16 + 1e-6) + self._loop.call_soon_threadsafe( + self._input_ch.send_nowait, + rtc.AudioFrame( + data=indata.tobytes(), + samples_per_channel=frame_count, + sample_rate=24000, + num_channels=1, + ), + ) + + async def _input_cli_task(self, in_ch: aio.Chan[str]) -> None: + while True: + char = await in_ch.recv() + if char is None: + break + + if char == "\x02": # Ctrl+B + if self._input_mode == "audio": + self._input_mode = "text" + self._update_text_output(enable=True) + self._update_microphone(enable=False) + click.echo("\nSwitched to Text Input Mode.", nl=False) + else: + self._input_mode = "audio" + self._update_text_output(enable=False) + self._update_microphone(enable=True) + self._text_buffer = [] + click.echo("\nSwitched to Audio Input Mode.", nl=False) + + if self._input_mode == "text": # Read input + if char in ("\r", "\n"): + text = "".join(self._text_buffer) + if text: + self._text_buffer = [] + self._agent.generate_reply(text) + click.echo("\n", nl=False) + elif char == "\x7f": # Backspace + if self._text_buffer: + self._text_buffer.pop() + sys.stdout.write("\b \b") + sys.stdout.flush() + elif char.isprintable(): + self._text_buffer.append(char) + click.echo(char, nl=False) + sys.stdout.flush() + + async def _render_cli_task(self) -> None: + next_frame = time.perf_counter() + while True: + next_frame += 1 / FPS + if self._input_mode == "audio": + self._print_audio_mode() + elif self._input_mode == "text" and not self._text_capturing: + self._print_text_mode() + + await asyncio.sleep(max(0, next_frame - time.perf_counter())) + + def _print_audio_mode(self): + amplitude_db = _normalize_db( + self._micro_db, db_min=INPUT_DB_MIN, db_max=INPUT_DB_MAX + ) + nb_bar = round(amplitude_db * MAX_AUDIO_BAR) + + color_code = 31 if amplitude_db > 0.75 else 33 if amplitude_db > 0.5 else 32 + bar = "#" * nb_bar + "-" * (MAX_AUDIO_BAR - nb_bar) + sys.stdout.write( + f"\r[Audio] {self._input_device_name[-20:]} [{self._micro_db:6.2f} dBFS] {_esc(color_code)}[{bar}]{_esc(0)}" + ) + sys.stdout.flush() + + def _print_text_mode(self): + sys.stdout.write("\r") + sys.stdout.flush() + prompt = "Enter your message: " + sys.stdout.write(f"[Text] {prompt}{''.join(self._text_buffer)}") + sys.stdout.flush() + + # io.Text Sink implementation + + async def capture_text(self, text: str) -> None: + if not self._text_capturing: + self._text_capturing = True + sys.stdout.write("\r") + sys.stdout.flush() + click.echo(_esc(36), nl=False) + + click.echo(text, nl=False) + + def flush(self) -> None: + if self._text_capturing: + click.echo(_esc(0)) + self._text_capturing = False diff --git a/livekit-agents/livekit/agents/pipeline/generation.py b/livekit-agents/livekit/agents/pipeline/generation.py new file mode 100644 index 000000000..1c1608184 --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/generation.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import AsyncIterable + +from ..llm import ChatChunk, ChatContext, FunctionCallInfo, FunctionContext +from ..utils import aio +from . import io + + +@dataclass +class _LLMGenerationData: + chat_ctx: ChatContext + fnc_ctx: FunctionContext | None + text_ch: aio.Chan[str] + tools_ch: aio.Chan[FunctionCallInfo] + + +async def do_llm_inference(*, node: io.LLMNode, data: _LLMGenerationData) -> bool: + llm_node = node(data.chat_ctx, data.fnc_ctx) + if asyncio.iscoroutine(llm_node): + llm_node = await llm_node + + if isinstance(llm_node, str): + data.text_ch.send_nowait(llm_node) + return True + + if isinstance(llm_node, AsyncIterable): + # forward llm stream to output channels + async for chunk in llm_node: + # io.LLMNode can either return a string or a ChatChunk + if isinstance(chunk, str): + data.text_ch.send_nowait(chunk) + + elif isinstance(chunk, ChatChunk): + if not chunk.choices: + continue # this can happens if we receive the stats chunk + + delta = chunk.choices[0].delta + + if delta.tool_calls: + for tool in delta.tool_calls: + data.tools_ch.send_nowait(tool) + + if delta.content: + data.text_ch.send_nowait(delta.content) + + return True + + return False + + +@dataclass +class _TTSGenerationData: + input_ch: AsyncIterable[str] + audio_ch: aio.Chan[bytes] + +async def do_tts_inference(*, node: io.TTSNode, data: _TTSGenerationData) -> bool: + tts_node = node(data.input_ch) + + + return False diff --git a/livekit-agents/livekit/agents/pipeline/human_input.py b/livekit-agents/livekit/agents/pipeline/human_input.py index b54ba6f28..bd875cf99 100644 --- a/livekit-agents/livekit/agents/pipeline/human_input.py +++ b/livekit-agents/livekit/agents/pipeline/human_input.py @@ -19,15 +19,9 @@ ] -class HumanInput(utils.EventEmitter[EventTypes]): +class AudioRecognition(utils.EventEmitter[EventTypes]): def __init__( self, - *, - room: rtc.Room, - vad: voice_activity_detection.VAD, - stt: speech_to_text.STT, - participant: rtc.RemoteParticipant, - transcription: bool, ) -> None: super().__init__() self._room, self._vad, self._stt, self._participant, self._transcription = ( diff --git a/livekit-agents/livekit/agents/pipeline/impl.py b/livekit-agents/livekit/agents/pipeline/impl.py new file mode 100644 index 000000000..14641a889 --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/impl.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import asyncio +from typing import Protocol + +from .. import io, llm, stt, utils, vad +from ..utils import aio + + +class _TurnDetector(Protocol): + # TODO: Move those two functions to EOU ctor (capabilities dataclass) + def unlikely_threshold(self) -> float: ... + def supports_language(self, language: str | None) -> bool: ... + + async def predict_end_of_turn(self, chat_ctx: llm.ChatContext) -> float: ... + + +class AudioRecognition: + """ + Audio recognition part of the PipelineAgent. + The class is always instantiated but no tasks may be running if STT/VAD is disabled + + This class is also responsible for the end of turn detection. + """ + + UNLIKELY_END_OF_TURN_EXTRA_DELAY = 6.0 + + def __init__( + self, + *, + pipeline_agent: "PipelineAgent", + stt: stt.STT | None, + vad: vad.VAD | None, + turn_detector: _TurnDetector | None = None, + min_endpointing_delay: float, + ) -> None: + self._pipeline_agent = weakref.ref(pipeline_agent) + self._stt_atask: asyncio.Task[None] | None = None + self._vad_atask: asyncio.Task[None] | None = None + self._end_of_turn_task: asyncio.Task[None] | None = None + self._audio_input: io.AudioStream | None = None + self._min_endpointing_delay = min_endpointing_delay + + self._init_stt(stt) + self._init_vad(vad) + self._turn_detector = turn_detector + + self._speaking = False + self._audio_transcript = "" + self._last_language: str | None = None + + @property + def audio_input(self) -> io.AudioStream | None: + return self._audio_input + + @audio_input.setter + def audio_input(self, audio_input: io.AudioStream | None) -> None: + self._init_stt(self._stt) + self._init_vad(self._vad) + self._audio_input = audio_input + + async def _on_stt_event(self, ev: stt.SpeechEvent) -> None: + if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: + transcript = ev.alternatives[0].text + if not transcript: + return + + logger.debug( + "received user transcript", + extra={"user_transcript": new_transcript}, + ) + + self._audio_transcript += f" {transcript}" + self._audio_transcript = self._audio_transcript.lstrip() + + if not self._speaking: + self._run_eou_detection(pipeline_agent.chat_ctx, self._audio_transcript) + + async def _on_vad_event(self, ev: vad.VADEvent) -> None: + if ev.type == vad.VADEventType.START_OF_SPEECH: + self._speaking = True + + if self._end_of_turn_task is not None: + self._end_of_turn_task.cancel() + + elif ev.tupe == vad.VADEventType.END_OF_SPEECH: + self._speaking = False + + def _on_end_of_turn(self) -> None: + # start llm generation + pass + + async def aclose(self) -> None: + if self._stt_atask is not None: + await aio.gracefully_cancel(self._stt_atask) + + if self._vad_atask is not None: + await aio.gracefully_cancel(self._vad_atask) + + if self._end_of_turn_task is not None: + await aio.gracefully_cancel(self._end_of_turn_task) + + def _run_eou_detection( + self, chat_ctx: llm.ChatContext, new_transcript: str + ) -> None: + chat_ctx = pipeline_agent.chat_ctx.copy() + chat_ctx.append(role="user", text=new_transcript) + turn_detector = self._turn_detector + + @utils.log_exceptions(logger=logger) + async def _bounce_eou_task() -> None: + await asyncio.sleep(self._min_endpointing_delay) + + if turn_detector is not None and turn_detector.supports_language( + self._last_language + ): + end_of_turn_probability = await turn_detector.predict_end_of_turn( + chat_ctx + ) + unlikely_threshold = turn_detector.unlikely_threshold() + if end_of_turn_probability > unlikely_threshold: + await asyncio.sleep(self.UNLIKELY_END_OF_TURN_EXTRA_DELAY) + + self._on_end_of_turn() + + if self._end_of_turn_task is not None: + self._end_of_turn_task.cancel() + + self._end_of_turn_task = asyncio.create_task(_bounce_eou_task()) + + async def _stt_task( + self, stt: stt.STT, audio_input: io.AudioStream, task: asyncio.Task[None] | None + ) -> None: + if task is not None: + await aio.gracefully_cancel(task) + + stream = stt.stream() + + async def _forward() -> None: + async for frame in audio_input: + stream.push_frame(frame) + + forward_task = asyncio.create_task(_forward()) + + try: + async for ev in stream: + await self._on_stt_event(ev) + finally: + await stream.aclose() + await aio.gracefully_cancel(forward_task) + + async def _vad_task( + self, vad: vad.VAD, audio_input: io.AudioStream, task: asyncio.Task[None] | None + ) -> None: + if task is not None: + await aio.gracefully_cancel(task) + + stream = vad.stream() + + async def _forward() -> None: + async for frame in audio_input: + stream.push_frame(frame) + + forward_task = asyncio.create_task(_forward()) + + try: + async for ev in stream: + await self._on_vad_event(ev) + finally: + await stream.aclose() + await aio.gracefully_cancel(forward_task) + + def init_stt(self, stt: stt.STT, audio_input: io.AudioStream) -> None: + self._stt = stt + self._stt_atask = asyncio.create_task( + self._stt_task(stt, audio_input, self._stt_atask) + ) + + def init_vad(self, vad: vad.VAD, audio_input: io.AudioStream) -> None: + self._vad = vad + self._vad_atask = asyncio.create_task( + self._vad_task(vad, audio_input, self._vad_atask) + ) diff --git a/livekit-agents/livekit/agents/pipeline/io.py b/livekit-agents/livekit/agents/pipeline/io.py new file mode 100644 index 000000000..3c630ddeb --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/io.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import ( + AsyncIterable, + Awaitable, + Callable, + Optional, + Protocol, + Union, +) + +from livekit import rtc + +from .. import llm, stt + +STTNode = Callable[ + [AsyncIterable[rtc.AudioFrame]], + Union[Awaitable[Optional[AsyncIterable[stt.SpeechEvent]]],], +] +LLMNode = Callable[ + [llm.ChatContext, Optional[llm.FunctionContext]], + Union[ + Optional[Union[AsyncIterable[llm.ChatChunk], AsyncIterable[str], str]], + Awaitable[ + Optional[Union[AsyncIterable[llm.ChatChunk], AsyncIterable[str], str]], + ], + ], +] +TTSNode = Callable[[AsyncIterable[str]], Optional[AsyncIterable[rtc.AudioFrame]]] + + +AudioStream = AsyncIterable[rtc.AudioFrame] +VideoStream = AsyncIterable[rtc.VideoFrame] + + +class AudioSink(Protocol): + async def capture_frame(self, audio: rtc.AudioFrame) -> None: ... + + def flush(self) -> None: ... + + +class TextSink(Protocol): + async def capture_text(self, text: str) -> None: ... + + def flush(self) -> None: ... + + +class VideoSink(Protocol): + async def capture_frame(self, text: rtc.VideoFrame) -> None: ... + + def flush(self) -> None: ... diff --git a/livekit-agents/livekit/agents/pipeline/multimodal.py b/livekit-agents/livekit/agents/pipeline/multimodal.py new file mode 100644 index 000000000..6c68cbd43 --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/multimodal.py @@ -0,0 +1,13 @@ + + +EventTypes = Literal[""] + + +class MultimodalModel: + """Model handling multiple modalities (video/audio/text) + + MultimodalModel assumes stateful multimodal input and output. (MultimodalSession). + """ + + def __init(self) -> None: + pass diff --git a/livekit-agents/livekit/agents/pipeline/pipeline2.py b/livekit-agents/livekit/agents/pipeline/pipeline2.py new file mode 100644 index 000000000..13ed91fcc --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/pipeline2.py @@ -0,0 +1,342 @@ +from __future__ import annotations, print_function + +import asyncio +from typing import ( + AsyncIterable, + Callable, + Literal, + Optional, + Union, +) + +from livekit import rtc + +from .. import io, llm, stt, tts, utils, vad +from ..llm import ChatContext, FunctionContext +from .audio_recognition import AudioRecognition, _TurnDetector +from .generation import _LLMGenerationData, do_llm_inference + + +class AgentInput: + def __init__(self, video_changed: Callable, audio_changed: Callable) -> None: + self._video_stream: io.VideoStream | None = None + self._audio_stream: io.AudioStream | None = None + self._video_changed = video_changed + self._audio_changed = audio_changed + + @property + def video(self) -> io.VideoStream | None: + return self._video_stream + + @video.setter + def video(self, stream: io.VideoStream | None) -> None: + self._video_stream = stream + self._video_changed() + + @property + def audio(self) -> io.AudioStream | None: + return self._audio_stream + + @audio.setter + def audio(self, stream: io.AudioStream | None) -> None: + self._audio_stream = stream + self._audio_changed() + + +class AgentOutput: + def __init__( + self, video_changed: Callable, audio_changed: Callable, text_changed: Callable + ) -> None: + self._video_sink: io.VideoSink | None = None + self._audio_sink: io.AudioSink | None = None + self._text_sink: io.TextSink | None = None + self._video_changed = video_changed + self._audio_changed = audio_changed + self._text_changed = text_changed + + @property + def video(self) -> io.VideoSink | None: + return self._video_sink + + @video.setter + def video(self, sink: io.VideoSink | None) -> None: + self._video_sink = sink + self._video_changed() + + @property + def audio(self) -> io.AudioSink | None: + return self._audio_sink + + @audio.setter + def audio(self, sink: io.AudioSink | None) -> None: + self._audio_sink = sink + self._audio_changed() + + @property + def text(self) -> io.TextSink | None: + return self._text_sink + + @text.setter + def text(self, sink: io.TextSink | None) -> None: + self._text_sink = sink + self._text_changed() + + +class GenerationHandle: + def __init__( + self, *, speech_id: str, allow_interruptions: bool, task: asyncio.Task + ) -> None: + self._id = speech_id + self._allow_interruptions = allow_interruptions + self._task = task + + @staticmethod + def from_task( + task: asyncio.Task, *, allow_interruptions: bool = True + ) -> GenerationHandle: + return GenerationHandle( + speech_id=utils.shortuuid("gen_"), + allow_interruptions=allow_interruptions, + task=task, + ) + + @property + def id(self) -> str: + return self._id + + @property + def allow_interruptions(self) -> bool: + return self._allow_interruptions + + def interrupt(self) -> None: + if not self._allow_interruptions: + raise ValueError("This generation handle does not allow interruptions") + + +class AgentTask: + pass + + +EventTypes = Literal[ + "user_started_speaking", + "user_stopped_speaking", + "agent_started_speaking", + "agent_stopped_speaking", + "user_message_committed", + "agent_message_committed", + "agent_message_interrupted", +] + + +class PipelineAgent(rtc.EventEmitter[EventTypes]): + def __init__( + self, + *, + llm: llm.LLM | None = None, + vad: vad.VAD | None = None, + stt: stt.STT | None = None, + tts: tts.TTS | None = None, + turn_detector: _TurnDetector | None = None, + language: str | None = None, + chat_ctx: ChatContext | None = None, + fnc_ctx: FunctionContext | None = None, + allow_interruptions: bool = True, + min_endpointing_delay: float = 0.5, + max_fnc_steps: int = 5, + loop: asyncio.AbstractEventLoop | None = None, + ) -> None: + super().__init__() + self._loop = loop or asyncio.get_event_loop() + + self._chat_ctx = chat_ctx or ChatContext() + self._fnc_ctx = fnc_ctx + + self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts + self._turn_detector = turn_detector + + self._audio_recognition = AudioRecognition( + agent=self, + stt=self.stt_node, + vad=vad, + turn_detector=turn_detector, + min_endpointing_delay=min_endpointing_delay, + chat_ctx=self._chat_ctx, + loop=self._loop, + ) + + self._max_fnc_steps = max_fnc_steps + self._audio_recognition.on("end_of_turn", self._on_audio_end_of_turn) + + # configurable IO + self._input = AgentInput( + self._on_video_input_changed, self._on_audio_input_changed + ) + self._output = AgentOutput( + self._on_video_output_changed, + self._on_audio_output_changed, + self._on_text_output_changed, + ) + + # current generation happening (including all function calls & steps) + self._current_generation: GenerationHandle | None = None + + # -- Pipeline nodes -- + # They can all be overriden by subclasses, by default they use the STT/LLM/TTS specified in the + # constructor of the PipelineAgent + + async def stt_node( + self, audio: AsyncIterable[rtc.AudioFrame] + ) -> Optional[AsyncIterable[stt.SpeechEvent]]: + assert self._stt is not None, "stt_node called but no STT node is available" + + async with self._stt.stream() as stream: + + async def _forward_input(): + async for frame in audio: + stream.push_frame(frame) + + forward_task = asyncio.create_task(_forward_input()) + try: + async for event in stream: + yield event + finally: + forward_task.cancel() + + async def llm_node( + self, chat_ctx: llm.ChatContext, fnc_ctx: llm.FunctionContext | None + ) -> Union[ + Optional[AsyncIterable[llm.ChatChunk]], + Optional[AsyncIterable[str]], + Optional[str], + ]: + assert self._llm is not None, "llm_node called but no LLM node is available" + + async with self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) as stream: + async for chunk in stream: + yield chunk + + async def tts_node( + self, text: AsyncIterable[str] + ) -> Optional[AsyncIterable[rtc.AudioFrame]]: + assert self._tts is not None, "tts_node called but no TTS node is available" + + def start(self) -> None: + self._audio_recognition.start() + + async def aclose(self) -> None: + await self._audio_recognition.aclose() + + @property + def input(self) -> AgentInput: + return self._input + + @property + def output(self) -> AgentOutput: + return self._output + + # TODO(theomonnom): find a better name than `generation` + @property + def current_generation(self) -> GenerationHandle | None: + return self._current_generation + + @property + def chat_ctx(self) -> llm.ChatContext: + return self._chat_ctx + + def update_options(self) -> None: + pass + + def say(self, text: str | AsyncIterable[str]) -> GenerationHandle: + # say also send to the text output sink... pfff + pass + + def generate_reply(self, user_input: str) -> GenerationHandle: + self._chat_ctx.append(role="user", text=user_input) + + # TODO(theomonnom): Use the agent task chat_ctx + task = asyncio.create_task( + self._generate_task(chat_ctx=self._chat_ctx, fnc_ctx=self._fnc_ctx) + ) + gen_handle = GenerationHandle.from_task(task) + return gen_handle + + # -- Main generation task -- + + async def _generate_task( + self, *, chat_ctx: ChatContext, fnc_ctx: FunctionContext | None + ) -> None: + # new messages generated during the generation (including function calls) + new_messages: list[llm.ChatMessage] = [] + + for i in range( + self._max_fnc_steps + 1 + ): # +1 to ignore the first step that doesn't contain any tools + llm_gen_data = _LLMGenerationData( + chat_ctx=chat_ctx, + # if i >= 2, the LLM supports having multiple steps + fnc_ctx=fnc_ctx if i < self._max_fnc_steps - 1 and i >= 2 else None, + text_ch=utils.aio.Chan(), + tools_ch=utils.aio.Chan(), + ) + + llm_task = asyncio.create_task( + do_llm_inference(node=self.llm_node, data=llm_gen_data) + ) + llm_task.add_done_callback(lambda _: llm_gen_data.text_ch.close()) + llm_task.add_done_callback(lambda _: llm_gen_data.tools_ch.close()) + + # TODO(theomonnom) Do TTS concurrently here if needed + + async def _collect_text_output() -> str: + """collect and forward the generated text to the current agent output""" + generated_text = "" + async for delta in llm_gen_data.text_ch: + if self.output.text is not None: + generated_text += delta + await self.output.text.capture_text(delta) + + if self.output.text is not None: + self.output.text.flush() + + return generated_text + + collect_text_task = asyncio.create_task( + _collect_text_output(), name="_generate_task.collect_text" + ) + + tools: list[llm.FunctionCallInfo] = [] + async for tool in llm_gen_data.tools_ch: + tools.append(tool) + + new_text = await collect_text_task + if len(new_text) > 0: + new_messages.append(llm.ChatMessage(role="assistant", content=new_text)) + + if len(tools) == 0: + break # no more fnc step needed + + # -- Audio recognition -- + + def _on_audio_end_of_turn(self, new_transcript: str) -> None: + pass + + # --- + + # -- User changed input/output streams/sinks -- + + def _on_video_input_changed(self) -> None: + pass + + def _on_audio_input_changed(self) -> None: + self._audio_recognition.audio_input = self._input.audio + + def _on_video_output_changed(self) -> None: + pass + + def _on_audio_output_changed(self) -> None: + pass + + def _on_text_output_changed(self) -> None: + pass + + # --- diff --git a/livekit-agents/livekit/agents/voice_assistant/__init__.py b/livekit-agents/livekit/agents/voice_assistant/__init__.py deleted file mode 100644 index 2dd455ca5..000000000 --- a/livekit-agents/livekit/agents/voice_assistant/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from ..pipeline import AgentCallContext, AgentTranscriptionOptions, VoicePipelineAgent - -AssistantTranscriptionOptions = AgentTranscriptionOptions -AssistantCallContext = AgentCallContext -VoiceAssistant = VoicePipelineAgent - -__all__ = [ - "AssistantTranscriptionOptions", - "AssistantCallContext", - "VoiceAssistant", -] diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index 4708a34d3..d2bed747e 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -213,7 +213,7 @@ def validate_config(self, devmode: bool): ) -EventTypes = Literal["worker_registered"] +EventTypes = Literal["worker_started", "worker_registered"] class Worker(utils.EventEmitter[EventTypes]): @@ -222,6 +222,7 @@ def __init__( opts: WorkerOptions, *, devmode: bool = True, + register: bool = True, loop: asyncio.AbstractEventLoop | None = None, ) -> None: super().__init__() @@ -263,6 +264,7 @@ def __init__( self._close_future: asyncio.Future[None] | None = None self._msg_chan = utils.aio.Chan[agent.WorkerMessage](128, loop=self._loop) self._devmode = devmode + self._register = register # using spawn context for all platforms. We may have further optimizations for # Linux with forkserver, but for now, this is the safest option @@ -313,7 +315,7 @@ def __init__( loop=self._loop, ) - self._main_task: asyncio.Task[None] | None = None + self._conn_task: asyncio.Task[None] | None = None async def run(self): if not self._closed: @@ -347,11 +349,18 @@ def _update_job_status(proc: ipc.job_executor.JobExecutor) -> None: self._http_session = aiohttp.ClientSession() self._close_future = asyncio.Future(loop=self._loop) - self._main_task = asyncio.create_task(self._worker_task(), name="worker_task") tasks = [ - self._main_task, asyncio.create_task(self._http_server.run(), name="http_server"), ] + + if self._register: + self._conn_task = asyncio.create_task( + self._connection_task(), name="worker_conn_task" + ) + tasks.append(self._conn_task) + + self.emit("worker_started") + try: await asyncio.gather(*tasks) finally: @@ -402,12 +411,32 @@ async def simulate_job( api.RoomParticipantIdentity(room=room, identity=participant_identity) ) - msg = agent.WorkerMessage() - msg.simulate_job.room.CopyFrom(room_obj) - if participant: - msg.simulate_job.participant.CopyFrom(participant) + agent_id = utils.shortuuid("simulated-agent-") + token = ( + api.AccessToken(self._opts.api_key, self._opts.api_secret) + .with_identity(agent_id) + .with_kind("agent") + .with_grants(api.VideoGrants(room_join=True, room=room, agent=True)) + .to_jwt() + ) - await self._queue_msg(msg) + job = agent.Job( + id=utils.shortuuid("simulated-job-"), + room=room_obj, + type=agent.JobType.JT_ROOM, + participant=participant, + ) + + running_info = RunningJobInfo( + accept_arguments=JobAcceptArguments( + identity=agent_id, name="", metadata="" + ), + job=job, + url=self._opts.ws_url, + token=token, + ) + + await self._proc_pool.launch_job(running_info) async def aclose(self) -> None: if self._closed: @@ -420,10 +449,11 @@ async def aclose(self) -> None: assert self._close_future is not None assert self._http_session is not None assert self._api is not None - assert self._main_task is not None self._closed = True - self._main_task.cancel() + + if self._conn_task is not None: + await utils.aio.gracefully_cancel(self._conn_task) await self._proc_pool.aclose() @@ -451,7 +481,7 @@ async def _queue_msg(self, msg: agent.WorkerMessage) -> None: await self._msg_chan.send(msg) - async def _worker_task(self) -> None: + async def _connection_task(self) -> None: assert self._http_session is not None retry_count = 0 @@ -605,7 +635,6 @@ async def _reload_jobs(self, jobs: list[RunningJobInfo]) -> None: "reloading job", extra={"job_id": aj.job.id, "agent_name": aj.job.agent_name}, ) - url = self._opts.ws_url # take the original jwt token and extend it while keeping all the same data that was generated # by the SFU for the original join token. @@ -619,7 +648,7 @@ async def _reload_jobs(self, jobs: list[RunningJobInfo]) -> None: running_info = RunningJobInfo( accept_arguments=aj.accept_arguments, job=aj.job, - url=url, + url=self._opts.ws_url, token=jwt.encode(decoded, self._opts.api_secret, algorithm="HS256"), ) await self._proc_pool.launch_job(running_info) From 7fc317ad946cbdb30a1cfb78292d3579f7a24412 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Mon, 30 Dec 2024 16:22:14 +0100 Subject: [PATCH 02/19] wip --- examples/minimal_worker.py | 4 +- livekit-agents/livekit/agents/io.py | 2 +- .../livekit/agents/pipeline/chat_cli.py | 130 ++++++++++++------ .../livekit/agents/pipeline/generation.py | 107 +++++++++----- livekit-agents/livekit/agents/pipeline/io.py | 10 +- .../livekit/agents/pipeline/multimodal.py | 2 - .../livekit/agents/pipeline/pipeline2.py | 104 +++++++++----- livekit-agents/livekit/agents/utils/log.py | 4 +- 8 files changed, 246 insertions(+), 117 deletions(-) diff --git a/examples/minimal_worker.py b/examples/minimal_worker.py index 2f8ada9ce..1a5fe6b40 100644 --- a/examples/minimal_worker.py +++ b/examples/minimal_worker.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, WorkerType, cli from livekit.agents.pipeline import ChatCLI, PipelineAgent -from livekit.plugins import deepgram, openai +from livekit.plugins import deepgram, openai, cartesia logger = logging.getLogger("my-worker") logger.setLevel(logging.INFO) @@ -14,7 +14,7 @@ async def entrypoint(ctx: JobContext): await ctx.connect(auto_subscribe=AutoSubscribe.SUBSCRIBE_ALL) - agent = PipelineAgent(llm=openai.LLM(), stt=deepgram.STT()) + agent = PipelineAgent(llm=openai.LLM(), stt=deepgram.STT(), tts=cartesia.TTS()) agent.start() # start a chat inside the CLI diff --git a/livekit-agents/livekit/agents/io.py b/livekit-agents/livekit/agents/io.py index 3d4cc6820..4cafcbd74 100644 --- a/livekit-agents/livekit/agents/io.py +++ b/livekit-agents/livekit/agents/io.py @@ -18,7 +18,7 @@ class TextChunk: class AudioSink(Protocol): - async def capture_frame(self, audio: rtc.AudioFrame) -> None: ... + async def capture_frame(self, frame: rtc.AudioFrame) -> None: ... def flush(self) -> None: ... diff --git a/livekit-agents/livekit/agents/pipeline/chat_cli.py b/livekit-agents/livekit/agents/pipeline/chat_cli.py index 2a98b0edf..a2177a7ef 100644 --- a/livekit-agents/livekit/agents/pipeline/chat_cli.py +++ b/livekit-agents/livekit/agents/pipeline/chat_cli.py @@ -13,7 +13,8 @@ import sounddevice as sd from livekit import rtc -from ..utils import aio +from ..utils import aio, log_exceptions +from ..log import logger from . import io from .pipeline2 import PipelineAgent @@ -32,7 +33,44 @@ def _normalize_db(amplitude_db: float, db_min: float, db_max: float) -> float: return (amplitude_db - db_min) / (db_max - db_min) -class ChatCLI(io.TextSink): +class _TextSink(io.TextSink): + def __init__(self, cli: "ChatCLI") -> None: + self._cli = cli + self._capturing = False + + async def capture_text(self, text: str) -> None: + if not self._capturing: + self._capturing = True + sys.stdout.write("\r") + sys.stdout.flush() + click.echo(_esc(36), nl=False) + + click.echo(text, nl=False) + + def flush(self) -> None: + if self._capturing: + click.echo(_esc(0)) + self._capturing = False + + +class _AudioSink(io.AudioSink): + def __init__(self, cli: "ChatCLI") -> None: + self._cli = cli + self._capturing = False + + async def capture_frame(self, frame: rtc.AudioFrame) -> None: + if not self._capturing: + self._capturing = True + + if self._cli._output_stream is not None: + self._cli._output_stream.write(frame.data) + + def flush(self) -> None: + if self._capturing: + self._capturing = False + + +class ChatCLI: def __init__( self, agent: PipelineAgent, @@ -41,16 +79,19 @@ def __init__( ) -> None: self._loop = loop or asyncio.get_event_loop() self._agent = agent - self._generation_done_ev = threading.Event() self._done_fut = asyncio.Future() self._micro_db = INPUT_DB_MIN - self._input_ch = aio.Chan[rtc.AudioFrame](loop=self._loop) + self._audio_input_ch = aio.Chan[rtc.AudioFrame](loop=self._loop) + self._input_stream: sd.InputStream | None = None - self._input_mode: Literal["audio", "text"] = "audio" - self._text_buffer = [] # in text mode + self._output_stream: sd.OutputStream | None = None + self._cli_mode: Literal["text", "audio"] = "audio" + + self._text_input_buf = [] - self._text_capturing = False + self._text_sink = _TextSink(self) + self._audio_sink = _AudioSink(self) def _print_welcome(self): print(_esc(34) + "=" * 50 + _esc(0)) @@ -75,6 +116,7 @@ def _on_input(): old_settings = termios.tcgetattr(fd) self._update_microphone(enable=True) + self._update_speaker(enable=True) try: tty.setcbreak(fd) @@ -87,6 +129,7 @@ def _on_input(): await aio.gracefully_cancel(render_cli_task) self._update_microphone(enable=False) + self._update_speaker(enable=False) finally: termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) self._loop.remove_reader(fd) @@ -106,27 +149,43 @@ def _update_microphone(self, *, enable: bool) -> None: samplerate=24000, ) self._input_stream.start() - self._agent.input.audio = self._input_ch + self._agent.input.audio = self._audio_input_ch elif self._input_stream is not None: self._input_stream.stop() self._input_stream.close() self._input_stream = None self._agent.input.audio = None + def _update_speaker(self, *, enable: bool) -> None: + _, output_device = sd.default.device + if output_device is not None and enable: + self._output_stream = sd.OutputStream( + dtype="int16", + channels=1, + device=output_device, + samplerate=24000, + ) + self._output_stream.start() + self._agent.output.audio = self._audio_sink + elif self._output_stream is not None: + self._output_stream.stop() + self._output_stream.close() + self._output_stream = None + self._agent.output.audio = None + def _update_text_output(self, *, enable: bool) -> None: if enable: - self._agent.output.text = self + self._agent.output.text = self._text_sink else: self._agent.output.text = None - self._text_buffer = [] - self._text_capturing = False + self._text_input_buf = [] def _input_sd_callback(self, indata: np.ndarray, frame_count: int, *_) -> None: rms = np.sqrt(np.mean(indata.astype(np.float32) ** 2)) max_int16 = np.iinfo(np.int16).max self._micro_db = 20.0 * np.log10(rms / max_int16 + 1e-6) self._loop.call_soon_threadsafe( - self._input_ch.send_nowait, + self._audio_input_ch.send_nowait, rtc.AudioFrame( data=indata.tobytes(), samples_per_channel=frame_count, @@ -135,6 +194,7 @@ def _input_sd_callback(self, indata: np.ndarray, frame_count: int, *_) -> None: ), ) + @log_exceptions(logger=logger) async def _input_cli_task(self, in_ch: aio.Chan[str]) -> None: while True: char = await in_ch.recv() @@ -142,32 +202,34 @@ async def _input_cli_task(self, in_ch: aio.Chan[str]) -> None: break if char == "\x02": # Ctrl+B - if self._input_mode == "audio": - self._input_mode = "text" + if self._cli_mode == "audio": + self._cli_mode = "text" self._update_text_output(enable=True) self._update_microphone(enable=False) + self._update_speaker(enable=False) click.echo("\nSwitched to Text Input Mode.", nl=False) else: - self._input_mode = "audio" + self._cli_mode = "audio" self._update_text_output(enable=False) self._update_microphone(enable=True) - self._text_buffer = [] + self._update_speaker(enable=True) + self._text_input_buf = [] click.echo("\nSwitched to Audio Input Mode.", nl=False) - if self._input_mode == "text": # Read input + if self._cli_mode == "text": # Read input if char in ("\r", "\n"): - text = "".join(self._text_buffer) + text = "".join(self._text_input_buf) if text: - self._text_buffer = [] + self._text_input_buf = [] self._agent.generate_reply(text) click.echo("\n", nl=False) - elif char == "\x7f": # Backspace - if self._text_buffer: - self._text_buffer.pop() + elif char == "\x7f": # Backspace + if self._text_input_buf: + self._text_input_buf.pop() sys.stdout.write("\b \b") sys.stdout.flush() elif char.isprintable(): - self._text_buffer.append(char) + self._text_input_buf.append(char) click.echo(char, nl=False) sys.stdout.flush() @@ -175,9 +237,9 @@ async def _render_cli_task(self) -> None: next_frame = time.perf_counter() while True: next_frame += 1 / FPS - if self._input_mode == "audio": + if self._cli_mode == "audio": self._print_audio_mode() - elif self._input_mode == "text" and not self._text_capturing: + elif self._cli_mode == "text" and not self._text_sink._capturing: self._print_text_mode() await asyncio.sleep(max(0, next_frame - time.perf_counter())) @@ -199,21 +261,5 @@ def _print_text_mode(self): sys.stdout.write("\r") sys.stdout.flush() prompt = "Enter your message: " - sys.stdout.write(f"[Text] {prompt}{''.join(self._text_buffer)}") + sys.stdout.write(f"[Text] {prompt}{''.join(self._text_input_buf)}") sys.stdout.flush() - - # io.Text Sink implementation - - async def capture_text(self, text: str) -> None: - if not self._text_capturing: - self._text_capturing = True - sys.stdout.write("\r") - sys.stdout.flush() - click.echo(_esc(36), nl=False) - - click.echo(text, nl=False) - - def flush(self) -> None: - if self._text_capturing: - click.echo(_esc(0)) - self._text_capturing = False diff --git a/livekit-agents/livekit/agents/pipeline/generation.py b/livekit-agents/livekit/agents/pipeline/generation.py index 1c1608184..b3c4f3975 100644 --- a/livekit-agents/livekit/agents/pipeline/generation.py +++ b/livekit-agents/livekit/agents/pipeline/generation.py @@ -2,62 +2,105 @@ import asyncio from dataclasses import dataclass -from typing import AsyncIterable +from typing import AsyncIterable, Protocol, Tuple, runtime_checkable +from livekit import rtc from ..llm import ChatChunk, ChatContext, FunctionCallInfo, FunctionContext from ..utils import aio from . import io +@runtime_checkable +class _ACloseable(Protocol): + async def aclose(self): ... + + @dataclass class _LLMGenerationData: - chat_ctx: ChatContext - fnc_ctx: FunctionContext | None text_ch: aio.Chan[str] tools_ch: aio.Chan[FunctionCallInfo] + generated_text: str = "" # + + +def do_llm_inference( + *, node: io.LLMNode, chat_ctx: ChatContext, fnc_ctx: FunctionContext | None +) -> Tuple[asyncio.Task, _LLMGenerationData]: + text_ch = aio.Chan() + tools_ch = aio.Chan() + + data = _LLMGenerationData(text_ch=text_ch, tools_ch=tools_ch) + + async def _inference_task(): + llm_node = node(chat_ctx, fnc_ctx) + if asyncio.iscoroutine(llm_node): + llm_node = await llm_node + if isinstance(llm_node, str): + data.generated_text = llm_node + text_ch.send_nowait(llm_node) + return True -async def do_llm_inference(*, node: io.LLMNode, data: _LLMGenerationData) -> bool: - llm_node = node(data.chat_ctx, data.fnc_ctx) - if asyncio.iscoroutine(llm_node): - llm_node = await llm_node + if isinstance(llm_node, AsyncIterable): + # forward llm stream to output channels + try: + async for chunk in llm_node: + # io.LLMNode can either return a string or a ChatChunk + if isinstance(chunk, str): + data.generated_text += chunk + text_ch.send_nowait(chunk) - if isinstance(llm_node, str): - data.text_ch.send_nowait(llm_node) - return True + elif isinstance(chunk, ChatChunk): + if not chunk.choices: + continue # this can happens if we receive the stats chunk - if isinstance(llm_node, AsyncIterable): - # forward llm stream to output channels - async for chunk in llm_node: - # io.LLMNode can either return a string or a ChatChunk - if isinstance(chunk, str): - data.text_ch.send_nowait(chunk) + delta = chunk.choices[0].delta - elif isinstance(chunk, ChatChunk): - if not chunk.choices: - continue # this can happens if we receive the stats chunk + if delta.tool_calls: + for tool in delta.tool_calls: + tools_ch.send_nowait(tool) - delta = chunk.choices[0].delta + if delta.content: + data.generated_text += delta.content + text_ch.send_nowait(delta.content) + finally: + if isinstance(llm_node, _ACloseable): + await llm_node.aclose() - if delta.tool_calls: - for tool in delta.tool_calls: - data.tools_ch.send_nowait(tool) + return True - if delta.content: - data.text_ch.send_nowait(delta.content) + return False - return True + llm_task = asyncio.create_task(_inference_task()) + llm_task.add_done_callback(lambda _: text_ch.close()) + llm_task.add_done_callback(lambda _: tools_ch.close()) - return False + return llm_task, data @dataclass class _TTSGenerationData: - input_ch: AsyncIterable[str] - audio_ch: aio.Chan[bytes] + audio_ch: aio.Chan[rtc.AudioFrame] + + +def do_tts_inference( + *, node: io.TTSNode, input: AsyncIterable[str] +) -> Tuple[asyncio.Task, _TTSGenerationData]: + audio_ch = aio.Chan[rtc.AudioFrame]() + + async def _inference_task(): + tts_node = node(input) + if asyncio.iscoroutine(tts_node): + tts_node = await tts_node + + if isinstance(tts_node, AsyncIterable): + async for audio_frame in tts_node: + audio_ch.send_nowait(audio_frame) + + return True -async def do_tts_inference(*, node: io.TTSNode, data: _TTSGenerationData) -> bool: - tts_node = node(data.input_ch) + return False + tts_task = asyncio.create_task(_inference_task()) + tts_task.add_done_callback(lambda _: audio_ch.close()) - return False + return tts_task, _TTSGenerationData(audio_ch=audio_ch) diff --git a/livekit-agents/livekit/agents/pipeline/io.py b/livekit-agents/livekit/agents/pipeline/io.py index 3c630ddeb..3297b2e76 100644 --- a/livekit-agents/livekit/agents/pipeline/io.py +++ b/livekit-agents/livekit/agents/pipeline/io.py @@ -26,7 +26,13 @@ ], ], ] -TTSNode = Callable[[AsyncIterable[str]], Optional[AsyncIterable[rtc.AudioFrame]]] +TTSNode = Callable[ + [AsyncIterable[str]], + Union[ + Optional[AsyncIterable[rtc.AudioFrame]], + Awaitable[Optional[AsyncIterable[rtc.AudioFrame]]], + ], +] AudioStream = AsyncIterable[rtc.AudioFrame] @@ -34,7 +40,7 @@ class AudioSink(Protocol): - async def capture_frame(self, audio: rtc.AudioFrame) -> None: ... + async def capture_frame(self, frame: rtc.AudioFrame) -> None: ... def flush(self) -> None: ... diff --git a/livekit-agents/livekit/agents/pipeline/multimodal.py b/livekit-agents/livekit/agents/pipeline/multimodal.py index 6c68cbd43..d572a37ad 100644 --- a/livekit-agents/livekit/agents/pipeline/multimodal.py +++ b/livekit-agents/livekit/agents/pipeline/multimodal.py @@ -1,5 +1,3 @@ - - EventTypes = Literal[""] diff --git a/livekit-agents/livekit/agents/pipeline/pipeline2.py b/livekit-agents/livekit/agents/pipeline/pipeline2.py index 13ed91fcc..8a4bce71a 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline2.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline2.py @@ -14,7 +14,12 @@ from .. import io, llm, stt, tts, utils, vad from ..llm import ChatContext, FunctionContext from .audio_recognition import AudioRecognition, _TurnDetector -from .generation import _LLMGenerationData, do_llm_inference +from .generation import ( + _LLMGenerationData, + _TTSGenerationData, + do_llm_inference, + do_tts_inference, +) class AgentInput: @@ -220,6 +225,21 @@ async def tts_node( ) -> Optional[AsyncIterable[rtc.AudioFrame]]: assert self._tts is not None, "tts_node called but no TTS node is available" + async with self._tts.stream() as stream: + + async def _forward_input(): + async for chunk in text: + stream.push_text(chunk) + + stream.end_input() + + forward_task = asyncio.create_task(_forward_input()) + try: + async for ev in stream: + yield ev.frame + finally: + await utils.aio.gracefully_cancel(forward_task) + def start(self) -> None: self._audio_recognition.start() @@ -265,52 +285,68 @@ def generate_reply(self, user_input: str) -> GenerationHandle: async def _generate_task( self, *, chat_ctx: ChatContext, fnc_ctx: FunctionContext | None ) -> None: + async def _forward_llm_text(llm_output: AsyncIterable[str]) -> None: + """collect and forward the generated text to the current agent output""" + async for delta in llm_output: + if self.output.text is not None: + await self.output.text.capture_text(delta) + + if self.output.text is not None: + self.output.text.flush() + + async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: + """collect and forward the generated audio to the current agent output""" + async for frame in tts_output: + if self.output.audio is not None: + await self.output.audio.capture_frame(frame) + + if self.output.audio is not None: + self.output.audio.flush() + # new messages generated during the generation (including function calls) new_messages: list[llm.ChatMessage] = [] + # TODO(theomonnom): how nested fnc calls is going to work with realtime API? for i in range( self._max_fnc_steps + 1 ): # +1 to ignore the first step that doesn't contain any tools - llm_gen_data = _LLMGenerationData( - chat_ctx=chat_ctx, - # if i >= 2, the LLM supports having multiple steps - fnc_ctx=fnc_ctx if i < self._max_fnc_steps - 1 and i >= 2 else None, - text_ch=utils.aio.Chan(), - tools_ch=utils.aio.Chan(), - ) + # if i >= 2, the LLM supports having multiple steps + fnc_ctx = fnc_ctx if i < self._max_fnc_steps - 1 and i >= 2 else None + chat_ctx = chat_ctx - llm_task = asyncio.create_task( - do_llm_inference(node=self.llm_node, data=llm_gen_data) + llm_task, llm_gen_data = do_llm_inference( + node=self.llm_node, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx ) - llm_task.add_done_callback(lambda _: llm_gen_data.text_ch.close()) - llm_task.add_done_callback(lambda _: llm_gen_data.tools_ch.close()) - - # TODO(theomonnom) Do TTS concurrently here if needed - - async def _collect_text_output() -> str: - """collect and forward the generated text to the current agent output""" - generated_text = "" - async for delta in llm_gen_data.text_ch: - if self.output.text is not None: - generated_text += delta - await self.output.text.capture_text(delta) - - if self.output.text is not None: - self.output.text.flush() - - return generated_text - - collect_text_task = asyncio.create_task( - _collect_text_output(), name="_generate_task.collect_text" + tts_text_input, llm_output = utils.aio.itertools.tee(llm_gen_data.text_ch) + forward_llm_task = asyncio.create_task( + _forward_llm_text(llm_output), name="_generate_task.forward_llm_text" ) + tts_task: asyncio.Task | None = None + forward_tts_task: asyncio.Task | None = None + if self._output.audio is not None: + tts_task, tts_gen_data = do_tts_inference( + node=self.tts_node, input=tts_text_input + ) + forward_tts_task = asyncio.create_task( + _forward_tts_audio(tts_gen_data.audio_ch), + name="_generate_task.forward_tts_audio", + ) + tools: list[llm.FunctionCallInfo] = [] async for tool in llm_gen_data.tools_ch: - tools.append(tool) + tools.append(tool) # TODO(theomonnom): optimize function calls response + + await asyncio.gather(llm_task, forward_llm_task) + + if tts_task is not None and forward_tts_task is not None: + await asyncio.gather(tts_task, forward_tts_task) - new_text = await collect_text_task - if len(new_text) > 0: - new_messages.append(llm.ChatMessage(role="assistant", content=new_text)) + generated_text = llm_gen_data.generated_text + if len(generated_text) > 0: + new_messages.append( + llm.ChatMessage(role="assistant", content=generated_text) + ) if len(tools) == 0: break # no more fnc step needed diff --git a/livekit-agents/livekit/agents/utils/log.py b/livekit-agents/livekit/agents/utils/log.py index 5b985fc59..e53854f20 100644 --- a/livekit-agents/livekit/agents/utils/log.py +++ b/livekit-agents/livekit/agents/utils/log.py @@ -15,7 +15,7 @@ async def async_fn_logs(*args: Any, **kwargs: Any): try: return await fn(*args, **kwargs) except Exception: - err = f"Error in {fn.__name__}" + err = f"Error in {async_fn_logs.__name__}" if msg: err += f" – {msg}" logger.exception(err) @@ -29,7 +29,7 @@ def fn_logs(*args: Any, **kwargs: Any): try: return fn(*args, **kwargs) except Exception: - err = f"Error in {fn.__name__}" + err = f"Error in {fn_logs.__name__}" if msg: err += f" – {msg}" logger.exception(err) From 5b9a187047725f1a1dc5bce481c171785c570791 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Fri, 3 Jan 2025 18:04:51 +0100 Subject: [PATCH 03/19] updated io.AudioSink API to support interruptions --- livekit-agents/livekit/agents/io.py | 35 ------ .../livekit/agents/pipeline/chat_cli.py | 60 +++++++++- .../livekit/agents/pipeline/generation.py | 1 + livekit-agents/livekit/agents/pipeline/io.py | 106 ++++++++++++++++-- .../livekit/agents/pipeline/pipeline2.py | 104 +++++++++++++---- livekit-agents/livekit/agents/worker.py | 1 + 6 files changed, 240 insertions(+), 67 deletions(-) delete mode 100644 livekit-agents/livekit/agents/io.py diff --git a/livekit-agents/livekit/agents/io.py b/livekit-agents/livekit/agents/io.py deleted file mode 100644 index 4cafcbd74..000000000 --- a/livekit-agents/livekit/agents/io.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import AsyncIterable, Protocol - -from livekit import rtc - - -@dataclass -class TextChunk: - text: str - is_final: bool - - -AudioStream = AsyncIterable[rtc.AudioFrame] -VideoStream = AsyncIterable[rtc.VideoFrame] -TextStream = AsyncIterable[TextChunk] - - -class AudioSink(Protocol): - async def capture_frame(self, frame: rtc.AudioFrame) -> None: ... - - def flush(self) -> None: ... - - -class TextSink(Protocol): - async def capture_text(self, text: str) -> None: ... - - def flush(self) -> None: ... - - -class VideoSink(Protocol): - async def capture_frame(self, text: rtc.VideoFrame) -> None: ... - - def flush(self) -> None: ... diff --git a/livekit-agents/livekit/agents/pipeline/chat_cli.py b/livekit-agents/livekit/agents/pipeline/chat_cli.py index a2177a7ef..630392247 100644 --- a/livekit-agents/livekit/agents/pipeline/chat_cli.py +++ b/livekit-agents/livekit/agents/pipeline/chat_cli.py @@ -3,7 +3,6 @@ import asyncio import sys import termios -import threading import time import tty from typing import Literal @@ -13,8 +12,8 @@ import sounddevice as sd from livekit import rtc -from ..utils import aio, log_exceptions from ..log import logger +from ..utils import aio, log_exceptions from . import io from .pipeline2 import PipelineAgent @@ -55,19 +54,72 @@ def flush(self) -> None: class _AudioSink(io.AudioSink): def __init__(self, cli: "ChatCLI") -> None: + super().__init__(sample_rate=24000) self._cli = cli self._capturing = False + self._pushed_duration: float = 0.0 + self._capture_start: float = 0.0 + self._dispatch_handle: asyncio.TimerHandle | None = None + + self._flush_complete = asyncio.Event() + self._flush_complete.set() async def capture_frame(self, frame: rtc.AudioFrame) -> None: + await super().capture_frame(frame) + await self._flush_complete.wait() + + if not frame.duration: + return + if not self._capturing: self._capturing = True + self._buffer_duration = 0.0 + self._capture_start = time.monotonic() + + self._pushed_duration += frame.duration if self._cli._output_stream is not None: self._cli._output_stream.write(frame.data) + def clear_buffer(self) -> None: + self._capturing = False + + if self._cli._output_stream is not None and self._cli._output_stream.active: + # restarting the stream will clear the buffer + self._cli._output_stream.stop() + self._cli._output_stream.start() + + if self._pushed_duration > 0.0: + if self._dispatch_handle is not None: + self._dispatch_handle.cancel() + + self._flush_complete.set() + self._pushed_duration = 0.0 + played_duration = min( + time.monotonic() - self._capture_start, self._pushed_duration + ) + self.on_playback_finished( + playback_position=played_duration, + interrupted=played_duration + 1.0 < self._pushed_duration, + ) + def flush(self) -> None: if self._capturing: + self._flush_complete.clear() self._capturing = False + to_wait = min( + 0.0, self._pushed_duration - (time.monotonic() - self._capture_start) + ) + self._dispatch_handle = self._cli._loop.call_later( + to_wait, self._dispatch_playback_finished + ) + + def _dispatch_playback_finished(self) -> None: + self.on_playback_finished( + playback_position=self._pushed_duration, interrupted=False + ) + self._flush_complete.set() + self._pushed_duration = 0.0 class ChatCLI: @@ -171,7 +223,7 @@ def _update_speaker(self, *, enable: bool) -> None: self._output_stream.stop() self._output_stream.close() self._output_stream = None - self._agent.output.audio = None + self._agent.output.audio def _update_text_output(self, *, enable: bool) -> None: if enable: @@ -261,5 +313,5 @@ def _print_text_mode(self): sys.stdout.write("\r") sys.stdout.flush() prompt = "Enter your message: " - sys.stdout.write(f"[Text] {prompt}{''.join(self._text_input_buf)}") + sys.stdout.write(f"[Text {prompt}{''.join(self._text_input_buf)}") sys.stdout.flush() diff --git a/livekit-agents/livekit/agents/pipeline/generation.py b/livekit-agents/livekit/agents/pipeline/generation.py index b3c4f3975..a8cd4a17b 100644 --- a/livekit-agents/livekit/agents/pipeline/generation.py +++ b/livekit-agents/livekit/agents/pipeline/generation.py @@ -3,6 +3,7 @@ import asyncio from dataclasses import dataclass from typing import AsyncIterable, Protocol, Tuple, runtime_checkable + from livekit import rtc from ..llm import ChatChunk, ChatContext, FunctionCallInfo, FunctionContext diff --git a/livekit-agents/livekit/agents/pipeline/io.py b/livekit-agents/livekit/agents/pipeline/io.py index 3297b2e76..7ce206b8a 100644 --- a/livekit-agents/livekit/agents/pipeline/io.py +++ b/livekit-agents/livekit/agents/pipeline/io.py @@ -1,11 +1,13 @@ from __future__ import annotations +from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import ( AsyncIterable, Awaitable, Callable, + Literal, Optional, - Protocol, Union, ) @@ -13,6 +15,8 @@ from .. import llm, stt +import asyncio + STTNode = Callable[ [AsyncIterable[rtc.AudioFrame]], Union[Awaitable[Optional[AsyncIterable[stt.SpeechEvent]]],], @@ -39,19 +43,105 @@ VideoStream = AsyncIterable[rtc.VideoFrame] -class AudioSink(Protocol): - async def capture_frame(self, frame: rtc.AudioFrame) -> None: ... +@dataclass +class PlaybackFinishedEvent: + playback_position: float + """How much of the audio was played back""" + interrupted: bool + """interrupted is True if playback was interrupted (clear_buffer() was called)""" - def flush(self) -> None: ... +class AudioSink(ABC, rtc.EventEmitter[Literal["playback_finished"]]): + def __init__(self, *, sample_rate: int | None = None) -> None: + """ + Args: + sample_rate: The sample rate required by the audio sink, if None, any sample rate is accepted + """ + super().__init__() + self._sample_rate = sample_rate + self.__capturing = False + self.__playback_finished_event = asyncio.Event() -class TextSink(Protocol): - async def capture_text(self, text: str) -> None: ... + self.__nb_playback_finished_needed = 0 + self.__playback_finished_count = 0 - def flush(self) -> None: ... + def on_playback_finished( + self, *, playback_position: float, interrupted: bool + ) -> None: + """ + Developers building audio sinks must call this method when a playback/segment is finished. + Segments are segmented by calls to flush() or clear_buffer() + """ + self.__nb_playback_finished_needed = max( + 0, self.__nb_playback_finished_needed - 1 + ) + self.__playback_finished_count += 1 + self.__playback_finished_event.set() + + ev = PlaybackFinishedEvent( + playback_position=playback_position, interrupted=interrupted + ) + self.__last_playback_ev = ev + self.emit("playback_finished", ev) + + async def wait_for_playout(self) -> PlaybackFinishedEvent: + """ + Wait for the past audio segments to finish playing out. + + Returns: + PlaybackFinishedEvent: The event that was emitted when the audio finished playing out + (only the last segment information) + """ + needed = self.__nb_playback_finished_needed + initial_count = self.__playback_finished_count + target_count = initial_count + needed + + while self.__playback_finished_count < target_count: + await self.__playback_finished_event.wait() + self.__playback_finished_event.clear() + + return self.__last_playback_ev + + @property + def sample_rate(self) -> int | None: + """The sample rate required by the audio sink, if None, any sample rate is accepted""" + return self._sample_rate + + @abstractmethod + async def capture_frame(self, frame: rtc.AudioFrame) -> None: + """Capture an audio frame for playback, frames can be pushed faster than real-time""" + if not self.__capturing: + self.__capturing = True + self.__nb_playback_finished_needed += 1 + + @abstractmethod + def flush(self) -> None: + """Flush any buffered audio, marking the current playback/segment as complete""" + if self.__capturing: + self.__capturing = False + + @abstractmethod + def clear_buffer(self) -> None: + """Clear the buffer, stopping playback immediately""" + ... + + +class TextSink(ABC): + @abstractmethod + async def capture_text(self, text: str) -> None: + """Capture a text segment (Used by the output of LLM nodes)""" + ... + + @abstractmethod + def flush(self) -> None: + """Mark the current text segment as complete (e.g LLM generation is complete)""" + ... -class VideoSink(Protocol): +# TODO(theomonnom): Add documentation to VideoSink +class VideoSink(ABC): + @abstractmethod async def capture_frame(self, text: rtc.VideoFrame) -> None: ... + @abstractmethod def flush(self) -> None: ... diff --git a/livekit-agents/livekit/agents/pipeline/pipeline2.py b/livekit-agents/livekit/agents/pipeline/pipeline2.py index 8a4bce71a..cb5cbf4f5 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline2.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline2.py @@ -1,6 +1,9 @@ from __future__ import annotations, print_function import asyncio +import contextlib + +from dataclasses import dataclass from typing import ( AsyncIterable, Callable, @@ -11,12 +14,12 @@ from livekit import rtc -from .. import io, llm, stt, tts, utils, vad +from .. import llm, stt, tts, utils, vad from ..llm import ChatContext, FunctionContext +from ..log import logger +from . import io from .audio_recognition import AudioRecognition, _TurnDetector from .generation import ( - _LLMGenerationData, - _TTSGenerationData, do_llm_inference, do_tts_inference, ) @@ -93,6 +96,7 @@ def __init__( ) -> None: self._id = speech_id self._allow_interruptions = allow_interruptions + self._interrupted = False self._task = task @staticmethod @@ -109,6 +113,10 @@ def from_task( def id(self) -> str: return self._id + @property + def interrupted(self) -> bool: + return self._interrupted + @property def allow_interruptions(self) -> bool: return self._allow_interruptions @@ -117,9 +125,11 @@ def interrupt(self) -> None: if not self._allow_interruptions: raise ValueError("This generation handle does not allow interruptions") + if self._task.done(): + return -class AgentTask: - pass + self._interrupted = True + self._task.cancel() EventTypes = Literal[ @@ -133,6 +143,14 @@ class AgentTask: ] +@dataclass +class _PipelineOptions: + language: str | None + allow_interruptions: bool + min_interruption_duration: float + min_endpointing_delay: float + + class PipelineAgent(rtc.EventEmitter[EventTypes]): def __init__( self, @@ -146,6 +164,7 @@ def __init__( chat_ctx: ChatContext | None = None, fnc_ctx: FunctionContext | None = None, allow_interruptions: bool = True, + min_interruption_duration: float = 1.0, min_endpointing_delay: float = 0.5, max_fnc_steps: int = 5, loop: asyncio.AbstractEventLoop | None = None, @@ -169,8 +188,16 @@ def __init__( loop=self._loop, ) + self._opts = _PipelineOptions( + language=language, + allow_interruptions=allow_interruptions, + min_interruption_duration=min_interruption_duration, + min_endpointing_delay=min_endpointing_delay, + ) + self._max_fnc_steps = max_fnc_steps self._audio_recognition.on("end_of_turn", self._on_audio_end_of_turn) + self._audio_recognition.on("vad_inference_done", self._on_vad_inference_done) # configurable IO self._input = AgentInput( @@ -267,41 +294,53 @@ def update_options(self) -> None: pass def say(self, text: str | AsyncIterable[str]) -> GenerationHandle: - # say also send to the text output sink... pfff pass def generate_reply(self, user_input: str) -> GenerationHandle: + if ( + self._current_generation is not None + and not self._current_generation.interrupted + ): + raise ValueError("another reply is already in progress") + self._chat_ctx.append(role="user", text=user_input) # TODO(theomonnom): Use the agent task chat_ctx task = asyncio.create_task( - self._generate_task(chat_ctx=self._chat_ctx, fnc_ctx=self._fnc_ctx) + self._generate_reply_task(chat_ctx=self._chat_ctx, fnc_ctx=self._fnc_ctx) ) gen_handle = GenerationHandle.from_task(task) return gen_handle # -- Main generation task -- - async def _generate_task( + @utils.log_exceptions(logger=logger) + async def _generate_reply_task( self, *, chat_ctx: ChatContext, fnc_ctx: FunctionContext | None ) -> None: + @utils.log_exceptions(logger=logger) async def _forward_llm_text(llm_output: AsyncIterable[str]) -> None: """collect and forward the generated text to the current agent output""" + if self.output.text is None: + return + async for delta in llm_output: - if self.output.text is not None: - await self.output.text.capture_text(delta) + await self.output.text.capture_text(delta) - if self.output.text is not None: - self.output.text.flush() + self.output.text.flush() - async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: + @utils.log_exceptions(logger=logger) + async def _forward_tts_audio( + tts_output: AsyncIterable[rtc.AudioFrame], wait_for_playout: bool = True + ) -> None: """collect and forward the generated audio to the current agent output""" + if self.output.audio is None: + return + async for frame in tts_output: - if self.output.audio is not None: - await self.output.audio.capture_frame(frame) + await self.output.audio.capture_frame(frame) - if self.output.audio is not None: - self.output.audio.flush() + self.output.audio.flush() # new messages generated during the generation (including function calls) new_messages: list[llm.ChatMessage] = [] @@ -319,7 +358,8 @@ async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: ) tts_text_input, llm_output = utils.aio.itertools.tee(llm_gen_data.text_ch) forward_llm_task = asyncio.create_task( - _forward_llm_text(llm_output), name="_generate_task.forward_llm_text" + _forward_llm_text(llm_output), + name="_generate_reply_task.forward_llm_text", ) tts_task: asyncio.Task | None = None @@ -330,7 +370,7 @@ async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: ) forward_tts_task = asyncio.create_task( _forward_tts_audio(tts_gen_data.audio_ch), - name="_generate_task.forward_tts_audio", + name="_generate_reply_task.forward_tts_audio", ) tools: list[llm.FunctionCallInfo] = [] @@ -339,8 +379,11 @@ async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: await asyncio.gather(llm_task, forward_llm_task) + # TODO(theomonnom): Simplify this if tts_task is not None and forward_tts_task is not None: + assert self._output.audio is not None await asyncio.gather(tts_task, forward_tts_task) + playback_ev = await self._output.audio.wait_for_playout() generated_text = llm_gen_data.generated_text if len(generated_text) > 0: @@ -354,7 +397,28 @@ async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: # -- Audio recognition -- def _on_audio_end_of_turn(self, new_transcript: str) -> None: - pass + # When the audio recognition detects the end of a user turn: + # - check if there is no current generation happening + # - cancel the current generation if it allows interruptions (otherwise skip this current + # turn) + # - generate a reply to the user input + + if self._current_generation is not None: + if self._current_generation.allow_interruptions: + logger.warning( + "skipping user input, current speech generation cannot be interrupted", + extra={"user_input": new_transcript}, + ) + return + + self._current_generation.interrupt() + + self.generate_reply(new_transcript) + + def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: + if ev.speech_duration > self._opts.min_interruption_duration: + if self._current_generation is not None: + self._current_generation.interrupt() # --- diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index 7e729373d..52032d7a1 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -428,6 +428,7 @@ async def simulate_job( ) running_info = RunningJobInfo( + worker_id=self._id, accept_arguments=JobAcceptArguments( identity=agent_id, name="", metadata="" ), From 80cb9c45c431eec6132408592b8f5b10e3fc1fae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Sun, 12 Jan 2025 16:27:49 +0100 Subject: [PATCH 04/19] tracing & pipeline --- .../livekit/agents/debug/__init__.py | 8 + .../livekit/agents/debug/index.html | 518 +++++++++++++++++ .../livekit/agents/debug/tracing.py | 197 +++++++ livekit-agents/livekit/agents/http_server.py | 9 +- .../livekit/agents/ipc/job_executor.py | 5 + .../livekit/agents/ipc/job_proc_executor.py | 20 +- .../livekit/agents/ipc/job_proc_lazy_main.py | 17 + .../livekit/agents/ipc/job_thread_executor.py | 23 + livekit-agents/livekit/agents/ipc/proto.py | 32 +- .../agents/pipeline/audio_recognition.py | 94 +++- .../livekit/agents/pipeline/chat_cli.py | 29 +- .../livekit/agents/pipeline/generation.py | 6 +- livekit-agents/livekit/agents/pipeline/io.py | 65 +++ .../livekit/agents/pipeline/pipeline2.py | 521 ++++++++++++------ .../livekit/agents/utils/aio/__init__.py | 2 + .../livekit/agents/utils/aio/wait_group.py | 30 + livekit-agents/livekit/agents/worker.py | 72 ++- livekit-agents/setup.py | 5 +- 18 files changed, 1445 insertions(+), 208 deletions(-) create mode 100644 livekit-agents/livekit/agents/debug/__init__.py create mode 100644 livekit-agents/livekit/agents/debug/index.html create mode 100644 livekit-agents/livekit/agents/debug/tracing.py create mode 100644 livekit-agents/livekit/agents/utils/aio/wait_group.py diff --git a/livekit-agents/livekit/agents/debug/__init__.py b/livekit-agents/livekit/agents/debug/__init__.py new file mode 100644 index 000000000..7a3535d28 --- /dev/null +++ b/livekit-agents/livekit/agents/debug/__init__.py @@ -0,0 +1,8 @@ +from .tracing import Tracing, TracingGraph, TracingHandle + + +__all__ = [ + "Tracing", + "TracingGraph", + "TracingHandle", +] diff --git a/livekit-agents/livekit/agents/debug/index.html b/livekit-agents/livekit/agents/debug/index.html new file mode 100644 index 000000000..72a00157c --- /dev/null +++ b/livekit-agents/livekit/agents/debug/index.html @@ -0,0 +1,518 @@ + + + + + + lkagents - tracing + + + + +
+
+

Worker

+ +
+
+
+ + +
+
+

Runners

+ +
+
+
+ + + + + diff --git a/livekit-agents/livekit/agents/debug/tracing.py b/livekit-agents/livekit/agents/debug/tracing.py new file mode 100644 index 000000000..5ecc59d5e --- /dev/null +++ b/livekit-agents/livekit/agents/debug/tracing.py @@ -0,0 +1,197 @@ +from __future__ import annotations +import asyncio +import time + +from aiohttp import web +from typing import TYPE_CHECKING, Any, Literal +from .. import job + +if TYPE_CHECKING: + from ..worker import Worker + + +class TracingGraph: + def __init__( + self, + title: str, + y_label: str, + x_label: str, + y_range: tuple[float, float] | None, + x_type: Literal["time", "value"], + max_data_points: int, + ) -> None: + self._title = title + self._y_label = y_label + self._x_label = x_label + self._y_range = y_range + self._max_data_points = max_data_points + self._x_type = x_type + self._data: list[tuple[float | int, float]] = [] + + def plot(self, x: float | int, y: float) -> None: + self._data.append((x, y)) + if len(self._data) > self._max_data_points: + self._data.pop(0) + + +class TracingHandle: + def __init__(self) -> None: + self._kv = {} + self._events: list[dict] = [] + self._graphs: list[TracingGraph] = [] + + def store_kv(self, key: str, value: str | dict) -> None: + self._kv[key] = value + + def log_event(self, name: str, data: dict | None) -> None: + self._events.append({"name": name, "data": data, "timestamp": time.time()}) + + def add_graph( + self, + *, + title: str, + x_label: str, + y_label: str, + y_range: tuple[float, float] | None = None, + x_type: Literal["time", "value"] = "value", + max_data_points: int = 512, + ) -> TracingGraph: + graph = TracingGraph(title, y_label, x_label, y_range, x_type, max_data_points) + self._graphs.append(graph) + return graph + + def _export(self) -> dict[str, Any]: + return { + "kv": self._kv, + "events": self._events, + "graph": [ + { + "title": chart._title, + "x_label": chart._x_label, + "y_label": chart._y_label, + "y_range": chart._y_range, + "x_type": chart._x_type, + "data": chart._data, + } + for chart in self._graphs + ], + } + + +class Tracing: + _instance = None + + def __init__(self): + self._handles: dict[str, TracingHandle] = {} + + @classmethod + def with_handle(cls, handle: str) -> TracingHandle: + if cls._instance is None: + cls._instance = cls() + + if handle not in cls._instance._handles: + cls._instance._handles[handle] = TracingHandle() + + return cls._instance._handles[handle] + + @staticmethod + def _get_current_handle() -> TracingHandle: + try: + job_id = job.get_current_job_context().job.id + return Tracing._get_job_handle(job_id) + except RuntimeError: + pass + + return Tracing.with_handle("global") + + @staticmethod + def _get_job_handle(job_id: str) -> TracingHandle: + return Tracing.with_handle(f"job_{job_id}") + + @staticmethod + def store_kv(key: str, value: str | dict) -> None: + Tracing._get_current_handle().store_kv(key, value) + + @staticmethod + def log_event(name: str, data: dict | None = None) -> None: + Tracing._get_current_handle().log_event(name, data) + + @staticmethod + def add_graph( + *, + title: str, + x_label: str, + y_label: str, + y_range: tuple[float, float] | None = None, + x_type: Literal["time", "value"] = "value", + max_data_points: int = 512, + ) -> TracingGraph: + return Tracing._get_current_handle().add_graph( + title=title, + x_label=x_label, + y_label=y_label, + y_range=y_range, + x_type=x_type, + max_data_points=max_data_points, + ) + + +def _create_tracing_app(w: Worker) -> web.Application: + async def tracing_index(request: web.Request) -> web.Response: + import aiofiles + import importlib.resources + + with importlib.resources.path("livekit.agents.debug", "index.html") as path: + async with aiofiles.open(path, mode="r") as f: + content = await f.read() + + return web.Response(text=content, content_type="text/html") + + async def runners(request: web.Request) -> web.Response: + data = { + "runners": [ + { + "id": runner.id, + "status": runner.status.name, + "job_id": runner.running_job.job.id if runner.running_job else None, + "room": runner.running_job.job.room.name + if runner.running_job + else None, + } + for runner in w._proc_pool.processes + if runner.started + ] + } + + return web.json_response(data) + + async def runner(request: web.Request) -> web.Response: + runner_id = request.query.get("id") + if not runner_id: + return web.Response(status=400) + + # TODO: avoid + runner = next((r for r in w._proc_pool.processes if r.id == runner_id), None) + if not runner: + return web.Response(status=404) + + info = await asyncio.wait_for( + runner.tracing_info(), timeout=5.0 + ) # proc could be stuck + return web.json_response({"tracing": info}) + + async def worker(request: web.Request) -> web.Response: + return web.json_response( + { + "id": w.id, + "tracing": Tracing.with_handle("global")._export(), + } + ) + + app = web.Application() + app.add_routes([web.get("", tracing_index)]) + app.add_routes([web.get("/", tracing_index)]) + app.add_routes([web.get("/runners/", runners)]) + app.add_routes([web.get("/runner/", runner)]) + app.add_routes([web.get("/worker/", worker)]) + return app diff --git a/livekit-agents/livekit/agents/http_server.py b/livekit-agents/livekit/agents/http_server.py index c1e379fcb..e769d530d 100644 --- a/livekit-agents/livekit/agents/http_server.py +++ b/livekit-agents/livekit/agents/http_server.py @@ -6,10 +6,6 @@ from aiohttp import web -async def health_check(_: Any): - return web.Response(text="OK") - - class HttpServer: def __init__( self, host: str, port: int, loop: asyncio.AbstractEventLoop | None = None @@ -18,9 +14,12 @@ def __init__( self._host = host self._port = port self._app = web.Application(loop=self._loop) - self._app.add_routes([web.get("/", health_check)]) self._close_future = asyncio.Future[None](loop=self._loop) + @property + def app(self) -> web.Application: + return self._app + async def run(self) -> None: self._runner = web.AppRunner(self._app) await self._runner.setup() diff --git a/livekit-agents/livekit/agents/ipc/job_executor.py b/livekit-agents/livekit/agents/ipc/job_executor.py index dccf1831d..683925cbd 100644 --- a/livekit-agents/livekit/agents/ipc/job_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_executor.py @@ -7,6 +7,9 @@ class JobExecutor(Protocol): + @property + def id(self) -> str: ... + @property def started(self) -> bool: ... @@ -32,6 +35,8 @@ async def aclose(self) -> None: ... async def launch_job(self, info: RunningJobInfo) -> None: ... + async def tracing_info(self) -> dict[str, Any]: ... + class JobStatus(Enum): RUNNING = "running" diff --git a/livekit-agents/livekit/agents/ipc/job_proc_executor.py b/livekit-agents/livekit/agents/ipc/job_proc_executor.py index 75a45a1a3..5c1ecb38b 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_executor.py @@ -8,7 +8,7 @@ from ..job import JobContext, JobProcess, RunningJobInfo from ..log import logger -from ..utils import aio, log_exceptions +from ..utils import aio, log_exceptions, shortuuid from . import channel, proto from .inference_executor import InferenceExecutor from .job_executor import JobStatus @@ -52,6 +52,24 @@ def __init__( self._job_entrypoint_fnc = job_entrypoint_fnc self._inference_executor = inference_executor self._inference_tasks: list[asyncio.Task[None]] = [] + self._id = shortuuid("PCEXEC_") + self._tracing_requests = dict[str, asyncio.Future[proto.TracingResponse]]() + + @property + def id(self) -> str: + return self._id + + async def tracing_info(self) -> dict[str, Any]: + if not self.started: + raise RuntimeError("process not started") + + tracing_req = proto.TracingRequest() + tracing_req.request_id = shortuuid("trace_req_") + fut = asyncio.Future[proto.TracingResponse]() + self._tracing_requests[tracing_req.request_id] = fut + await channel.asend_message(self._pch, tracing_req) + resp = await fut + return resp.info @property def status(self) -> JobStatus: diff --git a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py index 531dd7a36..4baec41b0 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py @@ -28,6 +28,7 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): from livekit import rtc from ..job import JobContext, JobProcess, _JobContextVar +from ..debug import tracing from ..log import logger from ..utils import aio, http_context, log_exceptions, shortuuid from .channel import Message @@ -40,6 +41,8 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): InitializeRequest, ShutdownRequest, StartJobRequest, + TracingRequest, + TracingResponse, ) @@ -169,6 +172,20 @@ async def _read_ipc_task(): if isinstance(msg, InferenceResponse): self._inf_client._on_inference_response(msg) + if isinstance(msg, TracingRequest): + if not self.has_running_job: + logger.warning("tracing request received without running job") + return + + await self._client.send( + TracingResponse( + request_id=msg.request_id, + info=tracing.Tracing._get_job_handle( + self._job_ctx.job.id + )._export(), + ) + ) + read_task = asyncio.create_task(_read_ipc_task(), name="job_ipc_read") await self._exit_proc_flag.wait() diff --git a/livekit-agents/livekit/agents/ipc/job_thread_executor.py b/livekit-agents/livekit/agents/ipc/job_thread_executor.py index 6705422ab..feace496d 100644 --- a/livekit-agents/livekit/agents/ipc/job_thread_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_thread_executor.py @@ -60,6 +60,24 @@ def __init__( self._inference_executor = inference_executor self._inference_tasks: list[asyncio.Task[None]] = [] + self._id = utils.shortuuid("THEXEC_") + self._tracing_requests = dict[str, asyncio.Future[proto.TracingResponse]]() + + @property + def id(self) -> str: + return self._id + + async def tracing_info(self) -> dict[str, Any]: + if not self.started: + raise RuntimeError("thread not started") + + tracing_req = proto.TracingRequest() + tracing_req.request_id = utils.shortuuid("trace_req_") + fut = asyncio.Future[proto.TracingResponse]() + self._tracing_requests[tracing_req.request_id] = fut + await channel.asend_message(self._pch, tracing_req) + resp = await fut + return resp.info @property def status(self) -> JobStatus: @@ -271,6 +289,11 @@ async def _monitor_task(self) -> None: asyncio.create_task(self._do_inference_task(msg)) ) + if isinstance(msg, proto.TracingResponse): + fut = self._tracing_requests.pop(msg.request_id) + with contextlib.suppress(asyncio.InvalidStateError): + fut.set_result(msg) + @utils.log_exceptions(logger=logger) async def _ping_task(self) -> None: ping_interval = utils.aio.interval(self._opts.ping_interval) diff --git a/livekit-agents/livekit/agents/ipc/proto.py b/livekit-agents/livekit/agents/ipc/proto.py index 509964b55..f650d8163 100644 --- a/livekit-agents/livekit/agents/ipc/proto.py +++ b/livekit-agents/livekit/agents/ipc/proto.py @@ -1,8 +1,9 @@ from __future__ import annotations import io +import pickle from dataclasses import dataclass, field -from typing import ClassVar +from typing import Any, ClassVar from livekit.protocol import agent @@ -181,6 +182,33 @@ def read(self, b: io.BytesIO) -> None: self.error = channel.read_string(b) +@dataclass +class TracingRequest: + MSG_ID: ClassVar[int] = 9 + request_id: str = "" + + def write(self, b: io.BytesIO) -> None: + channel.write_string(b, self.request_id) + + def read(self, b: io.BytesIO) -> None: + self.request_id = channel.read_string(b) + + +@dataclass +class TracingResponse: + MSG_ID: ClassVar[int] = 10 + request_id: str = "" + info: dict[str, Any] = field(default_factory=dict) + + def write(self, b: io.BytesIO) -> None: + channel.write_string(b, self.request_id) + channel.write_bytes(b, pickle.dumps(self.info)) + + def read(self, b: io.BytesIO) -> None: + self.request_id = channel.read_string(b) + self.info = pickle.loads(channel.read_bytes(b)) + + IPC_MESSAGES = { InitializeRequest.MSG_ID: InitializeRequest, InitializeResponse.MSG_ID: InitializeResponse, @@ -191,4 +219,6 @@ def read(self, b: io.BytesIO) -> None: Exiting.MSG_ID: Exiting, InferenceRequest.MSG_ID: InferenceRequest, InferenceResponse.MSG_ID: InferenceResponse, + TracingRequest.MSG_ID: TracingRequest, + TracingResponse.MSG_ID: TracingResponse, } diff --git a/livekit-agents/livekit/agents/pipeline/audio_recognition.py b/livekit-agents/livekit/agents/pipeline/audio_recognition.py index 9b238ba25..33dc49e1a 100644 --- a/livekit-agents/livekit/agents/pipeline/audio_recognition.py +++ b/livekit-agents/livekit/agents/pipeline/audio_recognition.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Literal, Protocol +from typing import TYPE_CHECKING, AsyncIterable, Literal, Protocol from livekit import rtc @@ -10,6 +10,8 @@ from ..utils import aio from . import io +from ..debug import tracing + if TYPE_CHECKING: from .pipeline2 import PipelineAgent @@ -55,6 +57,7 @@ def __init__( ) -> None: super().__init__() self._agent = agent + self._audio_input_atask: asyncio.Task[None] | None = None self._stt_atask: asyncio.Task[None] | None = None self._vad_atask: asyncio.Task[None] | None = None self._end_of_turn_task: asyncio.Task[None] | None = None @@ -69,6 +72,17 @@ def __init__( self._speaking = False self._audio_transcript = "" self._last_language: str | None = None + self._vad_graph = tracing.Tracing.add_graph( + title="vad", + x_label="time", + y_label="speech_probability", + x_type="time", + y_range=(0, 1), + max_data_points=int(30 * 30), + ) + + self._stt_ch: aio.Chan[rtc.AudioFrame] | None = None + self._vad_ch: aio.Chan[rtc.AudioFrame] | None = None def start(self) -> None: self.update_stt(self._stt) @@ -84,7 +98,18 @@ def audio_input(self, audio_input: io.AudioStream | None) -> None: self.update_stt(self._stt) self.update_vad(self._vad) + if self._audio_input and self._audio_input_atask is None: + self._audio_input_atask = asyncio.create_task( + self._audio_input_task(self._audio_input) + ) + elif self._audio_input_atask is not None: + self._audio_input_atask.cancel() + self._audio_input_atask = None + async def aclose(self) -> None: + if self._audio_input_atask is not None: + await aio.gracefully_cancel(self._audio_input_atask) + if self._stt_atask is not None: await aio.gracefully_cancel(self._stt_atask) @@ -96,23 +121,27 @@ async def aclose(self) -> None: def update_stt(self, stt: io.STTNode | None) -> None: self._stt = stt - if self._audio_input and stt: + self._stt_ch = aio.Chan[rtc.AudioFrame]() self._stt_atask = asyncio.create_task( - self._stt_task(stt, self._audio_input, self._stt_atask) + self._stt_task(stt, self._stt_ch, self._stt_atask) ) elif self._stt_atask is not None: self._stt_atask.cancel() + self._stt_atask = None + self._stt_ch = None def update_vad(self, vad: vad.VAD | None) -> None: self._vad = vad - if self._audio_input and vad: + self._vad_ch = aio.Chan[rtc.AudioFrame]() self._vad_atask = asyncio.create_task( - self._vad_task(vad, self._audio_input, self._vad_atask) + self._vad_task(vad, self._vad_ch, self._vad_atask) ) elif self._vad_atask is not None: self._vad_atask.cancel() + self._vad_atask = None + self._vad_ch = None async def _on_stt_event(self, ev: stt.SpeechEvent) -> None: if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: @@ -126,16 +155,32 @@ async def _on_stt_event(self, ev: stt.SpeechEvent) -> None: extra={"user_transcript": transcript}, ) + tracing.Tracing.log_event( + "user transcript", + { + "transcript": transcript, + "buffered_transcript": self._audio_transcript, + }, + ) + self._audio_transcript += f" {transcript}" self._audio_transcript = self._audio_transcript.lstrip() if not self._speaking: - self._run_eou_detection(self._agent.chat_ctx, self._audio_transcript) + self._run_eou_detection(self._agent.chat_ctx) elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: self.emit("interim_transcript", ev) + tracing.Tracing.log_event( + "user interim transcript", + { + "interim transcript": ev.alternatives[0].text, + }, + ) + async def _on_vad_event(self, ev: vad.VADEvent) -> None: if ev.type == vad.VADEventType.START_OF_SPEECH: + tracing.Tracing.log_event("start of speech") self.emit("start_of_speech", ev) self._speaking = True @@ -143,17 +188,23 @@ async def _on_vad_event(self, ev: vad.VADEvent) -> None: self._end_of_turn_task.cancel() elif ev.type == vad.VADEventType.INFERENCE_DONE: + self._vad_graph.plot(ev.timestamp, ev.probability) self.emit("vad_inference_done", ev) elif ev.type == vad.VADEventType.END_OF_SPEECH: + tracing.Tracing.log_event("end of speech") self.emit("end_of_speech", ev) self._speaking = False - def _run_eou_detection( - self, chat_ctx: llm.ChatContext, new_transcript: str - ) -> None: + if not self._speaking: + self._run_eou_detection(self._agent.chat_ctx) + + def _run_eou_detection(self, chat_ctx: llm.ChatContext) -> None: + if not self._audio_transcript: + return + chat_ctx = self._chat_ctx.copy() - chat_ctx.append(role="user", text=new_transcript) + chat_ctx.append(role="user", text=self._audio_transcript) turn_detector = self._turn_detector @utils.log_exceptions(logger=logger) @@ -166,11 +217,19 @@ async def _bounce_eou_task() -> None: end_of_turn_probability = await turn_detector.predict_end_of_turn( chat_ctx ) + tracing.Tracing.log_event( + "end of user turn probability", + {"probability": end_of_turn_probability}, + ) unlikely_threshold = turn_detector.unlikely_threshold() if end_of_turn_probability > unlikely_threshold: await asyncio.sleep(self.UNLIKELY_END_OF_TURN_EXTRA_DELAY) - self.emit("end_of_turn", new_transcript) + tracing.Tracing.log_event( + "end of user turn", {"transcript": self._audio_transcript} + ) + self.emit("end_of_turn", self._audio_transcript) + self._audio_transcript = "" if self._end_of_turn_task is not None: self._end_of_turn_task.cancel() @@ -190,7 +249,10 @@ async def _stt_task( if asyncio.iscoroutine(node): node = await node - if node is not None: + if node is None: + return + + if isinstance(node, AsyncIterable): async for ev in node: assert isinstance( ev, stt.SpeechEvent @@ -217,3 +279,11 @@ async def _forward() -> None: finally: await stream.aclose() await aio.gracefully_cancel(forward_task) + + async def _audio_input_task(self, audio_input: io.AudioStream) -> None: + async for frame in audio_input: + if self._stt_ch is not None: + self._stt_ch.send_nowait(frame) + + if self._vad_ch is not None: + self._vad_ch.send_nowait(frame) diff --git a/livekit-agents/livekit/agents/pipeline/chat_cli.py b/livekit-agents/livekit/agents/pipeline/chat_cli.py index 630392247..7f197229c 100644 --- a/livekit-agents/livekit/agents/pipeline/chat_cli.py +++ b/livekit-agents/livekit/agents/pipeline/chat_cli.py @@ -79,7 +79,9 @@ async def capture_frame(self, frame: rtc.AudioFrame) -> None: self._pushed_duration += frame.duration if self._cli._output_stream is not None: - self._cli._output_stream.write(frame.data) + await self._cli._loop.run_in_executor( + None, self._cli._output_stream.write, frame.data + ) def clear_buffer(self) -> None: self._capturing = False @@ -145,6 +147,8 @@ def __init__( self._text_sink = _TextSink(self) self._audio_sink = _AudioSink(self) + self._saved_frames = [] + def _print_welcome(self): print(_esc(34) + "=" * 50 + _esc(0)) print(_esc(34) + " Livekit Agents - ChatCLI" + _esc(0)) @@ -236,15 +240,22 @@ def _input_sd_callback(self, indata: np.ndarray, frame_count: int, *_) -> None: rms = np.sqrt(np.mean(indata.astype(np.float32) ** 2)) max_int16 = np.iinfo(np.int16).max self._micro_db = 20.0 * np.log10(rms / max_int16 + 1e-6) - self._loop.call_soon_threadsafe( - self._audio_input_ch.send_nowait, - rtc.AudioFrame( - data=indata.tobytes(), - samples_per_channel=frame_count, - sample_rate=24000, - num_channels=1, - ), + + frame = rtc.AudioFrame( + data=indata.tobytes(), + samples_per_channel=frame_count, + sample_rate=24000, + num_channels=1, ) + self._saved_frames.append(frame) + + if len(self._saved_frames) > 20 * 5: + frmae = rtc.combine_audio_frames(self._saved_frames) + wav = frmae.to_wav_bytes() + with open("audio.wav", "wb") as f: + f.write(wav) + + self._loop.call_soon_threadsafe(self._audio_input_ch.send_nowait, frame) @log_exceptions(logger=logger) async def _input_cli_task(self, in_ch: aio.Chan[str]) -> None: diff --git a/livekit-agents/livekit/agents/pipeline/generation.py b/livekit-agents/livekit/agents/pipeline/generation.py index a8cd4a17b..03b4e69ac 100644 --- a/livekit-agents/livekit/agents/pipeline/generation.py +++ b/livekit-agents/livekit/agents/pipeline/generation.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import AsyncIterable, Protocol, Tuple, runtime_checkable from livekit import rtc @@ -20,7 +20,8 @@ async def aclose(self): ... class _LLMGenerationData: text_ch: aio.Chan[str] tools_ch: aio.Chan[FunctionCallInfo] - generated_text: str = "" # + generated_text: str = "" + generated_tools: list[FunctionCallInfo] = field(default_factory=list) def do_llm_inference( @@ -58,6 +59,7 @@ async def _inference_task(): if delta.tool_calls: for tool in delta.tool_calls: + data.generated_tools.append(tool) tools_ch.send_nowait(tool) if delta.content: diff --git a/livekit-agents/livekit/agents/pipeline/io.py b/livekit-agents/livekit/agents/pipeline/io.py index 7ce206b8a..accc00ee5 100644 --- a/livekit-agents/livekit/agents/pipeline/io.py +++ b/livekit-agents/livekit/agents/pipeline/io.py @@ -145,3 +145,68 @@ async def capture_frame(self, text: rtc.VideoFrame) -> None: ... @abstractmethod def flush(self) -> None: ... + + +class AgentInput: + def __init__(self, video_changed: Callable, audio_changed: Callable) -> None: + self._video_stream: VideoStream | None = None + self._audio_stream: AudioStream | None = None + self._video_changed = video_changed + self._audio_changed = audio_changed + + @property + def video(self) -> VideoStream | None: + return self._video_stream + + @video.setter + def video(self, stream: VideoStream | None) -> None: + self._video_stream = stream + self._video_changed() + + @property + def audio(self) -> AudioStream | None: + return self._audio_stream + + @audio.setter + def audio(self, stream: AudioStream | None) -> None: + self._audio_stream = stream + self._audio_changed() + + +class AgentOutput: + def __init__( + self, video_changed: Callable, audio_changed: Callable, text_changed: Callable + ) -> None: + self._video_sink: VideoSink | None = None + self._audio_sink: AudioSink | None = None + self._text_sink: TextSink | None = None + self._video_changed = video_changed + self._audio_changed = audio_changed + self._text_changed = text_changed + + @property + def video(self) -> VideoSink | None: + return self._video_sink + + @video.setter + def video(self, sink: VideoSink | None) -> None: + self._video_sink = sink + self._video_changed() + + @property + def audio(self) -> AudioSink | None: + return self._audio_sink + + @audio.setter + def audio(self, sink: AudioSink | None) -> None: + self._audio_sink = sink + self._audio_changed() + + @property + def text(self) -> TextSink | None: + return self._text_sink + + @text.setter + def text(self, sink: TextSink | None) -> None: + self._text_sink = sink + self._text_changed() diff --git a/livekit-agents/livekit/agents/pipeline/pipeline2.py b/livekit-agents/livekit/agents/pipeline/pipeline2.py index cb5cbf4f5..8d130621c 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline2.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline2.py @@ -2,11 +2,12 @@ import asyncio import contextlib +import heapq from dataclasses import dataclass from typing import ( AsyncIterable, - Callable, + Tuple, Literal, Optional, Union, @@ -14,7 +15,7 @@ from livekit import rtc -from .. import llm, stt, tts, utils, vad +from .. import llm, stt, tts, utils, vad, debug, tokenize from ..llm import ChatContext, FunctionContext from ..log import logger from . import io @@ -22,114 +23,72 @@ from .generation import ( do_llm_inference, do_tts_inference, + _TTSGenerationData, ) -class AgentInput: - def __init__(self, video_changed: Callable, audio_changed: Callable) -> None: - self._video_stream: io.VideoStream | None = None - self._audio_stream: io.AudioStream | None = None - self._video_changed = video_changed - self._audio_changed = audio_changed - - @property - def video(self) -> io.VideoStream | None: - return self._video_stream - - @video.setter - def video(self, stream: io.VideoStream | None) -> None: - self._video_stream = stream - self._video_changed() - - @property - def audio(self) -> io.AudioStream | None: - return self._audio_stream - - @audio.setter - def audio(self, stream: io.AudioStream | None) -> None: - self._audio_stream = stream - self._audio_changed() - - -class AgentOutput: - def __init__( - self, video_changed: Callable, audio_changed: Callable, text_changed: Callable - ) -> None: - self._video_sink: io.VideoSink | None = None - self._audio_sink: io.AudioSink | None = None - self._text_sink: io.TextSink | None = None - self._video_changed = video_changed - self._audio_changed = audio_changed - self._text_changed = text_changed - - @property - def video(self) -> io.VideoSink | None: - return self._video_sink - - @video.setter - def video(self, sink: io.VideoSink | None) -> None: - self._video_sink = sink - self._video_changed() - - @property - def audio(self) -> io.AudioSink | None: - return self._audio_sink - - @audio.setter - def audio(self, sink: io.AudioSink | None) -> None: - self._audio_sink = sink - self._audio_changed() - - @property - def text(self) -> io.TextSink | None: - return self._text_sink - - @text.setter - def text(self, sink: io.TextSink | None) -> None: - self._text_sink = sink - self._text_changed() - - -class GenerationHandle: +class SpeechHandle: def __init__( - self, *, speech_id: str, allow_interruptions: bool, task: asyncio.Task + self, *, speech_id: str, allow_interruptions: bool, step_index: int ) -> None: self._id = speech_id + self._step_index = step_index self._allow_interruptions = allow_interruptions - self._interrupted = False - self._task = task + self._interrupt_fut = asyncio.Future() + self._done_fut = asyncio.Future() + self._play_fut = asyncio.Future() + self._playout_done_fut = asyncio.Future() @staticmethod - def from_task( - task: asyncio.Task, *, allow_interruptions: bool = True - ) -> GenerationHandle: - return GenerationHandle( - speech_id=utils.shortuuid("gen_"), + def create(allow_interruptions: bool = True, step_index: int = 0) -> SpeechHandle: + return SpeechHandle( + speech_id=utils.shortuuid("SH_"), allow_interruptions=allow_interruptions, - task=task, + step_index=step_index, ) @property def id(self) -> str: return self._id + @property + def step_index(self) -> int: + return self._step_index + @property def interrupted(self) -> bool: - return self._interrupted + return self._interrupt_fut.done() @property def allow_interruptions(self) -> bool: return self._allow_interruptions + def play(self) -> None: + self._play_fut.set_result(None) + + def done(self) -> bool: + return self._done_fut.done() + def interrupt(self) -> None: if not self._allow_interruptions: raise ValueError("This generation handle does not allow interruptions") - if self._task.done(): + if self.done(): return - self._interrupted = True - self._task.cancel() + self._done_fut.set_result(None) + self._interrupt_fut.set_result(None) + + async def wait_for_playout(self) -> None: + await asyncio.shield(self._playout_done_fut) + + def _mark_playout_done(self) -> None: + self._playout_done_fut.set_result(None) + + def _mark_done(self) -> None: + with contextlib.suppress(asyncio.InvalidStateError): + # will raise InvalidStateError if the future is already done (interrupted) + self._done_fut.set_result(None) EventTypes = Literal[ @@ -152,6 +111,13 @@ class _PipelineOptions: class PipelineAgent(rtc.EventEmitter[EventTypes]): + SPEECH_PRIORITY_LOW = 0 + """Priority for messages that should be played after all other messages in the queue""" + SPEECH_PRIORITY_NORMAL = 5 + """Every speech generates by the PipelineAgent defaults to this priority.""" + SPEECH_PRIORITY_HIGH = 10 + """Priority for important messages that should be played before others.""" + def __init__( self, *, @@ -164,7 +130,7 @@ def __init__( chat_ctx: ChatContext | None = None, fnc_ctx: FunctionContext | None = None, allow_interruptions: bool = True, - min_interruption_duration: float = 1.0, + min_interruption_duration: float = 0.5, min_endpointing_delay: float = 0.5, max_fnc_steps: int = 5, loop: asyncio.AbstractEventLoop | None = None, @@ -176,8 +142,28 @@ def __init__( self._fnc_ctx = fnc_ctx self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts - self._turn_detector = turn_detector + if tts and not tts.capabilities.streaming: + from .. import tts as text_to_speech + + tts = text_to_speech.StreamAdapter( + tts=tts, sentence_tokenizer=tokenize.basic.SentenceTokenizer() + ) + + if stt and not stt.capabilities.streaming: + from .. import stt as speech_to_text + + if vad is None: + raise ValueError( + "VAD is required when streaming is not supported by the STT" + ) + + stt = speech_to_text.StreamAdapter( + stt=stt, + vad=vad, + ) + + self._turn_detector = turn_detector self._audio_recognition = AudioRecognition( agent=self, stt=self.stt_node, @@ -200,17 +186,21 @@ def __init__( self._audio_recognition.on("vad_inference_done", self._on_vad_inference_done) # configurable IO - self._input = AgentInput( + self._input = io.AgentInput( self._on_video_input_changed, self._on_audio_input_changed ) - self._output = AgentOutput( + self._output = io.AgentOutput( self._on_video_output_changed, self._on_audio_output_changed, self._on_text_output_changed, ) - # current generation happening (including all function calls & steps) - self._current_generation: GenerationHandle | None = None + self._current_speech: SpeechHandle | None = None + self._speech_q: list[Tuple[int, SpeechHandle]] = [] + self._speech_q_changed = asyncio.Event() + self._speech_tasks = [] + + self._speech_scheduler_task: asyncio.Task | None = None # -- Pipeline nodes -- # They can all be overriden by subclasses, by default they use the STT/LLM/TTS specified in the @@ -232,7 +222,7 @@ async def _forward_input(): async for event in stream: yield event finally: - forward_task.cancel() + await utils.aio.gracefully_cancel(forward_task) async def llm_node( self, chat_ctx: llm.ChatContext, fnc_ctx: llm.FunctionContext | None @@ -269,22 +259,25 @@ async def _forward_input(): def start(self) -> None: self._audio_recognition.start() + self._speech_scheduler_task = asyncio.create_task( + self._playout_scheduler(), name="_playout_scheduler" + ) async def aclose(self) -> None: await self._audio_recognition.aclose() @property - def input(self) -> AgentInput: + def input(self) -> io.AgentInput: return self._input @property - def output(self) -> AgentOutput: + def output(self) -> io.AgentOutput: return self._output # TODO(theomonnom): find a better name than `generation` @property - def current_generation(self) -> GenerationHandle | None: - return self._current_generation + def current_speech(self) -> SpeechHandle | None: + return self._current_speech @property def chat_ctx(self) -> llm.ChatContext: @@ -293,30 +286,60 @@ def chat_ctx(self) -> llm.ChatContext: def update_options(self) -> None: pass - def say(self, text: str | AsyncIterable[str]) -> GenerationHandle: + def say(self, text: str | AsyncIterable[str]) -> SpeechHandle: pass - def generate_reply(self, user_input: str) -> GenerationHandle: - if ( - self._current_generation is not None - and not self._current_generation.interrupted - ): + def generate_reply(self, user_input: str) -> SpeechHandle: + if self._current_speech is not None and not self._current_speech.interrupted: raise ValueError("another reply is already in progress") - self._chat_ctx.append(role="user", text=user_input) + debug.Tracing.log_event("generate_reply", {"user_input": user_input}) + self._chat_ctx.append(role="user", text=user_input) # TODO(theomonnom) Remove - # TODO(theomonnom): Use the agent task chat_ctx + handle = SpeechHandle.create(allow_interruptions=self._opts.allow_interruptions) task = asyncio.create_task( - self._generate_reply_task(chat_ctx=self._chat_ctx, fnc_ctx=self._fnc_ctx) + self._generate_pipeline_reply_task( + handle=handle, + chat_ctx=self._chat_ctx, + fnc_ctx=self._fnc_ctx, + ), + name="_generate_pipeline_reply", ) - gen_handle = GenerationHandle.from_task(task) - return gen_handle + self._schedule_speech(handle, task, self.SPEECH_PRIORITY_NORMAL) + return handle # -- Main generation task -- + def _schedule_speech( + self, speech: SpeechHandle, task: asyncio.Task, priority: int + ) -> None: + self._speech_tasks.append(task) + task.add_done_callback(lambda _: self._speech_tasks.remove(task)) + + heapq.heappush(self._speech_q, (priority, speech)) + self._speech_q_changed.set() + + @utils.log_exceptions(logger=logger) + async def _playout_scheduler(self) -> None: + while True: + await self._speech_q_changed.wait() + + while self._speech_q: + _, speech = heapq.heappop(self._speech_q) + self._current_speech = speech + speech.play() + await speech.wait_for_playout() + self._current_speech = None + + self._speech_q_changed.clear() + @utils.log_exceptions(logger=logger) - async def _generate_reply_task( - self, *, chat_ctx: ChatContext, fnc_ctx: FunctionContext | None + async def _generate_pipeline_reply_task( + self, + *, + handle: SpeechHandle, + chat_ctx: ChatContext, + fnc_ctx: FunctionContext | None, ) -> None: @utils.log_exceptions(logger=logger) async def _forward_llm_text(llm_output: AsyncIterable[str]) -> None: @@ -324,75 +347,243 @@ async def _forward_llm_text(llm_output: AsyncIterable[str]) -> None: if self.output.text is None: return - async for delta in llm_output: - await self.output.text.capture_text(delta) - - self.output.text.flush() + try: + async for delta in llm_output: + await self.output.text.capture_text(delta) + finally: + self.output.text.flush() @utils.log_exceptions(logger=logger) - async def _forward_tts_audio( - tts_output: AsyncIterable[rtc.AudioFrame], wait_for_playout: bool = True - ) -> None: - """collect and forward the generated audio to the current agent output""" + async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: + """collect and forward the generated audio to the current agent output (generally playout)""" if self.output.audio is None: return - async for frame in tts_output: - await self.output.audio.capture_frame(frame) + try: + async for frame in tts_output: + await self.output.audio.capture_frame(frame) + finally: + self.output.audio.flush() # always flush (even if the task is interrupted) - self.output.audio.flush() + @utils.log_exceptions(logger=logger) + async def _execute_tools( + tools_ch: utils.aio.Chan[llm.FunctionCallInfo], + called_functions: set[llm.CalledFunction], + ) -> None: + """execute tools, when cancelled, stop executing new tools but wait for the pending ones""" + try: + async for tool in tools_ch: + logger.debug( + "executing tool", + extra={ + "function": tool.function_info.name, + "speech_id": handle.id, + }, + ) + debug.Tracing.log_event( + "executing tool", + { + "function": tool.function_info.name, + "speech_id": handle.id, + }, + ) + cfnc = tool.execute() + called_functions.add(cfnc) + except asyncio.CancelledError: + # don't allow to cancel running function calla if they're still running + pending_tools = [cfn for cfn in called_functions if not cfn.task.done()] + + if pending_tools: + names = [cfn.call_info.function_info.name for cfn in pending_tools] + + logger.debug( + "waiting for function call to finish before cancelling", + extra={ + "functions": names, + "speech_id": handle.id, + }, + ) + debug.Tracing.log_event( + "waiting for function call to finish before cancelling", + { + "functions": names, + "speech_id": handle.id, + }, + ) + await asyncio.gather(*[cfn.task for cfn in pending_tools]) + finally: + if len(called_functions) > 0: + logger.debug( + "tools execution completed", + extra={"speech_id": handle.id}, + ) + debug.Tracing.log_event( + "tools execution completed", + {"speech_id": handle.id}, + ) + + debug.Tracing.log_event( + "generation started", + {"speech_id": handle.id, "step_index": handle.step_index}, + ) - # new messages generated during the generation (including function calls) - new_messages: list[llm.ChatMessage] = [] + wg = utils.aio.WaitGroup() + tasks = [] + llm_task, llm_gen_data = do_llm_inference( + node=self.llm_node, + chat_ctx=chat_ctx, + fnc_ctx=fnc_ctx + if handle.step_index < self._max_fnc_steps - 1 and handle.step_index >= 2 + else None, + ) + tasks.append(llm_task) + wg.add(1) + llm_task.add_done_callback(lambda _: wg.done()) + tts_text_input, llm_output = utils.aio.itertools.tee(llm_gen_data.text_ch) + + tts_task: asyncio.Task | None = None + tts_gen_data: _TTSGenerationData | None = None + if self._output.audio is not None: + tts_task, tts_gen_data = do_tts_inference( + node=self.tts_node, input=tts_text_input + ) + tasks.append(tts_task) + wg.add(1) + tts_task.add_done_callback(lambda _: wg.done()) + + # wait for the play() method to be called + await asyncio.wait( + [ + handle._play_fut, + handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, + ) - # TODO(theomonnom): how nested fnc calls is going to work with realtime API? - for i in range( - self._max_fnc_steps + 1 - ): # +1 to ignore the first step that doesn't contain any tools - # if i >= 2, the LLM supports having multiple steps - fnc_ctx = fnc_ctx if i < self._max_fnc_steps - 1 and i >= 2 else None - chat_ctx = chat_ctx + if handle.interrupted: + await utils.aio.gracefully_cancel(*tasks) + handle._mark_done() + return # return directly (the generated output wasn't used) - llm_task, llm_gen_data = do_llm_inference( - node=self.llm_node, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx + # forward tasks are started after the play() method is called + # they redirect the generated text/audio to the output channels + forward_llm_task = asyncio.create_task( + _forward_llm_text(llm_output), + name="_generate_reply_task.forward_llm_text", + ) + tasks.append(forward_llm_task) + wg.add(1) + forward_llm_task.add_done_callback(lambda _: wg.done()) + + forward_tts_task: asyncio.Task | None = None + if tts_gen_data is not None: + forward_tts_task = asyncio.create_task( + _forward_tts_audio(tts_gen_data.audio_ch), + name="_generate_reply_task.forward_tts_audio", ) - tts_text_input, llm_output = utils.aio.itertools.tee(llm_gen_data.text_ch) - forward_llm_task = asyncio.create_task( - _forward_llm_text(llm_output), - name="_generate_reply_task.forward_llm_text", + tasks.append(forward_tts_task) + wg.add(1) + forward_tts_task.add_done_callback(lambda _: wg.done()) + + # start to execute tools (only after play()) + called_functions: set[llm.CalledFunction] = set() + tools_task = asyncio.create_task( + _execute_tools(llm_gen_data.tools_ch, called_functions), + name="_generate_reply_task.execute_tools", + ) + tasks.append(tools_task) + wg.add(1) + tools_task.add_done_callback(lambda _: wg.done()) + + # wait for the tasks to finish + await asyncio.wait( + [ + wg.wait(), + handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, + ) + + # wait for the end of the playout if the audio is enabled + if forward_llm_task is not None: + assert self._output.audio is not None + await asyncio.wait( + [ + self._output.audio.wait_for_playout(), + handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, ) - tts_task: asyncio.Task | None = None - forward_tts_task: asyncio.Task | None = None - if self._output.audio is not None: - tts_task, tts_gen_data = do_tts_inference( - node=self.tts_node, input=tts_text_input + if handle.interrupted: + await utils.aio.gracefully_cancel(*tasks) + + if len(called_functions) > 0: + functions = [ + cfnc.call_info.function_info.name for cfnc in called_functions + ] + logger.debug( + "speech interrupted, ignoring generation of the function calls results", + extra={"speech_id": handle.id, "functions": functions}, ) - forward_tts_task = asyncio.create_task( - _forward_tts_audio(tts_gen_data.audio_ch), - name="_generate_reply_task.forward_tts_audio", + debug.Tracing.log_event( + "speech interrupted, ignoring generation of the function calls results", + {"speech_id": handle.id, "functions": functions}, ) - tools: list[llm.FunctionCallInfo] = [] - async for tool in llm_gen_data.tools_ch: - tools.append(tool) # TODO(theomonnom): optimize function calls response - - await asyncio.gather(llm_task, forward_llm_task) - - # TODO(theomonnom): Simplify this - if tts_task is not None and forward_tts_task is not None: + # if the audio playout was enabled, clear the buffer + if forward_tts_task is not None: assert self._output.audio is not None - await asyncio.gather(tts_task, forward_tts_task) + + self._output.audio.clear_buffer() playback_ev = await self._output.audio.wait_for_playout() - generated_text = llm_gen_data.generated_text - if len(generated_text) > 0: - new_messages.append( - llm.ChatMessage(role="assistant", content=generated_text) + debug.Tracing.log_event( + "playout interrupted", + { + "playback_position": playback_ev.playback_position, + "speech_id": handle.id, + }, ) - if len(tools) == 0: - break # no more fnc step needed + handle._mark_playout_done() + # TODO(theomonnom): calculate the played text based on playback_ev.playback_position + + handle._mark_done() + return + + handle._mark_playout_done() + debug.Tracing.log_event("playout completed", {"speech_id": handle.id}) + + if len(called_functions) > 0: + if handle.step_index >= self._max_fnc_steps: + logger.warning( + "maximum number of function calls steps reached", + extra={"speech_id": handle.id}, + ) + debug.Tracing.log_event( + "maximum number of function calls steps reached", + {"speech_id": handle.id}, + ) + handle._mark_done() + return + + # create a new SpeechHandle to generate the result of the function calls + handle = SpeechHandle.create( + allow_interruptions=self._opts.allow_interruptions, + step_index=handle.step_index + 1, + ) + task = asyncio.create_task( + self._generate_pipeline_reply_task( + handle=handle, + chat_ctx=chat_ctx, + fnc_ctx=fnc_ctx, + ), + name="_generate_pipeline_reply", + ) + self._schedule_speech(handle, task, self.SPEECH_PRIORITY_NORMAL) + + handle._mark_done() # -- Audio recognition -- @@ -403,22 +594,34 @@ def _on_audio_end_of_turn(self, new_transcript: str) -> None: # turn) # - generate a reply to the user input - if self._current_generation is not None: - if self._current_generation.allow_interruptions: + if self._current_speech is not None: + if self._current_speech.allow_interruptions: logger.warning( "skipping user input, current speech generation cannot be interrupted", extra={"user_input": new_transcript}, ) return - self._current_generation.interrupt() + debug.Tracing.log_event( + "speech interrupted, new user turn detected", + {"speech_id": self._current_speech.id}, + ) + self._current_speech.interrupt() self.generate_reply(new_transcript) def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: if ev.speech_duration > self._opts.min_interruption_duration: - if self._current_generation is not None: - self._current_generation.interrupt() + if ( + self._current_speech is not None + and not self._current_speech.interrupted + and self._current_speech.allow_interruptions + ): + debug.Tracing.log_event( + "speech interrupted by vad", + {"speech_id": self._current_speech.id}, + ) + self._current_speech.interrupt() # --- diff --git a/livekit-agents/livekit/agents/utils/aio/__init__.py b/livekit-agents/livekit/agents/utils/aio/__init__.py index df97e26e9..bdef55dd7 100644 --- a/livekit-agents/livekit/agents/utils/aio/__init__.py +++ b/livekit-agents/livekit/agents/utils/aio/__init__.py @@ -6,6 +6,7 @@ from .interval import Interval, interval from .sleep import Sleep, SleepFinished, sleep from .task_set import TaskSet +from .wait_group import WaitGroup async def gracefully_cancel(*futures: asyncio.Future): @@ -45,6 +46,7 @@ def _release_waiter(waiter, *args): "SleepFinished", "sleep", "TaskSet", + "WaitGroup", "debug", "gracefully_cancel", "duplex_unix", diff --git a/livekit-agents/livekit/agents/utils/aio/wait_group.py b/livekit-agents/livekit/agents/utils/aio/wait_group.py new file mode 100644 index 000000000..5b190e8ba --- /dev/null +++ b/livekit-agents/livekit/agents/utils/aio/wait_group.py @@ -0,0 +1,30 @@ +import asyncio + + +class WaitGroup: + """ + asyncio wait group implementation (similar to sync.WaitGroup in go) + """ + + def __init__(self): + self._counter = 0 + self._zero_event = asyncio.Event() + self._zero_event.set() + + def add(self, delta: int = 1): + new_value = self._counter + delta + if new_value < 0: + raise ValueError("WaitGroup counter cannot go negative.") + + self._counter = new_value + + if self._counter == 0: + self._zero_event.set() + else: + self._zero_event.clear() + + def done(self): + self.add(-1) + + async def wait(self): + await self._zero_event.wait() diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index 52032d7a1..7693b0406 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -14,6 +14,7 @@ from __future__ import annotations +import time import asyncio import contextlib import datetime @@ -42,6 +43,7 @@ from livekit.protocol import agent, models from . import http_server, ipc, utils +from .debug import tracing from ._exceptions import AssignmentTimeoutError from .inference_runner import _InferenceRunner from .job import ( @@ -56,8 +58,11 @@ from .utils.hw import get_cpu_monitor from .version import __version__ +from aiohttp import web + ASSIGNMENT_TIMEOUT = 7.5 -UPDATE_LOAD_INTERVAL = 2.5 +UPDATE_STATUS_INTERVAL = 2.5 +UPDATE_LOAD_INTERVAL = 0.5 def _default_initialize_process_fnc(proc: JobProcess) -> Any: @@ -198,7 +203,7 @@ class WorkerOptions: By default it uses ``LIVEKIT_API_SECRET`` from environment""" host: str = "" # default to all interfaces port: int | _WorkerEnvOption[int] = _WorkerEnvOption( - dev_default=0, prod_default=8081 + dev_default=8080, prod_default=8081 ) """Port for local HTTP server to listen on. @@ -315,8 +320,24 @@ def __init__( loop=self._loop, ) + async def health_check(_: Any): + return web.Response(text="OK") + + self._http_server.app.add_routes([web.get("/", health_check)]) + self._http_server.app.add_subapp("/tracing", tracing._create_tracing_app(self)) + self._conn_task: asyncio.Task[None] | None = None + self._worker_load: float = 0.0 + self._worker_load_graph = tracing.Tracing.add_graph( + title="worker_load", + x_label="time", + y_label="load", + x_type="time", + y_range=(0, 1), + max_data_points=int(1 / UPDATE_LOAD_INTERVAL * 30), + ) + async def run(self): if not self._closed: raise Exception("worker is already running") @@ -341,16 +362,37 @@ def _update_job_status(proc: ipc.job_executor.JobExecutor) -> None: self._proc_pool.on("process_started", _update_job_status) self._proc_pool.on("process_closed", _update_job_status) self._proc_pool.on("process_job_launched", _update_job_status) - self._proc_pool.start() + self._api = api.LiveKitAPI( self._opts.ws_url, self._opts.api_key, self._opts.api_secret ) self._http_session = aiohttp.ClientSession() self._close_future = asyncio.Future(loop=self._loop) + @utils.log_exceptions(logger=logger) + async def _load_task(): + """periodically check load""" + interval = utils.aio.interval(UPDATE_LOAD_INTERVAL) + while True: + await interval.tick() + + def load_fnc(): + signature = inspect.signature(self._opts.load_fnc) + parameters = list(signature.parameters.values()) + if len(parameters) == 0: + return self._opts.load_fnc() # type: ignore + + return self._opts.load_fnc(self) # type: ignore + + self._worker_load = await asyncio.get_event_loop().run_in_executor( + None, load_fnc + ) + self._worker_load_graph.plot(time.time(), self._worker_load) + tasks = [ asyncio.create_task(self._http_server.run(), name="http_server"), + asyncio.create_task(_load_task(), name="load_task"), ] if self._register: @@ -482,7 +524,9 @@ async def _queue_msg(self, msg: agent.WorkerMessage) -> None: await self._msg_chan.send(msg) + @utils.log_exceptions(logger=logger) async def _connection_task(self) -> None: + print("connection task") assert self._http_session is not None retry_count = 0 @@ -565,9 +609,11 @@ async def _connection_task(self) -> None: async def _run_ws(self, ws: aiohttp.ClientWebSocketResponse): closing_ws = False + print("running ws") + async def _load_task(): - """periodically check load and update worker status""" - interval = utils.aio.interval(UPDATE_LOAD_INTERVAL) + """periodically update worker status""" + interval = utils.aio.interval(UPDATE_STATUS_INTERVAL) while True: await interval.tick() await self._update_worker_status() @@ -790,17 +836,7 @@ async def _update_worker_status(self): await self._queue_msg(msg) return - def load_fnc(): - signature = inspect.signature(self._opts.load_fnc) - parameters = list(signature.parameters.values()) - if len(parameters) == 0: - return self._opts.load_fnc() # type: ignore - - return self._opts.load_fnc(self) # type: ignore - - current_load = await asyncio.get_event_loop().run_in_executor(None, load_fnc) - - is_full = current_load >= _WorkerEnvOption.getvalue( + is_full = self._worker_load >= _WorkerEnvOption.getvalue( self._opts.load_threshold, self._devmode ) currently_available = not is_full and not self._draining @@ -812,14 +848,14 @@ def load_fnc(): ) update = agent.UpdateWorkerStatus( - load=current_load, status=status, job_count=job_cnt + load=self._worker_load, status=status, job_count=job_cnt ) # only log if status has changed if self._previous_status != status and not self._draining: self._previous_status = status extra = { - "load": current_load, + "load": self._worker_load, "threshold": self._opts.load_threshold, } if is_full: diff --git a/livekit-agents/setup.py b/livekit-agents/setup.py index 9ff541808..944266f82 100644 --- a/livekit-agents/setup.py +++ b/livekit-agents/setup.py @@ -69,7 +69,10 @@ "codecs": ["av>=12.0.0", "numpy>=1.26.0"], "images": ["pillow>=10.3.0"], }, - package_data={"livekit.agents": ["py.typed"]}, + package_data={ + "livekit.agents": ["py.typed"], + "livekit.agents.debug": ["index.html"], + }, project_urls={ "Documentation": "https://docs.livekit.io", "Website": "https://livekit.io/", From 35d97627d674210dda037a76e9286f03b554990d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Sun, 12 Jan 2025 20:57:49 +0100 Subject: [PATCH 05/19] multimodal wip --- .../agents/multimodal/agent_playout.py | 179 -- .../agents/multimodal/multimodal_agent.py | 499 ----- .../livekit/agents/multimodal/realtime.py | 90 + .../livekit/agents/pipeline/agent_output.py | 297 --- .../livekit/agents/pipeline/agent_playout.py | 184 -- .../agents/pipeline/audio_recognition.py | 9 - .../livekit/agents/pipeline/events.py | 11 + .../livekit/agents/pipeline/human_input.py | 150 -- .../livekit/agents/pipeline/impl.py | 183 -- livekit-agents/livekit/agents/pipeline/io.py | 2 +- .../livekit/agents/pipeline/multimodal.py | 11 - .../livekit/agents/pipeline/pipeline2.py | 645 ------- .../livekit/agents/pipeline/pipeline_agent.py | 1659 +++++------------ .../livekit/agents/pipeline/plotter.py | 201 -- .../livekit/agents/pipeline/speech_handle.py | 235 --- .../livekit/agents/transcription/__init__.py | 7 - .../livekit/agents/transcription/_utils.py | 32 - .../agents/transcription/stt_forwarder.py | 126 -- .../agents/transcription/tts_forwarder.py | 430 ----- 19 files changed, 599 insertions(+), 4351 deletions(-) delete mode 100644 livekit-agents/livekit/agents/multimodal/agent_playout.py delete mode 100644 livekit-agents/livekit/agents/multimodal/multimodal_agent.py create mode 100644 livekit-agents/livekit/agents/multimodal/realtime.py delete mode 100644 livekit-agents/livekit/agents/pipeline/agent_output.py delete mode 100644 livekit-agents/livekit/agents/pipeline/agent_playout.py create mode 100644 livekit-agents/livekit/agents/pipeline/events.py delete mode 100644 livekit-agents/livekit/agents/pipeline/human_input.py delete mode 100644 livekit-agents/livekit/agents/pipeline/impl.py delete mode 100644 livekit-agents/livekit/agents/pipeline/multimodal.py delete mode 100644 livekit-agents/livekit/agents/pipeline/pipeline2.py delete mode 100644 livekit-agents/livekit/agents/pipeline/plotter.py delete mode 100644 livekit-agents/livekit/agents/pipeline/speech_handle.py delete mode 100644 livekit-agents/livekit/agents/transcription/__init__.py delete mode 100644 livekit-agents/livekit/agents/transcription/_utils.py delete mode 100644 livekit-agents/livekit/agents/transcription/stt_forwarder.py delete mode 100644 livekit-agents/livekit/agents/transcription/tts_forwarder.py diff --git a/livekit-agents/livekit/agents/multimodal/agent_playout.py b/livekit-agents/livekit/agents/multimodal/agent_playout.py deleted file mode 100644 index f1dbda1e7..000000000 --- a/livekit-agents/livekit/agents/multimodal/agent_playout.py +++ /dev/null @@ -1,179 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import AsyncIterable, Literal - -from livekit import rtc -from livekit.agents import transcription, utils - -from ..log import logger - -EventTypes = Literal["playout_started", "playout_stopped"] - - -class PlayoutHandle: - def __init__( - self, - *, - audio_source: rtc.AudioSource, - item_id: str, - content_index: int, - transcription_fwd: transcription.TTSSegmentsForwarder, - ) -> None: - self._audio_source = audio_source - self._tr_fwd = transcription_fwd - self._item_id = item_id - self._content_index = content_index - - self._int_fut = asyncio.Future[None]() - self._done_fut = asyncio.Future[None]() - - self._interrupted = False - - self._pushed_duration = 0.0 - self._total_played_time: float | None = None # set when the playout is done - - @property - def item_id(self) -> str: - return self._item_id - - @property - def audio_samples(self) -> int: - if self._total_played_time is not None: - return int(self._total_played_time * 24000) - - return int((self._pushed_duration - self._audio_source.queued_duration) * 24000) - - @property - def text_chars(self) -> int: - return len(self._tr_fwd.played_text) - - @property - def content_index(self) -> int: - return self._content_index - - @property - def interrupted(self) -> bool: - return self._interrupted - - def done(self) -> bool: - return self._done_fut.done() or self._interrupted - - def interrupt(self) -> None: - if self.done(): - return - - self._int_fut.set_result(None) - self._interrupted = True - - -class AgentPlayout(utils.EventEmitter[EventTypes]): - def __init__(self, *, audio_source: rtc.AudioSource) -> None: - super().__init__() - self._source = audio_source - self._playout_atask: asyncio.Task[None] | None = None - - def play( - self, - *, - item_id: str, - content_index: int, - transcription_fwd: transcription.TTSSegmentsForwarder, - text_stream: AsyncIterable[str], - audio_stream: AsyncIterable[rtc.AudioFrame], - ) -> PlayoutHandle: - handle = PlayoutHandle( - audio_source=self._source, - item_id=item_id, - content_index=content_index, - transcription_fwd=transcription_fwd, - ) - self._playout_atask = asyncio.create_task( - self._playout_task(self._playout_atask, handle, text_stream, audio_stream) - ) - - return handle - - @utils.log_exceptions(logger=logger) - async def _playout_task( - self, - old_task: asyncio.Task[None], - handle: PlayoutHandle, - text_stream: AsyncIterable[str], - audio_stream: AsyncIterable[rtc.AudioFrame], - ) -> None: - if old_task is not None: - await utils.aio.gracefully_cancel(old_task) - - first_frame = True - - @utils.log_exceptions(logger=logger) - async def _play_text_stream(): - async for text in text_stream: - handle._tr_fwd.push_text(text) - - handle._tr_fwd.mark_text_segment_end() - - @utils.log_exceptions(logger=logger) - async def _capture_task(): - nonlocal first_frame - - samples_per_channel = 1200 - bstream = utils.audio.AudioByteStream( - 24000, - 1, - samples_per_channel=samples_per_channel, - ) - - async for frame in audio_stream: - if first_frame: - handle._tr_fwd.segment_playout_started() - self.emit("playout_started") - first_frame = False - - handle._tr_fwd.push_audio(frame) - - for f in bstream.write(frame.data.tobytes()): - handle._pushed_duration += f.samples_per_channel / f.sample_rate - await self._source.capture_frame(f) - - for f in bstream.flush(): - handle._pushed_duration += f.samples_per_channel / f.sample_rate - await self._source.capture_frame(f) - - handle._tr_fwd.mark_audio_segment_end() - - await self._source.wait_for_playout() - - read_text_task = asyncio.create_task(_play_text_stream()) - capture_task = asyncio.create_task(_capture_task()) - - try: - await asyncio.wait( - [capture_task, handle._int_fut], - return_when=asyncio.FIRST_COMPLETED, - ) - finally: - await utils.aio.gracefully_cancel(capture_task) - - handle._total_played_time = ( - handle._pushed_duration - self._source.queued_duration - ) - - if handle.interrupted or capture_task.exception(): - self._source.clear_queue() # make sure to remove any queued frames - - await utils.aio.gracefully_cancel(read_text_task) - - # make sure the text_data.sentence_stream is closed - handle._tr_fwd.mark_text_segment_end() - - if not first_frame and not handle.interrupted: - handle._tr_fwd.segment_playout_finished() - - await handle._tr_fwd.aclose() - handle._done_fut.set_result(None) - - # emit playout_stopped after the transcription forwarder has been closed - if not first_frame: - self.emit("playout_stopped", handle.interrupted) diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py deleted file mode 100644 index f02bb2e64..000000000 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ /dev/null @@ -1,499 +0,0 @@ -from __future__ import annotations - -import asyncio -from dataclasses import dataclass -from typing import ( - Any, - AsyncIterable, - Callable, - Literal, - Optional, - Protocol, - TypeVar, - Union, - overload, -) - -import aiohttp -from livekit import rtc -from livekit.agents import llm, stt, tokenize, transcription, utils, vad -from livekit.agents.llm import ChatMessage -from livekit.agents.metrics import MultimodalLLMMetrics - -from ..log import logger -from ..types import ATTRIBUTE_AGENT_STATE, AgentState -from . import agent_playout - -EventTypes = Literal[ - "user_started_speaking", - "user_stopped_speaking", - "agent_started_speaking", - "agent_stopped_speaking", - "user_speech_committed", - "agent_speech_committed", - "agent_speech_interrupted", - "function_calls_collected", - "function_calls_finished", - "metrics_collected", -] - - -class _InputTranscriptionProto(Protocol): - item_id: str - """id of the item""" - transcript: str - """transcript of the input audio""" - - -class _ContentProto(Protocol): - response_id: str - item_id: str - output_index: int - content_index: int - text: str - audio: list[rtc.AudioFrame] - text_stream: AsyncIterable[str] - audio_stream: AsyncIterable[rtc.AudioFrame] - content_type: Literal["text", "audio"] - - -class _CapabilitiesProto(Protocol): - supports_truncate: bool - - -class _RealtimeAPI(Protocol): - """Realtime API protocol""" - - @property - def capabilities(self) -> _CapabilitiesProto: ... - def session( - self, - *, - chat_ctx: llm.ChatContext | None = None, - fnc_ctx: llm.FunctionContext | None = None, - ) -> _RealtimeAPISession: - """ - Create a new realtime session with the given chat and function contexts. - """ - pass - - -T = TypeVar("T", bound=Callable[..., Any]) - - -class _RealtimeAPISession(Protocol): - async def set_chat_ctx(self, ctx: llm.ChatContext) -> None: ... - @overload - def on(self, event: str, callback: None = None) -> Callable[[T], T]: ... - @overload - def on(self, event: str, callback: T) -> T: ... - def on( - self, event: str, callback: Optional[T] = None - ) -> Union[T, Callable[[T], T]]: ... - - def _push_audio(self, frame: rtc.AudioFrame) -> None: ... - @property - def fnc_ctx(self) -> llm.FunctionContext | None: ... - @fnc_ctx.setter - def fnc_ctx(self, value: llm.FunctionContext | None) -> None: ... - def chat_ctx_copy(self) -> llm.ChatContext: ... - def _recover_from_text_response(self, item_id: str) -> None: ... - def _update_conversation_item_content( - self, - item_id: str, - content: llm.ChatContent | list[llm.ChatContent] | None = None, - ) -> None: ... - def _truncate_conversation_item( - self, item_id: str, content_index: int, audio_end_ms: int - ) -> None: ... - - -@dataclass(frozen=True) -class AgentTranscriptionOptions: - user_transcription: bool = True - """Whether to forward the user transcription to the client""" - agent_transcription: bool = True - """Whether to forward the agent transcription to the client""" - agent_transcription_speed: float = 1.0 - """The speed at which the agent's speech transcription is forwarded to the client. - We try to mimic the agent's speech speed by adjusting the transcription speed.""" - sentence_tokenizer: tokenize.SentenceTokenizer = tokenize.basic.SentenceTokenizer() - """The tokenizer used to split the speech into sentences. - This is used to decide when to mark a transcript as final for the agent transcription.""" - word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer( - ignore_punctuation=False - ) - """The tokenizer used to split the speech into words. - This is used to simulate the "interim results" of the agent transcription.""" - hyphenate_word: Callable[[str], list[str]] = tokenize.basic.hyphenate_word - """A function that takes a string (word) as input and returns a list of strings, - representing the hyphenated parts of the word.""" - - -@dataclass(frozen=True) -class _ImplOptions: - transcription: AgentTranscriptionOptions - - -class MultimodalAgent(utils.EventEmitter[EventTypes]): - def __init__( - self, - *, - model: _RealtimeAPI, - vad: vad.VAD | None = None, - chat_ctx: llm.ChatContext | None = None, - fnc_ctx: llm.FunctionContext | None = None, - transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(), - max_text_response_retries: int = 5, - loop: asyncio.AbstractEventLoop | None = None, - ): - """Create a new MultimodalAgent. - - Args: - model: RealtimeAPI instance. - vad: Voice Activity Detection (VAD) instance. - chat_ctx: Chat context for the assistant. - fnc_ctx: Function context for the assistant. - transcription: Options for assistant transcription. - max_text_response_retries: Maximum number of retries to recover - from text responses to audio mode. OpenAI's realtime API has a - chance to return text responses instead of audio if the chat - context includes text system or assistant messages. The agent will - attempt to recover to audio mode by deleting the text response - and appending an empty audio message to the conversation. - loop: Event loop to use. Default to asyncio.get_event_loop(). - """ - super().__init__() - self._loop = loop or asyncio.get_event_loop() - - self._model = model - self._vad = vad - self._chat_ctx = chat_ctx - self._fnc_ctx = fnc_ctx - - self._opts = _ImplOptions( - transcription=transcription, - ) - - # audio input - self._read_micro_atask: asyncio.Task | None = None - self._subscribed_track: rtc.RemoteAudioTrack | None = None - self._input_audio_ch = utils.aio.Chan[rtc.AudioFrame]() - - # audio output - self._playing_handle: agent_playout.PlayoutHandle | None = None - - self._linked_participant: rtc.RemoteParticipant | None = None - self._started, self._closed = False, False - - self._update_state_task: asyncio.Task | None = None - self._http_session: aiohttp.ClientSession | None = None - - self._text_response_retries = 0 - self._max_text_response_retries = max_text_response_retries - - @property - def vad(self) -> vad.VAD | None: - return self._vad - - @property - def fnc_ctx(self) -> llm.FunctionContext | None: - return self._session.fnc_ctx - - @fnc_ctx.setter - def fnc_ctx(self, value: llm.FunctionContext | None) -> None: - self._session.fnc_ctx = value - - def chat_ctx_copy(self) -> llm.ChatContext: - return self._session.chat_ctx_copy() - - async def set_chat_ctx(self, ctx: llm.ChatContext) -> None: - await self._session.set_chat_ctx(ctx) - - def start( - self, room: rtc.Room, participant: rtc.RemoteParticipant | str | None = None - ) -> None: - if self._started: - raise RuntimeError("voice assistant already started") - - room.on("participant_connected", self._on_participant_connected) - room.on("track_published", self._subscribe_to_microphone) - room.on("track_subscribed", self._subscribe_to_microphone) - - self._room, self._participant = room, participant - - if participant is not None: - if isinstance(participant, rtc.RemoteParticipant): - self._link_participant(participant.identity) - else: - self._link_participant(participant) - else: - # no participant provided, try to find the first participant in the room - for participant in self._room.remote_participants.values(): - self._link_participant(participant.identity) - break - - self._session = self._model.session( - chat_ctx=self._chat_ctx, fnc_ctx=self._fnc_ctx - ) - - # Create a task to wait for initialization and start the main task - async def _init_and_start(): - try: - await self._session._init_sync_task - logger.info("Session initialized with chat context") - self._main_atask = asyncio.create_task(self._main_task()) - except Exception as e: - logger.exception("Failed to initialize session") - raise e - - # Schedule the initialization and start task - asyncio.create_task(_init_and_start()) - - @self._session.on("response_content_added") - def _on_content_added(message: _ContentProto): - tr_fwd = transcription.TTSSegmentsForwarder( - room=self._room, - participant=self._room.local_participant, - speed=self._opts.transcription.agent_transcription_speed, - sentence_tokenizer=self._opts.transcription.sentence_tokenizer, - word_tokenizer=self._opts.transcription.word_tokenizer, - hyphenate_word=self._opts.transcription.hyphenate_word, - ) - - self._playing_handle = self._agent_playout.play( - item_id=message.item_id, - content_index=message.content_index, - transcription_fwd=tr_fwd, - text_stream=message.text_stream, - audio_stream=message.audio_stream, - ) - - @self._session.on("response_content_done") - def _response_content_done(message: _ContentProto): - if message.content_type == "text": - if self._text_response_retries >= self._max_text_response_retries: - raise RuntimeError( - f"The OpenAI Realtime API returned a text response " - f"after {self._max_text_response_retries} retries. " - f"Please try to reduce the number of text system or " - f"assistant messages in the chat context." - ) - - self._text_response_retries += 1 - logger.warning( - "The OpenAI Realtime API returned a text response instead of audio. " - "Attempting to recover to audio mode...", - extra={ - "item_id": message.item_id, - "text": message.text, - "retries": self._text_response_retries, - }, - ) - self._session._recover_from_text_response(message.item_id) - else: - self._text_response_retries = 0 - - @self._session.on("input_speech_committed") - def _input_speech_committed(): - self._stt_forwarder.update( - stt.SpeechEvent( - type=stt.SpeechEventType.INTERIM_TRANSCRIPT, - alternatives=[stt.SpeechData(language="", text="")], - ) - ) - - @self._session.on("input_speech_transcription_completed") - def _input_speech_transcription_completed(ev: _InputTranscriptionProto): - self._stt_forwarder.update( - stt.SpeechEvent( - type=stt.SpeechEventType.FINAL_TRANSCRIPT, - alternatives=[stt.SpeechData(language="", text=ev.transcript)], - ) - ) - user_msg = ChatMessage.create( - text=ev.transcript, role="user", id=ev.item_id - ) - - self._session._update_conversation_item_content( - ev.item_id, user_msg.content - ) - - self.emit("user_speech_committed", user_msg) - logger.debug( - "committed user speech", - extra={"user_transcript": ev.transcript}, - ) - - @self._session.on("input_speech_started") - def _input_speech_started(): - self.emit("user_started_speaking") - self._update_state("listening") - if self._playing_handle is not None and not self._playing_handle.done(): - self._playing_handle.interrupt() - - if self._model.capabilities.supports_truncate: - self._session._truncate_conversation_item( - item_id=self._playing_handle.item_id, - content_index=self._playing_handle.content_index, - audio_end_ms=int( - self._playing_handle.audio_samples / 24000 * 1000 - ), - ) - - @self._session.on("input_speech_stopped") - def _input_speech_stopped(): - self.emit("user_stopped_speaking") - - @self._session.on("function_calls_collected") - def _function_calls_collected(fnc_call_infos: list[llm.FunctionCallInfo]): - self.emit("function_calls_collected", fnc_call_infos) - - @self._session.on("function_calls_finished") - def _function_calls_finished(called_fncs: list[llm.CalledFunction]): - self.emit("function_calls_finished", called_fncs) - - @self._session.on("metrics_collected") - def _metrics_collected(metrics: MultimodalLLMMetrics): - self.emit("metrics_collected", metrics) - - def _update_state(self, state: AgentState, delay: float = 0.0): - """Set the current state of the agent""" - - @utils.log_exceptions(logger=logger) - async def _run_task(delay: float) -> None: - await asyncio.sleep(delay) - - if self._room.isconnected(): - await self._room.local_participant.set_attributes( - {ATTRIBUTE_AGENT_STATE: state} - ) - - if self._update_state_task is not None: - self._update_state_task.cancel() - - self._update_state_task = asyncio.create_task(_run_task(delay)) - - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: - self._update_state("initializing") - self._audio_source = rtc.AudioSource(24000, 1) - self._agent_playout = agent_playout.AgentPlayout( - audio_source=self._audio_source - ) - - def _on_playout_started() -> None: - self.emit("agent_started_speaking") - self._update_state("speaking") - - def _on_playout_stopped(interrupted: bool) -> None: - self.emit("agent_stopped_speaking") - self._update_state("listening") - - if self._playing_handle is not None: - collected_text = self._playing_handle._tr_fwd.played_text - if interrupted: - collected_text += "..." - - msg = ChatMessage.create( - text=collected_text, - role="assistant", - id=self._playing_handle.item_id, - ) - if self._model.capabilities.supports_truncate: - self._session._update_conversation_item_content( - self._playing_handle.item_id, msg.content - ) - - if interrupted: - self.emit("agent_speech_interrupted", msg) - else: - self.emit("agent_speech_committed", msg) - - logger.debug( - "committed agent speech", - extra={ - "agent_transcript": collected_text, - "interrupted": interrupted, - }, - ) - - self._agent_playout.on("playout_started", _on_playout_started) - self._agent_playout.on("playout_stopped", _on_playout_stopped) - - track = rtc.LocalAudioTrack.create_audio_track( - "assistant_voice", self._audio_source - ) - self._agent_publication = await self._room.local_participant.publish_track( - track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) - ) - - await self._agent_publication.wait_for_subscription() - - bstream = utils.audio.AudioByteStream( - 24000, - 1, - samples_per_channel=2400, - ) - async for frame in self._input_audio_ch: - for f in bstream.write(frame.data.tobytes()): - self._session._push_audio(f) - - def _on_participant_connected(self, participant: rtc.RemoteParticipant): - if self._linked_participant is None: - return - - self._link_participant(participant.identity) - - def _link_participant(self, participant_identity: str) -> None: - self._linked_participant = self._room.remote_participants.get( - participant_identity - ) - if self._linked_participant is None: - logger.error("_link_participant must be called with a valid identity") - return - - self._subscribe_to_microphone() - - async def _micro_task(self, track: rtc.LocalAudioTrack) -> None: - stream_24khz = rtc.AudioStream(track, sample_rate=24000, num_channels=1) - async for ev in stream_24khz: - self._input_audio_ch.send_nowait(ev.frame) - - def _subscribe_to_microphone(self, *args, **kwargs) -> None: - """Subscribe to the participant microphone if found""" - - if self._linked_participant is None: - return - - for publication in self._linked_participant.track_publications.values(): - if publication.source != rtc.TrackSource.SOURCE_MICROPHONE: - continue - - if not publication.subscribed: - publication.set_subscribed(True) - - if ( - publication.track is not None - and publication.track != self._subscribed_track - ): - self._subscribed_track = publication.track # type: ignore - self._stt_forwarder = transcription.STTSegmentsForwarder( - room=self._room, - participant=self._linked_participant, - track=self._subscribed_track, - ) - - if self._read_micro_atask is not None: - self._read_micro_atask.cancel() - - self._read_micro_atask = asyncio.create_task( - self._micro_task(self._subscribed_track) # type: ignore - ) - break - - def _ensure_session(self) -> aiohttp.ClientSession: - if not self._http_session: - self._http_session = utils.http_context.http_session() - - return self._http_session diff --git a/livekit-agents/livekit/agents/multimodal/realtime.py b/livekit-agents/livekit/agents/multimodal/realtime.py new file mode 100644 index 000000000..d8ab54424 --- /dev/null +++ b/livekit-agents/livekit/agents/multimodal/realtime.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass +from abc import ABC, abstractmethod + +from livekit import rtc + +from .. import llm + +from typing import AsyncIterable, Union, Literal, Generic, TypeVar + + +@dataclass +class InputSpeechStartedEvent: + pass + + +@dataclass +class InputSpeechStoppedEvent: + pass + + +@dataclass +class GenerationCreatedEvent: + message_id: str + text_stream: AsyncIterable[str] + audio_stream: AsyncIterable[rtc.AudioFrame] + tool_calls: AsyncIterable[llm.FunctionCallInfo] + + +@dataclass +class ErrorEvent: + type: str + message: str + + +@dataclass +class RealtimeCapabilities: + message_truncation: bool + + +class RealtimeModel: + def __init__(self, capabilities: RealtimeCapabilities) -> None: + self._capabilities = capabilities + + @property + def capabilities(self) -> RealtimeCapabilities: + return self._capabilities + + @abstractmethod + def session(self) -> "RealtimeSession": ... + + +EventTypes = Literal[ + "input_speech_started", # serverside VAD + "input_speech_stopped", # serverside VAD + "generation_created", + "error", +] + +TEvent = TypeVar("TEvent") + + +class RealtimeSession( + ABC, + rtc.EventEmitter[Union[EventTypes, TEvent]], + Generic[TEvent], +): + def __init__(self, realtime_model: RealtimeModel) -> None: + super().__init__() + self._realtime_model = realtime_model + + @property + def realtime_model(self) -> RealtimeModel: + return self._realtime_model + + @property + @abstractmethod + def chat_ctx(self) -> llm.ChatContext: ... + + @abstractmethod + async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: ... + + @abstractmethod + def push_audio(self, frame: rtc.AudioFrame) -> None: ... + + @abstractmethod + def generate_reply(self) -> None: ... # when VAD is disabled + + # message_id is the ID of the message to truncate (inside the ChatCtx) + @abstractmethod + def truncate(self, *, message_id: str, audio_end_ms: int) -> None: ... diff --git a/livekit-agents/livekit/agents/pipeline/agent_output.py b/livekit-agents/livekit/agents/pipeline/agent_output.py deleted file mode 100644 index 14a836ef7..000000000 --- a/livekit-agents/livekit/agents/pipeline/agent_output.py +++ /dev/null @@ -1,297 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -from typing import Any, AsyncIterable, Awaitable, Callable, Union - -from livekit import rtc - -from .. import llm, tokenize, utils -from .. import transcription as agent_transcription -from .. import tts as text_to_speech -from .agent_playout import AgentPlayout, PlayoutHandle -from .log import logger - -SpeechSource = Union[AsyncIterable[str], str, Awaitable[str]] - - -class SynthesisHandle: - def __init__( - self, - *, - speech_id: str, - tts_source: SpeechSource, - transcript_source: SpeechSource, - agent_playout: AgentPlayout, - tts: text_to_speech.TTS, - transcription_fwd: agent_transcription.TTSSegmentsForwarder, - ) -> None: - ( - self._tts_source, - self._transcript_source, - self._agent_playout, - self._tts, - self._tr_fwd, - ) = ( - tts_source, - transcript_source, - agent_playout, - tts, - transcription_fwd, - ) - self._buf_ch = utils.aio.Chan[rtc.AudioFrame]() - self._play_handle: PlayoutHandle | None = None - self._interrupt_fut = asyncio.Future[None]() - self._speech_id = speech_id - - @property - def speech_id(self) -> str: - return self._speech_id - - @property - def tts_forwarder(self) -> agent_transcription.TTSSegmentsForwarder: - return self._tr_fwd - - @property - def validated(self) -> bool: - return self._play_handle is not None - - @property - def interrupted(self) -> bool: - return self._interrupt_fut.done() - - @property - def play_handle(self) -> PlayoutHandle | None: - return self._play_handle - - def play(self) -> PlayoutHandle: - """Validate the speech for playout""" - if self.interrupted: - raise RuntimeError("synthesis was interrupted") - - self._play_handle = self._agent_playout.play( - self._speech_id, self._buf_ch, transcription_fwd=self._tr_fwd - ) - return self._play_handle - - def interrupt(self) -> None: - """Interrupt the speech""" - if self.interrupted: - return - - logger.debug( - "agent interrupted", - extra={"speech_id": self.speech_id}, - ) - - if self._play_handle is not None: - self._play_handle.interrupt() - - self._interrupt_fut.set_result(None) - - -class AgentOutput: - def __init__( - self, - *, - room: rtc.Room, - agent_playout: AgentPlayout, - llm: llm.LLM, - tts: text_to_speech.TTS, - ) -> None: - self._room, self._agent_playout, self._llm, self._tts = ( - room, - agent_playout, - llm, - tts, - ) - self._tasks = set[asyncio.Task[Any]]() - - @property - def playout(self) -> AgentPlayout: - return self._agent_playout - - async def aclose(self) -> None: - for task in self._tasks: - task.cancel() - - await asyncio.gather(*self._tasks, return_exceptions=True) - - def synthesize( - self, - *, - speech_id: str, - tts_source: SpeechSource, - transcript_source: SpeechSource, - transcription: bool, - transcription_speed: float, - sentence_tokenizer: tokenize.SentenceTokenizer, - word_tokenizer: tokenize.WordTokenizer, - hyphenate_word: Callable[[str], list[str]], - ) -> SynthesisHandle: - def _before_forward( - fwd: agent_transcription.TTSSegmentsForwarder, - transcription: rtc.Transcription, - ): - if not transcription: - transcription.segments = [] - - return transcription - - transcription_fwd = agent_transcription.TTSSegmentsForwarder( - room=self._room, - participant=self._room.local_participant, - speed=transcription_speed, - sentence_tokenizer=sentence_tokenizer, - word_tokenizer=word_tokenizer, - hyphenate_word=hyphenate_word, - before_forward_cb=_before_forward, - ) - - handle = SynthesisHandle( - tts_source=tts_source, - transcript_source=transcript_source, - agent_playout=self._agent_playout, - tts=self._tts, - transcription_fwd=transcription_fwd, - speech_id=speech_id, - ) - - task = asyncio.create_task(self._synthesize_task(handle)) - self._tasks.add(task) - task.add_done_callback(self._tasks.remove) - return handle - - @utils.log_exceptions(logger=logger) - async def _synthesize_task(self, handle: SynthesisHandle) -> None: - """Synthesize speech from the source""" - tts_source = handle._tts_source - transcript_source = handle._transcript_source - - if isinstance(tts_source, Awaitable): - tts_source = await tts_source - if isinstance(transcript_source, Awaitable): - transcript_source = await transcript_source - - if isinstance(tts_source, str): - co = self._str_synthesis_task(tts_source, transcript_source, handle) - else: - co = self._stream_synthesis_task(tts_source, transcript_source, handle) - - synth = asyncio.create_task(co) - synth.add_done_callback(lambda _: handle._buf_ch.close()) - try: - _ = await asyncio.wait( - [synth, handle._interrupt_fut], return_when=asyncio.FIRST_COMPLETED - ) - finally: - await utils.aio.gracefully_cancel(synth) - - @utils.log_exceptions(logger=logger) - async def _read_transcript_task( - self, transcript_source: AsyncIterable[str] | str, handle: SynthesisHandle - ) -> None: - try: - if isinstance(transcript_source, str): - handle._tr_fwd.push_text(transcript_source) - else: - async for seg in transcript_source: - if not handle._tr_fwd.closed: - handle._tr_fwd.push_text(seg) - - if not handle.tts_forwarder.closed: - handle.tts_forwarder.mark_text_segment_end() - finally: - if inspect.isasyncgen(transcript_source): - await transcript_source.aclose() - - @utils.log_exceptions(logger=logger) - async def _str_synthesis_task( - self, - tts_text: str, - transcript_source: AsyncIterable[str] | str, - handle: SynthesisHandle, - ) -> None: - """synthesize speech from a string""" - read_transcript_atask: asyncio.Task | None = None - - first_frame = True - tts_stream = handle._tts.synthesize(tts_text) - try: - async for audio in tts_stream: - if first_frame: - first_frame = False - read_transcript_atask = asyncio.create_task( - self._read_transcript_task(transcript_source, handle) - ) - - handle._buf_ch.send_nowait(audio.frame) - if not handle.tts_forwarder.closed: - handle.tts_forwarder.push_audio(audio.frame) - - if not handle.tts_forwarder.closed: - handle.tts_forwarder.mark_audio_segment_end() - - if read_transcript_atask is not None: - await read_transcript_atask - finally: - await tts_stream.aclose() - - if read_transcript_atask is not None: - await utils.aio.gracefully_cancel(read_transcript_atask) - - @utils.log_exceptions(logger=logger) - async def _stream_synthesis_task( - self, - tts_source: AsyncIterable[str], - transcript_source: AsyncIterable[str] | str, - handle: SynthesisHandle, - ) -> None: - """synthesize speech from streamed text""" - - @utils.log_exceptions(logger=logger) - async def _read_generated_audio_task( - tts_stream: text_to_speech.SynthesizeStream, - ) -> None: - try: - async for audio in tts_stream: - if not handle._tr_fwd.closed: - handle._tr_fwd.push_audio(audio.frame) - - handle._buf_ch.send_nowait(audio.frame) - finally: - if handle._tr_fwd and not handle._tr_fwd.closed: - handle._tr_fwd.mark_audio_segment_end() - - await tts_stream.aclose() - - tts_stream: text_to_speech.SynthesizeStream | None = None - read_tts_atask: asyncio.Task | None = None - read_transcript_atask: asyncio.Task | None = None - - try: - async for seg in tts_source: - if tts_stream is None: - tts_stream = handle._tts.stream() - read_tts_atask = asyncio.create_task( - _read_generated_audio_task(tts_stream) - ) - read_transcript_atask = asyncio.create_task( - self._read_transcript_task(transcript_source, handle) - ) - - tts_stream.push_text(seg) - - if tts_stream is not None: - tts_stream.end_input() - assert read_transcript_atask and read_tts_atask - await read_tts_atask - await read_transcript_atask - - finally: - if read_tts_atask is not None: - assert read_transcript_atask is not None - await utils.aio.gracefully_cancel(read_tts_atask, read_transcript_atask) - - if inspect.isasyncgen(tts_source): - await tts_source.aclose() diff --git a/livekit-agents/livekit/agents/pipeline/agent_playout.py b/livekit-agents/livekit/agents/pipeline/agent_playout.py deleted file mode 100644 index 482b0e942..000000000 --- a/livekit-agents/livekit/agents/pipeline/agent_playout.py +++ /dev/null @@ -1,184 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import AsyncIterable, Literal - -from livekit import rtc - -from .. import transcription, utils -from .log import logger - -EventTypes = Literal["playout_started", "playout_stopped"] - - -class PlayoutHandle: - def __init__( - self, - speech_id: str, - audio_source: rtc.AudioSource, - playout_source: AsyncIterable[rtc.AudioFrame], - transcription_fwd: transcription.TTSSegmentsForwarder, - ) -> None: - self._playout_source = playout_source - self._audio_source = audio_source - self._tr_fwd = transcription_fwd - self._interrupted = False - self._int_fut = asyncio.Future[None]() - self._done_fut = asyncio.Future[None]() - self._speech_id = speech_id - - self._pushed_duration = 0.0 - - self._total_played_time: float | None = None # set whem the playout is done - - @property - def speech_id(self) -> str: - return self._speech_id - - @property - def interrupted(self) -> bool: - return self._interrupted - - @property - def time_played(self) -> float: - if self._total_played_time is not None: - return self._total_played_time - - return self._pushed_duration - self._audio_source.queued_duration - - def done(self) -> bool: - return self._done_fut.done() or self._interrupted - - def interrupt(self) -> None: - if self.done(): - return - - self._int_fut.set_result(None) - self._interrupted = True - - def join(self) -> asyncio.Future: - return self._done_fut - - -class AgentPlayout(utils.EventEmitter[EventTypes]): - def __init__(self, *, audio_source: rtc.AudioSource) -> None: - super().__init__() - self._audio_source = audio_source - self._target_volume = 1.0 - self._playout_atask: asyncio.Task[None] | None = None - self._closed = False - - @property - def target_volume(self) -> float: - return self._target_volume - - @target_volume.setter - def target_volume(self, value: float) -> None: - self._target_volume = value - - @property - def smoothed_volume(self) -> float: - return self._target_volume - - async def aclose(self) -> None: - if self._closed: - return - - self._closed = True - - if self._playout_atask is not None: - await self._playout_atask - - def play( - self, - speech_id: str, - playout_source: AsyncIterable[rtc.AudioFrame], - transcription_fwd: transcription.TTSSegmentsForwarder, - ) -> PlayoutHandle: - if self._closed: - raise ValueError("cancellable source is closed") - - handle = PlayoutHandle( - speech_id=speech_id, - audio_source=self._audio_source, - playout_source=playout_source, - transcription_fwd=transcription_fwd, - ) - self._playout_atask = asyncio.create_task( - self._playout_task(self._playout_atask, handle) - ) - - return handle - - @utils.log_exceptions(logger=logger) - async def _playout_task( - self, old_task: asyncio.Task[None] | None, handle: PlayoutHandle - ) -> None: - if old_task is not None: - await utils.aio.gracefully_cancel(old_task) - - if self._audio_source.queued_duration > 0: - # this should not happen, but log it just in case - logger.warning( - "new playout while the source is still playing", - extra={ - "speech_id": handle.speech_id, - "queued_duration": self._audio_source.queued_duration, - }, - ) - - first_frame = True - - @utils.log_exceptions(logger=logger) - async def _capture_task(): - nonlocal first_frame - async for frame in handle._playout_source: - if first_frame: - handle._tr_fwd.segment_playout_started() - - logger.debug( - "speech playout started", - extra={"speech_id": handle.speech_id}, - ) - - self.emit("playout_started") - first_frame = False - - handle._pushed_duration += frame.samples_per_channel / frame.sample_rate - await self._audio_source.capture_frame(frame) - - if self._audio_source.queued_duration > 0: - await self._audio_source.wait_for_playout() - - capture_task = asyncio.create_task(_capture_task()) - try: - await asyncio.wait( - [capture_task, handle._int_fut], - return_when=asyncio.FIRST_COMPLETED, - ) - finally: - await utils.aio.gracefully_cancel(capture_task) - - handle._total_played_time = ( - handle._pushed_duration - self._audio_source.queued_duration - ) - - if handle.interrupted or capture_task.exception(): - self._audio_source.clear_queue() # make sure to remove any queued frames - - if not first_frame: - if not handle.interrupted: - handle._tr_fwd.segment_playout_finished() - - self.emit("playout_stopped", handle.interrupted) - - await handle._tr_fwd.aclose() - handle._done_fut.set_result(None) - - logger.debug( - "speech playout finished", - extra={ - "speech_id": handle.speech_id, - "interrupted": handle.interrupted, - }, - ) diff --git a/livekit-agents/livekit/agents/pipeline/audio_recognition.py b/livekit-agents/livekit/agents/pipeline/audio_recognition.py index 33dc49e1a..7ae05430d 100644 --- a/livekit-agents/livekit/agents/pipeline/audio_recognition.py +++ b/livekit-agents/livekit/agents/pipeline/audio_recognition.py @@ -171,16 +171,8 @@ async def _on_stt_event(self, ev: stt.SpeechEvent) -> None: elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: self.emit("interim_transcript", ev) - tracing.Tracing.log_event( - "user interim transcript", - { - "interim transcript": ev.alternatives[0].text, - }, - ) - async def _on_vad_event(self, ev: vad.VADEvent) -> None: if ev.type == vad.VADEventType.START_OF_SPEECH: - tracing.Tracing.log_event("start of speech") self.emit("start_of_speech", ev) self._speaking = True @@ -192,7 +184,6 @@ async def _on_vad_event(self, ev: vad.VADEvent) -> None: self.emit("vad_inference_done", ev) elif ev.type == vad.VADEventType.END_OF_SPEECH: - tracing.Tracing.log_event("end of speech") self.emit("end_of_speech", ev) self._speaking = False diff --git a/livekit-agents/livekit/agents/pipeline/events.py b/livekit-agents/livekit/agents/pipeline/events.py new file mode 100644 index 000000000..bcf311a76 --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/events.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + + +@dataclass +class UserStartedSpeakingEvent: + pass + + +@dataclass +class UserStoppedSpeakingEvent: + pass diff --git a/livekit-agents/livekit/agents/pipeline/human_input.py b/livekit-agents/livekit/agents/pipeline/human_input.py deleted file mode 100644 index bd875cf99..000000000 --- a/livekit-agents/livekit/agents/pipeline/human_input.py +++ /dev/null @@ -1,150 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import Literal - -from livekit import rtc - -from .. import stt as speech_to_text -from .. import transcription, utils -from .. import vad as voice_activity_detection -from .log import logger - -EventTypes = Literal[ - "start_of_speech", - "vad_inference_done", - "end_of_speech", - "final_transcript", - "interim_transcript", -] - - -class AudioRecognition(utils.EventEmitter[EventTypes]): - def __init__( - self, - ) -> None: - super().__init__() - self._room, self._vad, self._stt, self._participant, self._transcription = ( - room, - vad, - stt, - participant, - transcription, - ) - self._subscribed_track: rtc.RemoteAudioTrack | None = None - self._recognize_atask: asyncio.Task[None] | None = None - - self._closed = False - self._speaking = False - self._speech_probability = 0.0 - - self._room.on("track_published", self._subscribe_to_microphone) - self._room.on("track_subscribed", self._subscribe_to_microphone) - self._subscribe_to_microphone() - - async def aclose(self) -> None: - if self._closed: - raise RuntimeError("HumanInput already closed") - - self._closed = True - self._room.off("track_published", self._subscribe_to_microphone) - self._room.off("track_subscribed", self._subscribe_to_microphone) - self._speaking = False - - if self._recognize_atask is not None: - await utils.aio.gracefully_cancel(self._recognize_atask) - - @property - def speaking(self) -> bool: - return self._speaking - - @property - def speaking_probability(self) -> float: - return self._speech_probability - - def _subscribe_to_microphone(self, *args, **kwargs) -> None: - """ - Subscribe to the participant microphone if found and not already subscribed. - Do nothing if no track is found. - """ - for publication in self._participant.track_publications.values(): - if publication.source != rtc.TrackSource.SOURCE_MICROPHONE: - continue - - if not publication.subscribed: - publication.set_subscribed(True) - - track: rtc.RemoteAudioTrack | None = publication.track # type: ignore - if track is not None and track != self._subscribed_track: - self._subscribed_track = track - if self._recognize_atask is not None: - self._recognize_atask.cancel() - - self._recognize_atask = asyncio.create_task( - self._recognize_task(rtc.AudioStream(track, sample_rate=16000)) - ) - break - - @utils.log_exceptions(logger=logger) - async def _recognize_task(self, audio_stream: rtc.AudioStream) -> None: - """ - Receive the frames from the user audio stream and detect voice activity. - """ - vad_stream = self._vad.stream() - stt_stream = self._stt.stream() - - def _before_forward( - fwd: transcription.STTSegmentsForwarder, transcription: rtc.Transcription - ): - if not self._transcription: - transcription.segments = [] - - return transcription - - stt_forwarder = transcription.STTSegmentsForwarder( - room=self._room, - participant=self._participant, - track=self._subscribed_track, - before_forward_cb=_before_forward, - ) - - async def _audio_stream_co() -> None: - # forward the audio stream to the VAD and STT streams - async for ev in audio_stream: - stt_stream.push_frame(ev.frame) - vad_stream.push_frame(ev.frame) - - async def _vad_stream_co() -> None: - async for ev in vad_stream: - if ev.type == voice_activity_detection.VADEventType.START_OF_SPEECH: - self._speaking = True - self.emit("start_of_speech", ev) - elif ev.type == voice_activity_detection.VADEventType.INFERENCE_DONE: - self._speech_probability = ev.probability - self.emit("vad_inference_done", ev) - elif ev.type == voice_activity_detection.VADEventType.END_OF_SPEECH: - self._speaking = False - self.emit("end_of_speech", ev) - - async def _stt_stream_co() -> None: - async for ev in stt_stream: - stt_forwarder.update(ev) - - if ev.type == speech_to_text.SpeechEventType.FINAL_TRANSCRIPT: - self.emit("final_transcript", ev) - elif ev.type == speech_to_text.SpeechEventType.INTERIM_TRANSCRIPT: - self.emit("interim_transcript", ev) - - tasks = [ - asyncio.create_task(_audio_stream_co()), - asyncio.create_task(_vad_stream_co()), - asyncio.create_task(_stt_stream_co()), - ] - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) - - await stt_forwarder.aclose() - await stt_stream.aclose() - await vad_stream.aclose() diff --git a/livekit-agents/livekit/agents/pipeline/impl.py b/livekit-agents/livekit/agents/pipeline/impl.py deleted file mode 100644 index 14641a889..000000000 --- a/livekit-agents/livekit/agents/pipeline/impl.py +++ /dev/null @@ -1,183 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import Protocol - -from .. import io, llm, stt, utils, vad -from ..utils import aio - - -class _TurnDetector(Protocol): - # TODO: Move those two functions to EOU ctor (capabilities dataclass) - def unlikely_threshold(self) -> float: ... - def supports_language(self, language: str | None) -> bool: ... - - async def predict_end_of_turn(self, chat_ctx: llm.ChatContext) -> float: ... - - -class AudioRecognition: - """ - Audio recognition part of the PipelineAgent. - The class is always instantiated but no tasks may be running if STT/VAD is disabled - - This class is also responsible for the end of turn detection. - """ - - UNLIKELY_END_OF_TURN_EXTRA_DELAY = 6.0 - - def __init__( - self, - *, - pipeline_agent: "PipelineAgent", - stt: stt.STT | None, - vad: vad.VAD | None, - turn_detector: _TurnDetector | None = None, - min_endpointing_delay: float, - ) -> None: - self._pipeline_agent = weakref.ref(pipeline_agent) - self._stt_atask: asyncio.Task[None] | None = None - self._vad_atask: asyncio.Task[None] | None = None - self._end_of_turn_task: asyncio.Task[None] | None = None - self._audio_input: io.AudioStream | None = None - self._min_endpointing_delay = min_endpointing_delay - - self._init_stt(stt) - self._init_vad(vad) - self._turn_detector = turn_detector - - self._speaking = False - self._audio_transcript = "" - self._last_language: str | None = None - - @property - def audio_input(self) -> io.AudioStream | None: - return self._audio_input - - @audio_input.setter - def audio_input(self, audio_input: io.AudioStream | None) -> None: - self._init_stt(self._stt) - self._init_vad(self._vad) - self._audio_input = audio_input - - async def _on_stt_event(self, ev: stt.SpeechEvent) -> None: - if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: - transcript = ev.alternatives[0].text - if not transcript: - return - - logger.debug( - "received user transcript", - extra={"user_transcript": new_transcript}, - ) - - self._audio_transcript += f" {transcript}" - self._audio_transcript = self._audio_transcript.lstrip() - - if not self._speaking: - self._run_eou_detection(pipeline_agent.chat_ctx, self._audio_transcript) - - async def _on_vad_event(self, ev: vad.VADEvent) -> None: - if ev.type == vad.VADEventType.START_OF_SPEECH: - self._speaking = True - - if self._end_of_turn_task is not None: - self._end_of_turn_task.cancel() - - elif ev.tupe == vad.VADEventType.END_OF_SPEECH: - self._speaking = False - - def _on_end_of_turn(self) -> None: - # start llm generation - pass - - async def aclose(self) -> None: - if self._stt_atask is not None: - await aio.gracefully_cancel(self._stt_atask) - - if self._vad_atask is not None: - await aio.gracefully_cancel(self._vad_atask) - - if self._end_of_turn_task is not None: - await aio.gracefully_cancel(self._end_of_turn_task) - - def _run_eou_detection( - self, chat_ctx: llm.ChatContext, new_transcript: str - ) -> None: - chat_ctx = pipeline_agent.chat_ctx.copy() - chat_ctx.append(role="user", text=new_transcript) - turn_detector = self._turn_detector - - @utils.log_exceptions(logger=logger) - async def _bounce_eou_task() -> None: - await asyncio.sleep(self._min_endpointing_delay) - - if turn_detector is not None and turn_detector.supports_language( - self._last_language - ): - end_of_turn_probability = await turn_detector.predict_end_of_turn( - chat_ctx - ) - unlikely_threshold = turn_detector.unlikely_threshold() - if end_of_turn_probability > unlikely_threshold: - await asyncio.sleep(self.UNLIKELY_END_OF_TURN_EXTRA_DELAY) - - self._on_end_of_turn() - - if self._end_of_turn_task is not None: - self._end_of_turn_task.cancel() - - self._end_of_turn_task = asyncio.create_task(_bounce_eou_task()) - - async def _stt_task( - self, stt: stt.STT, audio_input: io.AudioStream, task: asyncio.Task[None] | None - ) -> None: - if task is not None: - await aio.gracefully_cancel(task) - - stream = stt.stream() - - async def _forward() -> None: - async for frame in audio_input: - stream.push_frame(frame) - - forward_task = asyncio.create_task(_forward()) - - try: - async for ev in stream: - await self._on_stt_event(ev) - finally: - await stream.aclose() - await aio.gracefully_cancel(forward_task) - - async def _vad_task( - self, vad: vad.VAD, audio_input: io.AudioStream, task: asyncio.Task[None] | None - ) -> None: - if task is not None: - await aio.gracefully_cancel(task) - - stream = vad.stream() - - async def _forward() -> None: - async for frame in audio_input: - stream.push_frame(frame) - - forward_task = asyncio.create_task(_forward()) - - try: - async for ev in stream: - await self._on_vad_event(ev) - finally: - await stream.aclose() - await aio.gracefully_cancel(forward_task) - - def init_stt(self, stt: stt.STT, audio_input: io.AudioStream) -> None: - self._stt = stt - self._stt_atask = asyncio.create_task( - self._stt_task(stt, audio_input, self._stt_atask) - ) - - def init_vad(self, vad: vad.VAD, audio_input: io.AudioStream) -> None: - self._vad = vad - self._vad_atask = asyncio.create_task( - self._vad_task(vad, audio_input, self._vad_atask) - ) diff --git a/livekit-agents/livekit/agents/pipeline/io.py b/livekit-agents/livekit/agents/pipeline/io.py index accc00ee5..fea6d01a8 100644 --- a/livekit-agents/livekit/agents/pipeline/io.py +++ b/livekit-agents/livekit/agents/pipeline/io.py @@ -19,7 +19,7 @@ STTNode = Callable[ [AsyncIterable[rtc.AudioFrame]], - Union[Awaitable[Optional[AsyncIterable[stt.SpeechEvent]]],], + Union[Awaitable[Optional[AsyncIterable[stt.SpeechEvent]]]], ] LLMNode = Callable[ [llm.ChatContext, Optional[llm.FunctionContext]], diff --git a/livekit-agents/livekit/agents/pipeline/multimodal.py b/livekit-agents/livekit/agents/pipeline/multimodal.py deleted file mode 100644 index d572a37ad..000000000 --- a/livekit-agents/livekit/agents/pipeline/multimodal.py +++ /dev/null @@ -1,11 +0,0 @@ -EventTypes = Literal[""] - - -class MultimodalModel: - """Model handling multiple modalities (video/audio/text) - - MultimodalModel assumes stateful multimodal input and output. (MultimodalSession). - """ - - def __init(self) -> None: - pass diff --git a/livekit-agents/livekit/agents/pipeline/pipeline2.py b/livekit-agents/livekit/agents/pipeline/pipeline2.py deleted file mode 100644 index 8d130621c..000000000 --- a/livekit-agents/livekit/agents/pipeline/pipeline2.py +++ /dev/null @@ -1,645 +0,0 @@ -from __future__ import annotations, print_function - -import asyncio -import contextlib -import heapq - -from dataclasses import dataclass -from typing import ( - AsyncIterable, - Tuple, - Literal, - Optional, - Union, -) - -from livekit import rtc - -from .. import llm, stt, tts, utils, vad, debug, tokenize -from ..llm import ChatContext, FunctionContext -from ..log import logger -from . import io -from .audio_recognition import AudioRecognition, _TurnDetector -from .generation import ( - do_llm_inference, - do_tts_inference, - _TTSGenerationData, -) - - -class SpeechHandle: - def __init__( - self, *, speech_id: str, allow_interruptions: bool, step_index: int - ) -> None: - self._id = speech_id - self._step_index = step_index - self._allow_interruptions = allow_interruptions - self._interrupt_fut = asyncio.Future() - self._done_fut = asyncio.Future() - self._play_fut = asyncio.Future() - self._playout_done_fut = asyncio.Future() - - @staticmethod - def create(allow_interruptions: bool = True, step_index: int = 0) -> SpeechHandle: - return SpeechHandle( - speech_id=utils.shortuuid("SH_"), - allow_interruptions=allow_interruptions, - step_index=step_index, - ) - - @property - def id(self) -> str: - return self._id - - @property - def step_index(self) -> int: - return self._step_index - - @property - def interrupted(self) -> bool: - return self._interrupt_fut.done() - - @property - def allow_interruptions(self) -> bool: - return self._allow_interruptions - - def play(self) -> None: - self._play_fut.set_result(None) - - def done(self) -> bool: - return self._done_fut.done() - - def interrupt(self) -> None: - if not self._allow_interruptions: - raise ValueError("This generation handle does not allow interruptions") - - if self.done(): - return - - self._done_fut.set_result(None) - self._interrupt_fut.set_result(None) - - async def wait_for_playout(self) -> None: - await asyncio.shield(self._playout_done_fut) - - def _mark_playout_done(self) -> None: - self._playout_done_fut.set_result(None) - - def _mark_done(self) -> None: - with contextlib.suppress(asyncio.InvalidStateError): - # will raise InvalidStateError if the future is already done (interrupted) - self._done_fut.set_result(None) - - -EventTypes = Literal[ - "user_started_speaking", - "user_stopped_speaking", - "agent_started_speaking", - "agent_stopped_speaking", - "user_message_committed", - "agent_message_committed", - "agent_message_interrupted", -] - - -@dataclass -class _PipelineOptions: - language: str | None - allow_interruptions: bool - min_interruption_duration: float - min_endpointing_delay: float - - -class PipelineAgent(rtc.EventEmitter[EventTypes]): - SPEECH_PRIORITY_LOW = 0 - """Priority for messages that should be played after all other messages in the queue""" - SPEECH_PRIORITY_NORMAL = 5 - """Every speech generates by the PipelineAgent defaults to this priority.""" - SPEECH_PRIORITY_HIGH = 10 - """Priority for important messages that should be played before others.""" - - def __init__( - self, - *, - llm: llm.LLM | None = None, - vad: vad.VAD | None = None, - stt: stt.STT | None = None, - tts: tts.TTS | None = None, - turn_detector: _TurnDetector | None = None, - language: str | None = None, - chat_ctx: ChatContext | None = None, - fnc_ctx: FunctionContext | None = None, - allow_interruptions: bool = True, - min_interruption_duration: float = 0.5, - min_endpointing_delay: float = 0.5, - max_fnc_steps: int = 5, - loop: asyncio.AbstractEventLoop | None = None, - ) -> None: - super().__init__() - self._loop = loop or asyncio.get_event_loop() - - self._chat_ctx = chat_ctx or ChatContext() - self._fnc_ctx = fnc_ctx - - self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts - - if tts and not tts.capabilities.streaming: - from .. import tts as text_to_speech - - tts = text_to_speech.StreamAdapter( - tts=tts, sentence_tokenizer=tokenize.basic.SentenceTokenizer() - ) - - if stt and not stt.capabilities.streaming: - from .. import stt as speech_to_text - - if vad is None: - raise ValueError( - "VAD is required when streaming is not supported by the STT" - ) - - stt = speech_to_text.StreamAdapter( - stt=stt, - vad=vad, - ) - - self._turn_detector = turn_detector - self._audio_recognition = AudioRecognition( - agent=self, - stt=self.stt_node, - vad=vad, - turn_detector=turn_detector, - min_endpointing_delay=min_endpointing_delay, - chat_ctx=self._chat_ctx, - loop=self._loop, - ) - - self._opts = _PipelineOptions( - language=language, - allow_interruptions=allow_interruptions, - min_interruption_duration=min_interruption_duration, - min_endpointing_delay=min_endpointing_delay, - ) - - self._max_fnc_steps = max_fnc_steps - self._audio_recognition.on("end_of_turn", self._on_audio_end_of_turn) - self._audio_recognition.on("vad_inference_done", self._on_vad_inference_done) - - # configurable IO - self._input = io.AgentInput( - self._on_video_input_changed, self._on_audio_input_changed - ) - self._output = io.AgentOutput( - self._on_video_output_changed, - self._on_audio_output_changed, - self._on_text_output_changed, - ) - - self._current_speech: SpeechHandle | None = None - self._speech_q: list[Tuple[int, SpeechHandle]] = [] - self._speech_q_changed = asyncio.Event() - self._speech_tasks = [] - - self._speech_scheduler_task: asyncio.Task | None = None - - # -- Pipeline nodes -- - # They can all be overriden by subclasses, by default they use the STT/LLM/TTS specified in the - # constructor of the PipelineAgent - - async def stt_node( - self, audio: AsyncIterable[rtc.AudioFrame] - ) -> Optional[AsyncIterable[stt.SpeechEvent]]: - assert self._stt is not None, "stt_node called but no STT node is available" - - async with self._stt.stream() as stream: - - async def _forward_input(): - async for frame in audio: - stream.push_frame(frame) - - forward_task = asyncio.create_task(_forward_input()) - try: - async for event in stream: - yield event - finally: - await utils.aio.gracefully_cancel(forward_task) - - async def llm_node( - self, chat_ctx: llm.ChatContext, fnc_ctx: llm.FunctionContext | None - ) -> Union[ - Optional[AsyncIterable[llm.ChatChunk]], - Optional[AsyncIterable[str]], - Optional[str], - ]: - assert self._llm is not None, "llm_node called but no LLM node is available" - - async with self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) as stream: - async for chunk in stream: - yield chunk - - async def tts_node( - self, text: AsyncIterable[str] - ) -> Optional[AsyncIterable[rtc.AudioFrame]]: - assert self._tts is not None, "tts_node called but no TTS node is available" - - async with self._tts.stream() as stream: - - async def _forward_input(): - async for chunk in text: - stream.push_text(chunk) - - stream.end_input() - - forward_task = asyncio.create_task(_forward_input()) - try: - async for ev in stream: - yield ev.frame - finally: - await utils.aio.gracefully_cancel(forward_task) - - def start(self) -> None: - self._audio_recognition.start() - self._speech_scheduler_task = asyncio.create_task( - self._playout_scheduler(), name="_playout_scheduler" - ) - - async def aclose(self) -> None: - await self._audio_recognition.aclose() - - @property - def input(self) -> io.AgentInput: - return self._input - - @property - def output(self) -> io.AgentOutput: - return self._output - - # TODO(theomonnom): find a better name than `generation` - @property - def current_speech(self) -> SpeechHandle | None: - return self._current_speech - - @property - def chat_ctx(self) -> llm.ChatContext: - return self._chat_ctx - - def update_options(self) -> None: - pass - - def say(self, text: str | AsyncIterable[str]) -> SpeechHandle: - pass - - def generate_reply(self, user_input: str) -> SpeechHandle: - if self._current_speech is not None and not self._current_speech.interrupted: - raise ValueError("another reply is already in progress") - - debug.Tracing.log_event("generate_reply", {"user_input": user_input}) - self._chat_ctx.append(role="user", text=user_input) # TODO(theomonnom) Remove - - handle = SpeechHandle.create(allow_interruptions=self._opts.allow_interruptions) - task = asyncio.create_task( - self._generate_pipeline_reply_task( - handle=handle, - chat_ctx=self._chat_ctx, - fnc_ctx=self._fnc_ctx, - ), - name="_generate_pipeline_reply", - ) - self._schedule_speech(handle, task, self.SPEECH_PRIORITY_NORMAL) - return handle - - # -- Main generation task -- - - def _schedule_speech( - self, speech: SpeechHandle, task: asyncio.Task, priority: int - ) -> None: - self._speech_tasks.append(task) - task.add_done_callback(lambda _: self._speech_tasks.remove(task)) - - heapq.heappush(self._speech_q, (priority, speech)) - self._speech_q_changed.set() - - @utils.log_exceptions(logger=logger) - async def _playout_scheduler(self) -> None: - while True: - await self._speech_q_changed.wait() - - while self._speech_q: - _, speech = heapq.heappop(self._speech_q) - self._current_speech = speech - speech.play() - await speech.wait_for_playout() - self._current_speech = None - - self._speech_q_changed.clear() - - @utils.log_exceptions(logger=logger) - async def _generate_pipeline_reply_task( - self, - *, - handle: SpeechHandle, - chat_ctx: ChatContext, - fnc_ctx: FunctionContext | None, - ) -> None: - @utils.log_exceptions(logger=logger) - async def _forward_llm_text(llm_output: AsyncIterable[str]) -> None: - """collect and forward the generated text to the current agent output""" - if self.output.text is None: - return - - try: - async for delta in llm_output: - await self.output.text.capture_text(delta) - finally: - self.output.text.flush() - - @utils.log_exceptions(logger=logger) - async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: - """collect and forward the generated audio to the current agent output (generally playout)""" - if self.output.audio is None: - return - - try: - async for frame in tts_output: - await self.output.audio.capture_frame(frame) - finally: - self.output.audio.flush() # always flush (even if the task is interrupted) - - @utils.log_exceptions(logger=logger) - async def _execute_tools( - tools_ch: utils.aio.Chan[llm.FunctionCallInfo], - called_functions: set[llm.CalledFunction], - ) -> None: - """execute tools, when cancelled, stop executing new tools but wait for the pending ones""" - try: - async for tool in tools_ch: - logger.debug( - "executing tool", - extra={ - "function": tool.function_info.name, - "speech_id": handle.id, - }, - ) - debug.Tracing.log_event( - "executing tool", - { - "function": tool.function_info.name, - "speech_id": handle.id, - }, - ) - cfnc = tool.execute() - called_functions.add(cfnc) - except asyncio.CancelledError: - # don't allow to cancel running function calla if they're still running - pending_tools = [cfn for cfn in called_functions if not cfn.task.done()] - - if pending_tools: - names = [cfn.call_info.function_info.name for cfn in pending_tools] - - logger.debug( - "waiting for function call to finish before cancelling", - extra={ - "functions": names, - "speech_id": handle.id, - }, - ) - debug.Tracing.log_event( - "waiting for function call to finish before cancelling", - { - "functions": names, - "speech_id": handle.id, - }, - ) - await asyncio.gather(*[cfn.task for cfn in pending_tools]) - finally: - if len(called_functions) > 0: - logger.debug( - "tools execution completed", - extra={"speech_id": handle.id}, - ) - debug.Tracing.log_event( - "tools execution completed", - {"speech_id": handle.id}, - ) - - debug.Tracing.log_event( - "generation started", - {"speech_id": handle.id, "step_index": handle.step_index}, - ) - - wg = utils.aio.WaitGroup() - tasks = [] - llm_task, llm_gen_data = do_llm_inference( - node=self.llm_node, - chat_ctx=chat_ctx, - fnc_ctx=fnc_ctx - if handle.step_index < self._max_fnc_steps - 1 and handle.step_index >= 2 - else None, - ) - tasks.append(llm_task) - wg.add(1) - llm_task.add_done_callback(lambda _: wg.done()) - tts_text_input, llm_output = utils.aio.itertools.tee(llm_gen_data.text_ch) - - tts_task: asyncio.Task | None = None - tts_gen_data: _TTSGenerationData | None = None - if self._output.audio is not None: - tts_task, tts_gen_data = do_tts_inference( - node=self.tts_node, input=tts_text_input - ) - tasks.append(tts_task) - wg.add(1) - tts_task.add_done_callback(lambda _: wg.done()) - - # wait for the play() method to be called - await asyncio.wait( - [ - handle._play_fut, - handle._interrupt_fut, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - - if handle.interrupted: - await utils.aio.gracefully_cancel(*tasks) - handle._mark_done() - return # return directly (the generated output wasn't used) - - # forward tasks are started after the play() method is called - # they redirect the generated text/audio to the output channels - forward_llm_task = asyncio.create_task( - _forward_llm_text(llm_output), - name="_generate_reply_task.forward_llm_text", - ) - tasks.append(forward_llm_task) - wg.add(1) - forward_llm_task.add_done_callback(lambda _: wg.done()) - - forward_tts_task: asyncio.Task | None = None - if tts_gen_data is not None: - forward_tts_task = asyncio.create_task( - _forward_tts_audio(tts_gen_data.audio_ch), - name="_generate_reply_task.forward_tts_audio", - ) - tasks.append(forward_tts_task) - wg.add(1) - forward_tts_task.add_done_callback(lambda _: wg.done()) - - # start to execute tools (only after play()) - called_functions: set[llm.CalledFunction] = set() - tools_task = asyncio.create_task( - _execute_tools(llm_gen_data.tools_ch, called_functions), - name="_generate_reply_task.execute_tools", - ) - tasks.append(tools_task) - wg.add(1) - tools_task.add_done_callback(lambda _: wg.done()) - - # wait for the tasks to finish - await asyncio.wait( - [ - wg.wait(), - handle._interrupt_fut, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - - # wait for the end of the playout if the audio is enabled - if forward_llm_task is not None: - assert self._output.audio is not None - await asyncio.wait( - [ - self._output.audio.wait_for_playout(), - handle._interrupt_fut, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - - if handle.interrupted: - await utils.aio.gracefully_cancel(*tasks) - - if len(called_functions) > 0: - functions = [ - cfnc.call_info.function_info.name for cfnc in called_functions - ] - logger.debug( - "speech interrupted, ignoring generation of the function calls results", - extra={"speech_id": handle.id, "functions": functions}, - ) - debug.Tracing.log_event( - "speech interrupted, ignoring generation of the function calls results", - {"speech_id": handle.id, "functions": functions}, - ) - - # if the audio playout was enabled, clear the buffer - if forward_tts_task is not None: - assert self._output.audio is not None - - self._output.audio.clear_buffer() - playback_ev = await self._output.audio.wait_for_playout() - - debug.Tracing.log_event( - "playout interrupted", - { - "playback_position": playback_ev.playback_position, - "speech_id": handle.id, - }, - ) - - handle._mark_playout_done() - # TODO(theomonnom): calculate the played text based on playback_ev.playback_position - - handle._mark_done() - return - - handle._mark_playout_done() - debug.Tracing.log_event("playout completed", {"speech_id": handle.id}) - - if len(called_functions) > 0: - if handle.step_index >= self._max_fnc_steps: - logger.warning( - "maximum number of function calls steps reached", - extra={"speech_id": handle.id}, - ) - debug.Tracing.log_event( - "maximum number of function calls steps reached", - {"speech_id": handle.id}, - ) - handle._mark_done() - return - - # create a new SpeechHandle to generate the result of the function calls - handle = SpeechHandle.create( - allow_interruptions=self._opts.allow_interruptions, - step_index=handle.step_index + 1, - ) - task = asyncio.create_task( - self._generate_pipeline_reply_task( - handle=handle, - chat_ctx=chat_ctx, - fnc_ctx=fnc_ctx, - ), - name="_generate_pipeline_reply", - ) - self._schedule_speech(handle, task, self.SPEECH_PRIORITY_NORMAL) - - handle._mark_done() - - # -- Audio recognition -- - - def _on_audio_end_of_turn(self, new_transcript: str) -> None: - # When the audio recognition detects the end of a user turn: - # - check if there is no current generation happening - # - cancel the current generation if it allows interruptions (otherwise skip this current - # turn) - # - generate a reply to the user input - - if self._current_speech is not None: - if self._current_speech.allow_interruptions: - logger.warning( - "skipping user input, current speech generation cannot be interrupted", - extra={"user_input": new_transcript}, - ) - return - - debug.Tracing.log_event( - "speech interrupted, new user turn detected", - {"speech_id": self._current_speech.id}, - ) - self._current_speech.interrupt() - - self.generate_reply(new_transcript) - - def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: - if ev.speech_duration > self._opts.min_interruption_duration: - if ( - self._current_speech is not None - and not self._current_speech.interrupted - and self._current_speech.allow_interruptions - ): - debug.Tracing.log_event( - "speech interrupted by vad", - {"speech_id": self._current_speech.id}, - ) - self._current_speech.interrupt() - - # --- - - # -- User changed input/output streams/sinks -- - - def _on_video_input_changed(self) -> None: - pass - - def _on_audio_input_changed(self) -> None: - self._audio_recognition.audio_input = self._input.audio - - def _on_video_output_changed(self) -> None: - pass - - def _on_audio_output_changed(self) -> None: - pass - - def _on_text_output_changed(self) -> None: - pass - - # --- diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 806579fb6..f668476ec 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -1,1330 +1,665 @@ -from __future__ import annotations +from __future__ import annotations, print_function import asyncio -import contextvars -import time +import contextlib +import heapq + from dataclasses import dataclass from typing import ( - Any, - AsyncGenerator, AsyncIterable, - Awaitable, - Callable, + Tuple, Literal, Optional, - Protocol, Union, ) from livekit import rtc -from .. import metrics, stt, tokenize, tts, utils, vad -from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream -from ..types import ATTRIBUTE_AGENT_STATE, AgentState -from .agent_output import AgentOutput, SpeechSource, SynthesisHandle -from .agent_playout import AgentPlayout -from .human_input import HumanInput -from .log import logger -from .plotter import AssistantPlotter -from .speech_handle import SpeechHandle - -BeforeLLMCallback = Callable[ - ["VoicePipelineAgent", ChatContext], - Union[ - Optional[LLMStream], - Awaitable[Optional[LLMStream]], - Literal[False], - Awaitable[Literal[False]], - ], -] +from .. import llm, stt, tts, utils, vad, debug, tokenize +from ..llm import ChatContext, FunctionContext +from ..log import logger +from . import io, events +from .audio_recognition import AudioRecognition, _TurnDetector +from .generation import ( + do_llm_inference, + do_tts_inference, + _TTSGenerationData, +) -WillSynthesizeAssistantReply = BeforeLLMCallback -BeforeTTSCallback = Callable[ - ["VoicePipelineAgent", Union[str, AsyncIterable[str]]], - SpeechSource, -] +class AgentContext: + pass -EventTypes = Literal[ - "user_started_speaking", - "user_stopped_speaking", - "agent_started_speaking", - "agent_stopped_speaking", - "user_speech_committed", - "agent_speech_committed", - "agent_speech_interrupted", - "function_calls_collected", - "function_calls_finished", - "metrics_collected", -] - -_CallContextVar = contextvars.ContextVar["AgentCallContext"]( - "voice_assistant_contextvar" -) - - -class AgentCallContext: - def __init__(self, assistant: "VoicePipelineAgent", llm_stream: LLMStream) -> None: - self._assistant = assistant - self._metadata = dict[str, Any]() - self._llm_stream = llm_stream - self._extra_chat_messages: list[ChatMessage] = [] +class SpeechHandle: + def __init__( + self, *, speech_id: str, allow_interruptions: bool, step_index: int + ) -> None: + self._id = speech_id + self._step_index = step_index + self._allow_interruptions = allow_interruptions + self._interrupt_fut = asyncio.Future() + self._done_fut = asyncio.Future() + self._play_fut = asyncio.Future() + self._playout_done_fut = asyncio.Future() @staticmethod - def get_current() -> "AgentCallContext": - return _CallContextVar.get() + def create(allow_interruptions: bool = True, step_index: int = 0) -> SpeechHandle: + return SpeechHandle( + speech_id=utils.shortuuid("speech_"), + allow_interruptions=allow_interruptions, + step_index=step_index, + ) @property - def agent(self) -> "VoicePipelineAgent": - return self._assistant + def id(self) -> str: + return self._id @property - def chat_ctx(self) -> ChatContext: - return self._llm_stream.chat_ctx - - def store_metadata(self, key: str, value: Any) -> None: - self._metadata[key] = value + def step_index(self) -> int: + return self._step_index - def get_metadata(self, key: str, default: Any = None) -> Any: - return self._metadata.get(key, default) + @property + def interrupted(self) -> bool: + return self._interrupt_fut.done() - def llm_stream(self) -> LLMStream: - return self._llm_stream + @property + def allow_interruptions(self) -> bool: + return self._allow_interruptions - def add_extra_chat_message(self, message: ChatMessage) -> None: - """Append chat message to the end of function outputs for the answer LLM call""" - self._extra_chat_messages.append(message) + def play(self) -> None: + self._play_fut.set_result(None) - @property - def extra_chat_messages(self) -> list[ChatMessage]: - return self._extra_chat_messages + def done(self) -> bool: + return self._done_fut.done() + def interrupt(self) -> None: + if not self._allow_interruptions: + raise ValueError("This generation handle does not allow interruptions") -def _default_before_llm_cb( - agent: VoicePipelineAgent, chat_ctx: ChatContext -) -> LLMStream: - return agent.llm.chat( - chat_ctx=chat_ctx, - fnc_ctx=agent.fnc_ctx, - ) + if self.done(): + return + self._done_fut.set_result(None) + self._interrupt_fut.set_result(None) -@dataclass -class SpeechData: - sequence_id: str + async def wait_for_playout(self) -> None: + await asyncio.shield(self._playout_done_fut) + def _mark_playout_done(self) -> None: + self._playout_done_fut.set_result(None) -SpeechDataContextVar = contextvars.ContextVar[SpeechData]("voice_assistant_speech_data") + def _mark_done(self) -> None: + with contextlib.suppress(asyncio.InvalidStateError): + # will raise InvalidStateError if the future is already done (interrupted) + self._done_fut.set_result(None) -def _default_before_tts_cb( - agent: VoicePipelineAgent, text: str | AsyncIterable[str] -) -> str | AsyncIterable[str]: - return text +EventTypes = Literal[ + "user_started_speaking", + "user_stopped_speaking", + "agent_started_speaking", + "agent_stopped_speaking", + "user_message_committed", + "agent_message_committed", + "agent_message_interrupted", +] -@dataclass(frozen=True) -class _ImplOptions: +@dataclass +class _PipelineOptions: + language: str | None allow_interruptions: bool - int_speech_duration: float - int_min_words: int + min_interruption_duration: float min_endpointing_delay: float - max_endpointing_delay: float - max_nested_fnc_calls: int - preemptive_synthesis: bool - before_llm_cb: BeforeLLMCallback - before_tts_cb: BeforeTTSCallback - plotting: bool - transcription: AgentTranscriptionOptions - - -@dataclass(frozen=True) -class AgentTranscriptionOptions: - user_transcription: bool = True - """Whether to forward the user transcription to the client""" - agent_transcription: bool = True - """Whether to forward the agent transcription to the client""" - agent_transcription_speed: float = 1.0 - """The speed at which the agent's speech transcription is forwarded to the client. - We try to mimic the agent's speech speed by adjusting the transcription speed.""" - sentence_tokenizer: tokenize.SentenceTokenizer = tokenize.basic.SentenceTokenizer() - """The tokenizer used to split the speech into sentences. - This is used to decide when to mark a transcript as final for the agent transcription.""" - word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer( - ignore_punctuation=False - ) - """The tokenizer used to split the speech into words. - This is used to simulate the "interim results" of the agent transcription.""" - hyphenate_word: Callable[[str], list[str]] = tokenize.basic.hyphenate_word - """A function that takes a string (word) as input and returns a list of strings, - representing the hyphenated parts of the word.""" - - -class _TurnDetector(Protocol): - # When endpoint probability is below this threshold we think the user is not finished speaking - # so we will use a long delay - def unlikely_threshold(self) -> float: ... - def supports_language(self, language: str | None) -> bool: ... - async def predict_end_of_turn(self, chat_ctx: ChatContext) -> float: ... - - -class VoicePipelineAgent(utils.EventEmitter[EventTypes]): - """ - A pipeline agent (VAD + STT + LLM + TTS) implementation. - """ - - MIN_TIME_PLAYED_FOR_COMMIT = 1.5 - """Minimum time played for the user speech to be committed to the chat context""" + max_fnc_steps: int + + +class PipelineAgent(rtc.EventEmitter[EventTypes]): + SPEECH_PRIORITY_LOW = 0 + """Priority for messages that should be played after all other messages in the queue""" + SPEECH_PRIORITY_NORMAL = 5 + """Every speech generates by the PipelineAgent defaults to this priority.""" + SPEECH_PRIORITY_HIGH = 10 + """Priority for important messages that should be played before others.""" def __init__( self, *, - vad: vad.VAD, - stt: stt.STT, - llm: LLM, - tts: tts.TTS, + llm: llm.LLM | None = None, + vad: vad.VAD | None = None, + stt: stt.STT | None = None, + tts: tts.TTS | None = None, turn_detector: _TurnDetector | None = None, + language: str | None = None, chat_ctx: ChatContext | None = None, fnc_ctx: FunctionContext | None = None, allow_interruptions: bool = True, - interrupt_speech_duration: float = 0.5, - interrupt_min_words: int = 0, + min_interruption_duration: float = 0.5, min_endpointing_delay: float = 0.5, - max_endpointing_delay: float = 6.0, - max_nested_fnc_calls: int = 1, - preemptive_synthesis: bool = False, - transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(), - before_llm_cb: BeforeLLMCallback = _default_before_llm_cb, - before_tts_cb: BeforeTTSCallback = _default_before_tts_cb, - plotting: bool = False, + max_fnc_steps: int = 5, loop: asyncio.AbstractEventLoop | None = None, - # backward compatibility - will_synthesize_assistant_reply: WillSynthesizeAssistantReply | None = None, ) -> None: - """ - Create a new VoicePipelineAgent. - - Args: - vad: Voice Activity Detection (VAD) instance. - stt: Speech-to-Text (STT) instance. - llm: Large Language Model (LLM) instance. - tts: Text-to-Speech (TTS) instance. - chat_ctx: Chat context for the assistant. - fnc_ctx: Function context for the assistant. - allow_interruptions: Whether to allow the user to interrupt the assistant. - interrupt_speech_duration: Minimum duration of speech to consider for interruption. - interrupt_min_words: Minimum number of words to consider for interruption. - Defaults to 0 as this may increase the latency depending on the STT. - min_endpointing_delay: Delay to wait before considering the user finished speaking. - max_nested_fnc_calls: Maximum number of nested function calls allowed for chaining - function calls (e.g functions that depend on each other). - preemptive_synthesis: Whether to preemptively synthesize responses. - transcription: Options for assistant transcription. - before_llm_cb: Callback called when the assistant is about to synthesize a reply. - This can be used to customize the reply (e.g: inject context/RAG). - - Returning None will create a default LLM stream. You can also return your own llm - stream by calling the llm.chat() method. - - Returning False will cancel the synthesis of the reply. - before_tts_cb: Callback called when the assistant is about to - synthesize a speech. This can be used to customize text before the speech synthesis. - (e.g: editing the pronunciation of a word). - plotting: Whether to enable plotting for debugging. matplotlib must be installed. - loop: Event loop to use. Default to asyncio.get_event_loop(). - """ super().__init__() self._loop = loop or asyncio.get_event_loop() - if will_synthesize_assistant_reply is not None: - logger.warning( - "will_synthesize_assistant_reply is deprecated and will be removed in 1.5.0, use before_llm_cb instead", - ) - before_llm_cb = will_synthesize_assistant_reply - - self._opts = _ImplOptions( - plotting=plotting, - allow_interruptions=allow_interruptions, - int_speech_duration=interrupt_speech_duration, - int_min_words=interrupt_min_words, - min_endpointing_delay=min_endpointing_delay, - max_endpointing_delay=max_endpointing_delay, - max_nested_fnc_calls=max_nested_fnc_calls, - preemptive_synthesis=preemptive_synthesis, - transcription=transcription, - before_llm_cb=before_llm_cb, - before_tts_cb=before_tts_cb, - ) - self._plotter = AssistantPlotter(self._loop) + self._chat_ctx = chat_ctx or ChatContext() + self._fnc_ctx = fnc_ctx - # wrap with StreamAdapter automatically when streaming is not supported on a specific TTS/STT. - # To override StreamAdapter options, create the adapter manually. + self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts - if not tts.capabilities.streaming: + if tts and not tts.capabilities.streaming: from .. import tts as text_to_speech tts = text_to_speech.StreamAdapter( tts=tts, sentence_tokenizer=tokenize.basic.SentenceTokenizer() ) - if not stt.capabilities.streaming: + if stt and not stt.capabilities.streaming: from .. import stt as speech_to_text + if vad is None: + raise ValueError( + "VAD is required when streaming is not supported by the STT" + ) + stt = speech_to_text.StreamAdapter( stt=stt, vad=vad, ) - self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts self._turn_detector = turn_detector - self._chat_ctx = chat_ctx or ChatContext() - self._fnc_ctx = fnc_ctx - self._started, self._closed = False, False - - self._human_input: HumanInput | None = None - self._agent_output: AgentOutput | None = None - - # done when the agent output track is published - self._track_published_fut = asyncio.Future[None]() + self._audio_recognition = AudioRecognition( + agent=self, + stt=self.stt_node, + vad=vad, + turn_detector=turn_detector, + min_endpointing_delay=min_endpointing_delay, + chat_ctx=self._chat_ctx, + loop=self._loop, + ) - self._pending_agent_reply: SpeechHandle | None = None - self._agent_reply_task: asyncio.Task[None] | None = None + self._opts = _PipelineOptions( + language=language, + allow_interruptions=allow_interruptions, + min_interruption_duration=min_interruption_duration, + min_endpointing_delay=min_endpointing_delay, + max_fnc_steps=max_fnc_steps, + ) - self._playing_speech: SpeechHandle | None = None - self._transcribed_text, self._transcribed_interim_text = "", "" + self._audio_recognition.on("end_of_turn", self._on_audio_end_of_turn) + self._audio_recognition.on("start_of_speech", self._on_start_of_speech) + self._audio_recognition.on("end_of_speech", self._on_end_of_speech) + self._audio_recognition.on("vad_inference_done", self._on_vad_inference_done) - self._deferred_validation = _DeferredReplyValidation( - self._validate_reply_if_possible, - min_endpointing_delay=self._opts.min_endpointing_delay, - max_endpointing_delay=self._opts.max_endpointing_delay, - turn_detector=self._turn_detector, - agent=self, + # configurable IO + self._input = io.AgentInput( + self._on_video_input_changed, self._on_audio_input_changed + ) + self._output = io.AgentOutput( + self._on_video_output_changed, + self._on_audio_output_changed, + self._on_text_output_changed, ) - self._speech_q: list[SpeechHandle] = [] + self._current_speech: SpeechHandle | None = None + self._speech_q: list[Tuple[int, SpeechHandle]] = [] self._speech_q_changed = asyncio.Event() + self._speech_tasks = [] - self._update_state_task: asyncio.Task | None = None - - self._last_final_transcript_time: float | None = None - self._last_speech_time: float | None = None - - @property - def fnc_ctx(self) -> FunctionContext | None: - return self._fnc_ctx + self._speech_scheduler_task: asyncio.Task | None = None - @fnc_ctx.setter - def fnc_ctx(self, fnc_ctx: FunctionContext | None) -> None: - self._fnc_ctx = fnc_ctx + # -- Pipeline nodes -- + # They can all be overriden by subclasses, by default they use the STT/LLM/TTS specified in the + # constructor of the PipelineAgent - @property - def chat_ctx(self) -> ChatContext: - return self._chat_ctx + async def stt_node( + self, audio: AsyncIterable[rtc.AudioFrame] + ) -> Optional[AsyncIterable[stt.SpeechEvent]]: + assert self._stt is not None, "stt_node called but no STT node is available" - @property - def llm(self) -> LLM: - return self._llm + async with self._stt.stream() as stream: - @property - def tts(self) -> tts.TTS: - return self._tts + async def _forward_input(): + async for frame in audio: + stream.push_frame(frame) - @property - def stt(self) -> stt.STT: - return self._stt + forward_task = asyncio.create_task(_forward_input()) + try: + async for event in stream: + yield event + finally: + await utils.aio.gracefully_cancel(forward_task) - @property - def vad(self) -> vad.VAD: - return self._vad + async def llm_node( + self, chat_ctx: llm.ChatContext, fnc_ctx: llm.FunctionContext | None + ) -> Union[ + Optional[AsyncIterable[llm.ChatChunk]], + Optional[AsyncIterable[str]], + Optional[str], + ]: + assert self._llm is not None, "llm_node called but no LLM node is available" - def start( - self, room: rtc.Room, participant: rtc.RemoteParticipant | str | None = None - ) -> None: - """Start the voice assistant - - Args: - room: the room to use - participant: the participant to listen to, can either be a participant or a participant identity - If None, the first participant found in the room will be selected - """ - if self._started: - raise RuntimeError("voice assistant already started") - - @self._stt.on("metrics_collected") - def _on_stt_metrics(stt_metrics: metrics.STTMetrics) -> None: - self.emit( - "metrics_collected", - metrics.PipelineSTTMetrics( - **stt_metrics.__dict__, - ), - ) + async with self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) as stream: + async for chunk in stream: + yield chunk - @self._tts.on("metrics_collected") - def _on_tts_metrics(tts_metrics: metrics.TTSMetrics) -> None: - speech_data = SpeechDataContextVar.get(None) - if speech_data is None: - return + async def tts_node( + self, text: AsyncIterable[str] + ) -> Optional[AsyncIterable[rtc.AudioFrame]]: + assert self._tts is not None, "tts_node called but no TTS node is available" - self.emit( - "metrics_collected", - metrics.PipelineTTSMetrics( - **tts_metrics.__dict__, - sequence_id=speech_data.sequence_id, - ), - ) + async with self._tts.stream() as stream: - @self._llm.on("metrics_collected") - def _on_llm_metrics(llm_metrics: metrics.LLMMetrics) -> None: - speech_data = SpeechDataContextVar.get(None) - if speech_data is None: - return - self.emit( - "metrics_collected", - metrics.PipelineLLMMetrics( - **llm_metrics.__dict__, - sequence_id=speech_data.sequence_id, - ), - ) + async def _forward_input(): + async for chunk in text: + stream.push_text(chunk) - @self._vad.on("metrics_collected") - def _on_vad_metrics(vad_metrics: vad.VADMetrics) -> None: - self.emit( - "metrics_collected", metrics.PipelineVADMetrics(**vad_metrics.__dict__) - ) + stream.end_input() - room.on("participant_connected", self._on_participant_connected) - self._room, self._participant = room, participant - - if participant is not None: - if isinstance(participant, rtc.RemoteParticipant): - self._link_participant(participant.identity) - else: - self._link_participant(participant) - else: - # no participant provided, try to find the first participant in the room - for participant in self._room.remote_participants.values(): - self._link_participant(participant.identity) - break - - self._main_atask = asyncio.create_task(self._main_task()) - - def on(self, event: EventTypes, callback: Callable[[Any], None] | None = None): - """Register a callback for an event - - Args: - event: the event to listen to (see EventTypes) - - user_started_speaking: the user started speaking - - user_stopped_speaking: the user stopped speaking - - agent_started_speaking: the agent started speaking - - agent_stopped_speaking: the agent stopped speaking - - user_speech_committed: the user speech was committed to the chat context - - agent_speech_committed: the agent speech was committed to the chat context - - agent_speech_interrupted: the agent speech was interrupted - - function_calls_collected: received the complete set of functions to be executed - - function_calls_finished: all function calls have been completed - callback: the callback to call when the event is emitted - """ - return super().on(event, callback) - - async def say( - self, - source: str | LLMStream | AsyncIterable[str], - *, - allow_interruptions: bool = True, - add_to_chat_ctx: bool = True, - ) -> SpeechHandle: - """ - Play a speech source through the voice assistant. - - Args: - source: The source of the speech to play. - It can be a string, an LLMStream, or an asynchronous iterable of strings. - allow_interruptions: Whether to allow interruptions during the speech playback. - add_to_chat_ctx: Whether to add the speech to the chat context. - - Returns: - The speech handle for the speech that was played, can be used to - wait for the speech to finish. - """ - await self._track_published_fut - - call_ctx = None - fnc_source: str | AsyncIterable[str] | None = None - if add_to_chat_ctx: + forward_task = asyncio.create_task(_forward_input()) try: - call_ctx = AgentCallContext.get_current() - except LookupError: - # no active call context, ignore - pass - else: - if isinstance(source, LLMStream): - logger.warning( - "LLMStream will be ignored for function call chat context" - ) - elif isinstance(source, AsyncIterable): - source, fnc_source = utils.aio.itertools.tee(source, 2) # type: ignore - else: - fnc_source = source + async for ev in stream: + yield ev.frame + finally: + await utils.aio.gracefully_cancel(forward_task) - new_handle = SpeechHandle.create_assistant_speech( - allow_interruptions=allow_interruptions, add_to_chat_ctx=add_to_chat_ctx + def start(self) -> None: + self._audio_recognition.start() + self._speech_scheduler_task = asyncio.create_task( + self._playout_scheduler(), name="_playout_scheduler" ) - synthesis_handle = self._synthesize_agent_speech(new_handle.id, source) - new_handle.initialize(source=source, synthesis_handle=synthesis_handle) - - if self._playing_speech and not self._playing_speech.nested_speech_done: - self._playing_speech.add_nested_speech(new_handle) - else: - self._add_speech_for_playout(new_handle) - - # add the speech to the function call context if needed - if call_ctx is not None and fnc_source is not None: - if isinstance(fnc_source, AsyncIterable): - text = "" - async for chunk in fnc_source: - text += chunk - else: - text = fnc_source - - call_ctx.add_extra_chat_message( - ChatMessage.create(text=text, role="assistant") - ) - logger.debug( - "added speech to function call chat context", - extra={"text": text}, - ) - - return new_handle - - def interrupt(self, interrupt_all: bool = True) -> None: - """Interrupt the current speech - - Args: - interrupt_all: Whether to interrupt all pending speech - """ - if interrupt_all: - # interrupt all pending speech - if self._pending_agent_reply is not None: - self._pending_agent_reply.cancel(cancel_nested=True) - for speech in self._speech_q: - speech.cancel(cancel_nested=True) - - # interrupt the playing speech - if self._playing_speech is not None: - self._playing_speech.cancel() - - def _update_state(self, state: AgentState, delay: float = 0.0): - """Set the current state of the agent""" - - @utils.log_exceptions(logger=logger) - async def _run_task(delay: float) -> None: - await asyncio.sleep(delay) - - if self._room.isconnected(): - await self._room.local_participant.set_attributes( - {ATTRIBUTE_AGENT_STATE: state} - ) - - if self._update_state_task is not None: - self._update_state_task.cancel() - - self._update_state_task = asyncio.create_task(_run_task(delay)) async def aclose(self) -> None: - """Close the voice assistant""" - if not self._started: - return + await self._audio_recognition.aclose() - self._room.off("participant_connected", self._on_participant_connected) - await self._deferred_validation.aclose() + def emit(self, event: EventTypes, *args) -> None: + debug.Tracing.log_event(f'agent.on("{event}")') + return super().emit(event, *args) - def _on_participant_connected(self, participant: rtc.RemoteParticipant): - if self._human_input is not None: - return - - self._link_participant(participant.identity) - - def _link_participant(self, identity: str) -> None: - participant = self._room.remote_participants.get(identity) - if participant is None: - logger.error("_link_participant must be called with a valid identity") - return - - self._human_input = HumanInput( - room=self._room, - vad=self._vad, - stt=self._stt, - participant=participant, - transcription=self._opts.transcription.user_transcription, - ) - - def _on_start_of_speech(ev: vad.VADEvent) -> None: - self._plotter.plot_event("user_started_speaking") - self.emit("user_started_speaking") - self._deferred_validation.on_human_start_of_speech(ev) - - def _on_vad_inference_done(ev: vad.VADEvent) -> None: - if not self._track_published_fut.done(): - return - - assert self._agent_output is not None - - tv = 1.0 - if self._opts.allow_interruptions: - tv = max(0.0, 1.0 - ev.probability) - self._agent_output.playout.target_volume = tv - - smoothed_tv = self._agent_output.playout.smoothed_volume - - self._plotter.plot_value("raw_vol", tv) - self._plotter.plot_value("smoothed_vol", smoothed_tv) - self._plotter.plot_value("vad_probability", ev.probability) - - if ev.speech_duration >= self._opts.int_speech_duration: - self._interrupt_if_possible() - - if ev.raw_accumulated_speech > 0.0: - self._last_speech_time = ( - time.perf_counter() - ev.raw_accumulated_silence - ) - - def _on_end_of_speech(ev: vad.VADEvent) -> None: - self._plotter.plot_event("user_stopped_speaking") - self.emit("user_stopped_speaking") - self._deferred_validation.on_human_end_of_speech(ev) - - def _on_interim_transcript(ev: stt.SpeechEvent) -> None: - self._transcribed_interim_text = ev.alternatives[0].text - - def _on_final_transcript(ev: stt.SpeechEvent) -> None: - new_transcript = ev.alternatives[0].text - if not new_transcript: - return - - logger.debug( - "received user transcript", - extra={"user_transcript": new_transcript}, - ) + @property + def input(self) -> io.AgentInput: + return self._input - self._last_final_transcript_time = time.perf_counter() + @property + def output(self) -> io.AgentOutput: + return self._output - self._transcribed_text += ( - " " if self._transcribed_text else "" - ) + new_transcript + # TODO(theomonnom): find a better name than `generation` + @property + def current_speech(self) -> SpeechHandle | None: + return self._current_speech - if self._opts.preemptive_synthesis: - if ( - self._playing_speech is None - or self._playing_speech.allow_interruptions - ): - self._synthesize_agent_reply() + @property + def chat_ctx(self) -> llm.ChatContext: + return self._chat_ctx - self._deferred_validation.on_human_final_transcript( - new_transcript, ev.alternatives[0].language - ) + def update_options(self) -> None: + pass - words = self._opts.transcription.word_tokenizer.tokenize( - text=new_transcript - ) - if len(words) >= 3: - # VAD can sometimes not detect that the human is speaking - # to make the interruption more reliable, we also interrupt on the final transcript. - self._interrupt_if_possible() + def say(self, text: str | AsyncIterable[str]) -> SpeechHandle: + pass - self._human_input.on("start_of_speech", _on_start_of_speech) - self._human_input.on("vad_inference_done", _on_vad_inference_done) - self._human_input.on("end_of_speech", _on_end_of_speech) - self._human_input.on("interim_transcript", _on_interim_transcript) - self._human_input.on("final_transcript", _on_final_transcript) + def generate_reply(self, user_input: str) -> SpeechHandle: + if self._current_speech is not None and not self._current_speech.interrupted: + raise ValueError("another reply is already in progress") - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: - if self._opts.plotting: - await self._plotter.start() - - self._update_state("initializing") - audio_source = rtc.AudioSource(self._tts.sample_rate, self._tts.num_channels) - track = rtc.LocalAudioTrack.create_audio_track("assistant_voice", audio_source) - self._agent_publication = await self._room.local_participant.publish_track( - track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) - ) + debug.Tracing.log_event("generate_reply", {"user_input": user_input}) + self._chat_ctx.append(role="user", text=user_input) # TODO(theomonnom) Remove - agent_playout = AgentPlayout(audio_source=audio_source) - self._agent_output = AgentOutput( - room=self._room, - agent_playout=agent_playout, - llm=self._llm, - tts=self._tts, + handle = SpeechHandle.create(allow_interruptions=self._opts.allow_interruptions) + task = asyncio.create_task( + self._generate_pipeline_reply_task( + handle=handle, + chat_ctx=self._chat_ctx, + fnc_ctx=self._fnc_ctx, + ), + name="_generate_pipeline_reply", ) + self._schedule_speech(handle, task, self.SPEECH_PRIORITY_NORMAL) + return handle - def _on_playout_started() -> None: - self._plotter.plot_event("agent_started_speaking") - self.emit("agent_started_speaking") - self._update_state("speaking") + # -- Main generation task -- - def _on_playout_stopped(interrupted: bool) -> None: - self._plotter.plot_event("agent_stopped_speaking") - self.emit("agent_stopped_speaking") - self._update_state("listening") - - agent_playout.on("playout_started", _on_playout_started) - agent_playout.on("playout_stopped", _on_playout_stopped) + def _schedule_speech( + self, speech: SpeechHandle, task: asyncio.Task, priority: int + ) -> None: + self._speech_tasks.append(task) + task.add_done_callback(lambda _: self._speech_tasks.remove(task)) - self._track_published_fut.set_result(None) + heapq.heappush(self._speech_q, (priority, speech)) + self._speech_q_changed.set() + @utils.log_exceptions(logger=logger) + async def _playout_scheduler(self) -> None: while True: await self._speech_q_changed.wait() while self._speech_q: - speech = self._speech_q[0] - self._playing_speech = speech - await self._play_speech(speech) - self._speech_q.pop(0) # Remove the element only after playing - self._playing_speech = None + _, speech = heapq.heappop(self._speech_q) + self._current_speech = speech + speech.play() + await speech.wait_for_playout() + self._current_speech = None self._speech_q_changed.clear() - def _synthesize_agent_reply(self): - """Synthesize the agent reply to the user question, also make sure only one reply - is synthesized/played at a time""" - - if self._pending_agent_reply is not None: - self._pending_agent_reply.cancel() - - if self._human_input is not None and not self._human_input.speaking: - self._update_state("thinking", 0.2) - - self._pending_agent_reply = new_handle = SpeechHandle.create_assistant_reply( - allow_interruptions=self._opts.allow_interruptions, - add_to_chat_ctx=True, - user_question=self._transcribed_text, - ) - - self._agent_reply_task = asyncio.create_task( - self._synthesize_answer_task(self._agent_reply_task, new_handle) - ) - @utils.log_exceptions(logger=logger) - async def _synthesize_answer_task( - self, old_task: asyncio.Task[None], handle: SpeechHandle + async def _generate_pipeline_reply_task( + self, + *, + handle: SpeechHandle, + chat_ctx: ChatContext, + fnc_ctx: FunctionContext | None, ) -> None: - if old_task is not None: - await utils.aio.gracefully_cancel(old_task) - - copied_ctx = self._chat_ctx.copy() - playing_speech = self._playing_speech - if playing_speech is not None and playing_speech.initialized: - if ( - not playing_speech.user_question or playing_speech.user_committed - ) and not playing_speech.speech_committed: - # the speech is playing but not committed yet, add it to the chat context for this new reply synthesis - # First add the previous function call message if any - if playing_speech.extra_tools_messages: - copied_ctx.messages.extend(playing_speech.extra_tools_messages) - - # Then add the previous assistant message - copied_ctx.messages.append( - ChatMessage.create( - text=playing_speech.synthesis_handle.tts_forwarder.played_text, - role="assistant", - ) - ) - - # we want to add this question even if it's empty. during false positive interruptions, - # adding an empty user message gives the LLM context so it could continue from where - # it had been interrupted. - copied_ctx.messages.append( - ChatMessage.create(text=handle.user_question, role="user") - ) - - tk = SpeechDataContextVar.set(SpeechData(sequence_id=handle.id)) - try: - llm_stream = self._opts.before_llm_cb(self, copied_ctx) - if asyncio.iscoroutine(llm_stream): - llm_stream = await llm_stream - - if llm_stream is False: - handle.cancel() + @utils.log_exceptions(logger=logger) + async def _forward_llm_text(llm_output: AsyncIterable[str]) -> None: + """collect and forward the generated text to the current agent output""" + if self.output.text is None: return - # fallback to default impl if no custom/user stream is returned - if not isinstance(llm_stream, LLMStream): - llm_stream = _default_before_llm_cb(self, chat_ctx=copied_ctx) + try: + async for delta in llm_output: + await self.output.text.capture_text(delta) + finally: + self.output.text.flush() - if handle.interrupted: + @utils.log_exceptions(logger=logger) + async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: + """collect and forward the generated audio to the current agent output (generally playout)""" + if self.output.audio is None: return - synthesis_handle = self._synthesize_agent_speech(handle.id, llm_stream) - handle.initialize(source=llm_stream, synthesis_handle=synthesis_handle) - finally: - SpeechDataContextVar.reset(tk) - - async def _play_speech(self, speech_handle: SpeechHandle) -> None: - try: - await speech_handle.wait_for_initialization() - except asyncio.CancelledError: - return - - await self._agent_publication.wait_for_subscription() + try: + async for frame in tts_output: + await self.output.audio.capture_frame(frame) + finally: + self.output.audio.flush() # always flush (even if the task is interrupted) - synthesis_handle = speech_handle.synthesis_handle - if synthesis_handle.interrupted: - return + @utils.log_exceptions(logger=logger) + async def _execute_tools( + tools_ch: utils.aio.Chan[llm.FunctionCallInfo], + called_functions: set[llm.CalledFunction], + ) -> None: + """execute tools, when cancelled, stop executing new tools but wait for the pending ones""" + try: + async for tool in tools_ch: + logger.debug( + "executing tool", + extra={ + "function": tool.function_info.name, + "speech_id": handle.id, + }, + ) + debug.Tracing.log_event( + "executing tool", + { + "function": tool.function_info.name, + "speech_id": handle.id, + }, + ) + cfnc = tool.execute() + called_functions.add(cfnc) + except asyncio.CancelledError: + # don't allow to cancel running function calla if they're still running + pending_tools = [cfn for cfn in called_functions if not cfn.task.done()] - user_question = speech_handle.user_question + if pending_tools: + names = [cfn.call_info.function_info.name for cfn in pending_tools] - play_handle = synthesis_handle.play() - join_fut = play_handle.join() + logger.debug( + "waiting for function call to finish before cancelling", + extra={ + "functions": names, + "speech_id": handle.id, + }, + ) + debug.Tracing.log_event( + "waiting for function call to finish before cancelling", + { + "functions": names, + "speech_id": handle.id, + }, + ) + await asyncio.gather(*[cfn.task for cfn in pending_tools]) + finally: + if len(called_functions) > 0: + logger.debug( + "tools execution completed", + extra={"speech_id": handle.id}, + ) + debug.Tracing.log_event( + "tools execution completed", + {"speech_id": handle.id}, + ) - def _commit_user_question_if_needed() -> None: - if ( - not user_question - or synthesis_handle.interrupted - or speech_handle.user_committed - ): - return + debug.Tracing.log_event( + "generation started", + {"speech_id": handle.id, "step_index": handle.step_index}, + ) - is_using_tools = isinstance(speech_handle.source, LLMStream) and len( - speech_handle.source.function_calls + wg = utils.aio.WaitGroup() + tasks = [] + llm_task, llm_gen_data = do_llm_inference( + node=self.llm_node, + chat_ctx=chat_ctx, + fnc_ctx=( + fnc_ctx + if handle.step_index < self._opts.max_fnc_steps - 1 + and handle.step_index >= 2 + else None + ), + ) + tasks.append(llm_task) + wg.add(1) + llm_task.add_done_callback(lambda _: wg.done()) + tts_text_input, llm_output = utils.aio.itertools.tee(llm_gen_data.text_ch) + + tts_task: asyncio.Task | None = None + tts_gen_data: _TTSGenerationData | None = None + if self._output.audio is not None: + tts_task, tts_gen_data = do_tts_inference( + node=self.tts_node, input=tts_text_input ) + tasks.append(tts_task) + wg.add(1) + tts_task.add_done_callback(lambda _: wg.done()) + + # wait for the play() method to be called + await asyncio.wait( + [ + handle._play_fut, + handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, + ) - # make sure at least some speech was played before committing the user message - # since we try to validate as fast as possible it is possible the agent gets interrupted - # really quickly (barely audible), we don't want to mark this question as "answered". - if ( - speech_handle.allow_interruptions - and not is_using_tools - and ( - play_handle.time_played < self.MIN_TIME_PLAYED_FOR_COMMIT - and not join_fut.done() - ) - ): - return - - user_msg = ChatMessage.create(text=user_question, role="user") - self._chat_ctx.messages.append(user_msg) - self.emit("user_speech_committed", user_msg) - - self._transcribed_text = self._transcribed_text[len(user_question) :] - speech_handle.mark_user_committed() - - # wait for the play_handle to finish and check every 1s if the user question should be committed - _commit_user_question_if_needed() + if handle.interrupted: + await utils.aio.gracefully_cancel(*tasks) + handle._mark_done() + return # return directly (the generated output wasn't used) - while not join_fut.done(): - await asyncio.wait( - [join_fut], return_when=asyncio.FIRST_COMPLETED, timeout=0.2 + # forward tasks are started after the play() method is called + # they redirect the generated text/audio to the output channels + forward_llm_task = asyncio.create_task( + _forward_llm_text(llm_output), + name="_generate_reply_task.forward_llm_text", + ) + tasks.append(forward_llm_task) + wg.add(1) + forward_llm_task.add_done_callback(lambda _: wg.done()) + + forward_tts_task: asyncio.Task | None = None + if tts_gen_data is not None: + forward_tts_task = asyncio.create_task( + _forward_tts_audio(tts_gen_data.audio_ch), + name="_generate_reply_task.forward_tts_audio", ) - - _commit_user_question_if_needed() - - if speech_handle.interrupted: - break - - _commit_user_question_if_needed() - - collected_text = speech_handle.synthesis_handle.tts_forwarder.played_text - interrupted = speech_handle.interrupted - is_using_tools = isinstance(speech_handle.source, LLMStream) and len( - speech_handle.source.function_calls + tasks.append(forward_tts_task) + wg.add(1) + forward_tts_task.add_done_callback(lambda _: wg.done()) + + # start to execute tools (only after play()) + called_functions: set[llm.CalledFunction] = set() + tools_task = asyncio.create_task( + _execute_tools(llm_gen_data.tools_ch, called_functions), + name="_generate_reply_task.execute_tools", + ) + tasks.append(tools_task) + wg.add(1) + tools_task.add_done_callback(lambda _: wg.done()) + + # wait for the tasks to finish + await asyncio.wait( + [ + wg.wait(), + handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, ) - message_id_committed: str | None = None - if ( - collected_text - and speech_handle.add_to_chat_ctx - and (not user_question or speech_handle.user_committed) - ): - if speech_handle.extra_tools_messages: - if speech_handle.fnc_text_message_id is not None: - # there is a message alongside the function calls - msgs = self._chat_ctx.messages - if msgs and msgs[-1].id == speech_handle.fnc_text_message_id: - # replace it with the tool call message if it's the last in the ctx - msgs.pop() - elif speech_handle.extra_tools_messages[0].tool_calls: - # remove the content of the tool call message - speech_handle.extra_tools_messages[0].content = "" - self._chat_ctx.messages.extend(speech_handle.extra_tools_messages) - - if interrupted: - collected_text += "..." - - msg = ChatMessage.create(text=collected_text, role="assistant") - self._chat_ctx.messages.append(msg) - message_id_committed = msg.id - speech_handle.mark_speech_committed() - - if interrupted: - self.emit("agent_speech_interrupted", msg) - else: - self.emit("agent_speech_committed", msg) - - logger.debug( - "committed agent speech", - extra={ - "agent_transcript": collected_text, - "interrupted": interrupted, - "speech_id": speech_handle.id, - }, + # wait for the end of the playout if the audio is enabled + if forward_llm_task is not None: + assert self._output.audio is not None + await asyncio.wait( + [ + self._output.audio.wait_for_playout(), + handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, ) - async def _execute_function_calls() -> None: - nonlocal interrupted, collected_text + if handle.interrupted: + await utils.aio.gracefully_cancel(*tasks) - # if the answer is using tools, execute the functions and automatically generate - # a response to the user question from the returned values - if not is_using_tools or interrupted: - return - - if speech_handle.fnc_nested_depth >= self._opts.max_nested_fnc_calls: - logger.warning( - "max function calls nested depth reached", - extra={ - "speech_id": speech_handle.id, - "fnc_nested_depth": speech_handle.fnc_nested_depth, - }, + if len(called_functions) > 0: + functions = [ + cfnc.call_info.function_info.name for cfnc in called_functions + ] + logger.debug( + "speech interrupted, ignoring generation of the function calls results", + extra={"speech_id": handle.id, "functions": functions}, + ) + debug.Tracing.log_event( + "speech interrupted, ignoring generation of the function calls results", + {"speech_id": handle.id, "functions": functions}, ) - return - - assert isinstance(speech_handle.source, LLMStream) - assert ( - not user_question or speech_handle.user_committed - ), "user speech should have been committed before using tools" - - llm_stream = speech_handle.source - - # execute functions - call_ctx = AgentCallContext(self, llm_stream) - tk = _CallContextVar.set(call_ctx) - new_function_calls = llm_stream.function_calls + # if the audio playout was enabled, clear the buffer + if forward_tts_task is not None: + assert self._output.audio is not None - self.emit("function_calls_collected", new_function_calls) + self._output.audio.clear_buffer() + playback_ev = await self._output.audio.wait_for_playout() - called_fncs = [] - for fnc in new_function_calls: - called_fnc = fnc.execute() - called_fncs.append(called_fnc) - logger.debug( - "executing ai function", - extra={ - "function": fnc.function_info.name, - "speech_id": speech_handle.id, + debug.Tracing.log_event( + "playout interrupted", + { + "playback_position": playback_ev.playback_position, + "speech_id": handle.id, }, ) - try: - await called_fnc.task - except Exception as e: - logger.exception( - "error executing ai function", - extra={ - "function": fnc.function_info.name, - "speech_id": speech_handle.id, - }, - exc_info=e, - ) - tool_calls_info = [] - tool_calls_results = [] + handle._mark_playout_done() + # TODO(theomonnom): calculate the played text based on playback_ev.playback_position - for called_fnc in called_fncs: - # ignore the function calls that returns None - if called_fnc.result is None and called_fnc.exception is None: - continue + handle._mark_done() + return - tool_calls_info.append(called_fnc.call_info) - tool_calls_results.append( - ChatMessage.create_tool_from_called_function(called_fnc) - ) + handle._mark_playout_done() + debug.Tracing.log_event("playout completed", {"speech_id": handle.id}) - if not tool_calls_info: + if len(called_functions) > 0: + if handle.step_index >= self._opts.max_fnc_steps: + logger.warning( + "maximum number of function calls steps reached", + extra={"speech_id": handle.id}, + ) + debug.Tracing.log_event( + "maximum number of function calls steps reached", + {"speech_id": handle.id}, + ) + handle._mark_done() return - # create a nested speech handle - extra_tools_messages = [ - ChatMessage.create_tool_calls(tool_calls_info, text=collected_text) - ] - extra_tools_messages.extend(tool_calls_results) - - new_speech_handle = SpeechHandle.create_tool_speech( - allow_interruptions=speech_handle.allow_interruptions, - add_to_chat_ctx=speech_handle.add_to_chat_ctx, - extra_tools_messages=extra_tools_messages, - fnc_nested_depth=speech_handle.fnc_nested_depth + 1, - fnc_text_message_id=message_id_committed, - ) - - # synthesize the tool speech with the chat ctx from llm_stream - chat_ctx = call_ctx.chat_ctx.copy() - chat_ctx.messages.extend(extra_tools_messages) - chat_ctx.messages.extend(call_ctx.extra_chat_messages) - fnc_ctx = self.fnc_ctx - if ( - fnc_ctx - and new_speech_handle.fnc_nested_depth - >= self._opts.max_nested_fnc_calls - ): - if len(fnc_ctx.ai_functions) > 1: - logger.info( - "max function calls nested depth reached, dropping function context. increase max_nested_fnc_calls to enable additional nesting.", - extra={ - "speech_id": speech_handle.id, - "fnc_nested_depth": speech_handle.fnc_nested_depth, - }, - ) - fnc_ctx = None - answer_llm_stream = self._llm.chat( - chat_ctx=chat_ctx, - fnc_ctx=fnc_ctx, - ) - - synthesis_handle = self._synthesize_agent_speech( - new_speech_handle.id, answer_llm_stream - ) - new_speech_handle.initialize( - source=answer_llm_stream, synthesis_handle=synthesis_handle - ) - speech_handle.add_nested_speech(new_speech_handle) - - self.emit("function_calls_finished", called_fncs) - _CallContextVar.reset(tk) - - if not is_using_tools: - speech_handle._set_done() - return - - fnc_task = asyncio.create_task(_execute_function_calls()) - while not speech_handle.nested_speech_done: - nesting_changed = asyncio.create_task( - speech_handle.nested_speech_changed.wait() - ) - nesting_done_fut: asyncio.Future = speech_handle._nested_speech_done_fut - await asyncio.wait( - [nesting_changed, fnc_task, nesting_done_fut], - return_when=asyncio.FIRST_COMPLETED, + # create a new SpeechHandle to generate the result of the function calls + handle = SpeechHandle.create( + allow_interruptions=self._opts.allow_interruptions, + step_index=handle.step_index + 1, ) - if not nesting_changed.done(): - nesting_changed.cancel() - - while speech_handle.nested_speech_handles: - speech = speech_handle.nested_speech_handles[0] - if speech_handle.nested_speech_done: - # in case tool speech is added after nested speech done - speech.cancel(cancel_nested=True) - speech_handle.nested_speech_handles.pop(0) - continue - - self._playing_speech = speech - await self._play_speech(speech) - speech_handle.nested_speech_handles.pop(0) - self._playing_speech = speech_handle - - speech_handle.nested_speech_changed.clear() - # break if the function calls task is done - if fnc_task.done(): - speech_handle.mark_nested_speech_done() - - if not fnc_task.done(): - logger.debug( - "cancelling function calls task", extra={"speech_id": speech_handle.id} + task = asyncio.create_task( + self._generate_pipeline_reply_task( + handle=handle, + chat_ctx=chat_ctx, + fnc_ctx=fnc_ctx, + ), + name="_generate_pipeline_reply", ) - fnc_task.cancel() + self._schedule_speech(handle, task, self.SPEECH_PRIORITY_NORMAL) - # mark the speech as done - speech_handle._set_done() + handle._mark_done() - def _synthesize_agent_speech( - self, - speech_id: str, - source: str | LLMStream | AsyncIterable[str], - ) -> SynthesisHandle: - assert ( - self._agent_output is not None - ), "agent output should be initialized when ready" - - tk = SpeechDataContextVar.set(SpeechData(speech_id)) - - async def _llm_stream_to_str_generator( - stream: LLMStream, - ) -> AsyncGenerator[str]: - try: - async for chunk in stream: - if not chunk.choices: - continue + # -- Audio recognition -- - content = chunk.choices[0].delta.content - if content is None: - continue + def _on_audio_end_of_turn(self, new_transcript: str) -> None: + # When the audio recognition detects the end of a user turn: + # - check if there is no current generation happening + # - cancel the current generation if it allows interruptions (otherwise skip this current + # turn) + # - generate a reply to the user input - yield content - finally: - await stream.aclose() - - if isinstance(source, LLMStream): - source = _llm_stream_to_str_generator(source) - - og_source = source - transcript_source = source - if isinstance(og_source, AsyncIterable): - og_source, transcript_source = utils.aio.itertools.tee(og_source, 2) - - tts_source = self._opts.before_tts_cb(self, og_source) - if tts_source is None: - raise ValueError("before_tts_cb must return str or AsyncIterable[str]") - - try: - return self._agent_output.synthesize( - speech_id=speech_id, - tts_source=tts_source, - transcript_source=transcript_source, - transcription=self._opts.transcription.agent_transcription, - transcription_speed=self._opts.transcription.agent_transcription_speed, - sentence_tokenizer=self._opts.transcription.sentence_tokenizer, - word_tokenizer=self._opts.transcription.word_tokenizer, - hyphenate_word=self._opts.transcription.hyphenate_word, - ) - finally: - SpeechDataContextVar.reset(tk) - - def _validate_reply_if_possible(self) -> None: - """Check if the new agent speech should be played""" - - if self._playing_speech and not self._playing_speech.interrupted: - should_ignore_input = False - if not self._playing_speech.allow_interruptions: - should_ignore_input = True - logger.debug( - "skipping validation, agent is speaking and does not allow interruptions", - extra={"speech_id": self._playing_speech.id}, - ) - elif not self._should_interrupt(): - should_ignore_input = True - logger.debug( - "interrupt threshold is not met", - extra={"speech_id": self._playing_speech.id}, + if self._current_speech is not None: + if self._current_speech.allow_interruptions: + logger.warning( + "skipping user input, current speech generation cannot be interrupted", + extra={"user_input": new_transcript}, ) - - if should_ignore_input: - self._transcribed_text = "" - return - - if self._pending_agent_reply is None: - if self._opts.preemptive_synthesis: return - # as long as we don't have a pending reply, we need to synthesize it - # in order to keep the conversation flowing. - # transcript could be empty at this moment, if the user interrupted the agent - # but did not generate any transcribed text. - self._synthesize_agent_reply() - - assert self._pending_agent_reply is not None - - # due to timing, we could end up with two pushed agent replies inside the speech queue. - # so make sure we directly interrupt every reply when validating a new one - for speech in self._speech_q: - if not speech.is_reply: - continue - - if speech.allow_interruptions: - speech.interrupt() - - logger.debug( - "validated agent reply", - extra={ - "speech_id": self._pending_agent_reply.id, - "text": self._transcribed_text, - }, - ) - - if self._last_speech_time is not None: - time_since_last_speech = time.perf_counter() - self._last_speech_time - transcription_delay = max( - (self._last_final_transcript_time or 0) - self._last_speech_time, 0 + debug.Tracing.log_event( + "speech interrupted, new user turn detected", + {"speech_id": self._current_speech.id}, ) + self._current_speech.interrupt() - eou_metrics = metrics.PipelineEOUMetrics( - timestamp=time.time(), - sequence_id=self._pending_agent_reply.id, - end_of_utterance_delay=time_since_last_speech, - transcription_delay=transcription_delay, - ) - self.emit("metrics_collected", eou_metrics) - - self._add_speech_for_playout(self._pending_agent_reply) - self._pending_agent_reply = None - self._transcribed_interim_text = "" - # self._transcribed_text is reset after MIN_TIME_PLAYED_FOR_COMMIT, see self._play_speech - - def _interrupt_if_possible(self) -> None: - """Check whether the current assistant speech should be interrupted""" - if self._playing_speech and self._should_interrupt(): - self._playing_speech.interrupt() - - def _should_interrupt(self) -> bool: - if self._playing_speech is None: - return False - - if ( - not self._playing_speech.allow_interruptions - or self._playing_speech.interrupted - ): - return False - - if self._opts.int_min_words != 0: - text = self._transcribed_interim_text or self._transcribed_text - interim_words = self._opts.transcription.word_tokenizer.tokenize(text=text) - if len(interim_words) < self._opts.int_min_words: - return False - - return True - - def _add_speech_for_playout(self, speech_handle: SpeechHandle) -> None: - self._speech_q.append(speech_handle) - self._speech_q_changed.set() + self.generate_reply(new_transcript) + def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: + if ev.speech_duration > self._opts.min_interruption_duration: + if ( + self._current_speech is not None + and not self._current_speech.interrupted + and self._current_speech.allow_interruptions + ): + debug.Tracing.log_event( + "speech interrupted by vad", + {"speech_id": self._current_speech.id}, + ) + self._current_speech.interrupt() -class _DeferredReplyValidation: - """This class is used to try to find the best time to validate the agent reply.""" + def _on_start_of_speech(self, _: vad.VADEvent) -> None: + self.emit("user_started_speaking", events.UserStartedSpeakingEvent()) - # if the STT gives us punctuation, we can try validate the reply faster. - PUNCTUATION = ".!?" - PUNCTUATION_REDUCE_FACTOR = 0.75 + def _on_end_of_speech(self, _: vad.VADEvent) -> None: + self.emit("user_stopped_speaking", events.UserStoppedSpeakingEvent()) - FINAL_TRANSCRIPT_TIMEOUT = 5 + # --- - def __init__( - self, - validate_fnc: Callable[[], None], - min_endpointing_delay: float, - max_endpointing_delay: float, - turn_detector: _TurnDetector | None, - agent: VoicePipelineAgent, - ) -> None: - self._turn_detector = turn_detector - self._validate_fnc = validate_fnc - self._validating_task: asyncio.Task | None = None - self._last_final_transcript: str = "" - self._last_language: str | None = None - self._last_recv_start_of_speech_time: float = 0.0 - self._last_recv_end_of_speech_time: float = 0.0 - self._last_recv_transcript_time: float = 0.0 - self._speaking = False - - self._agent = agent - self._end_of_speech_delay = min_endpointing_delay - self._max_endpointing_delay = max_endpointing_delay + # -- User changed input/output streams/sinks -- - @property - def validating(self) -> bool: - return self._validating_task is not None and not self._validating_task.done() - - def _compute_delay(self) -> float | None: - """Computes the amount of time to wait before validating the agent reply. - - This function should be called after the agent has received final transcript, or after VAD - """ - # never interrupt the user while they are speaking - if self._speaking: - return None - - # if STT doesn't give us the final transcript after end of speech, we'll still validate the reply - # to prevent the agent from getting "stuck" - # in this case, the agent will not have final transcript, so it'll trigger the user input with empty - if not self._last_final_transcript: - return self.FINAL_TRANSCRIPT_TIMEOUT - - delay = self._end_of_speech_delay - if self._end_with_punctuation(): - delay = delay * self.PUNCTUATION_REDUCE_FACTOR - - # the delay should be computed from end of earlier timestamp, that's the true end of user speech - end_of_speech_time = self._last_recv_end_of_speech_time - if ( - self._last_recv_transcript_time > 0 - and self._last_recv_transcript_time > self._last_recv_start_of_speech_time - and self._last_recv_transcript_time < end_of_speech_time - ): - end_of_speech_time = self._last_recv_transcript_time - - elapsed_time = time.perf_counter() - end_of_speech_time - if elapsed_time < delay: - delay -= elapsed_time - else: - delay = 0 - return delay - - def on_human_final_transcript(self, transcript: str, language: str | None) -> None: - self._last_final_transcript += " " + transcript.strip() # type: ignore - self._last_language = language - self._last_recv_transcript_time = time.perf_counter() - - delay = self._compute_delay() - if delay is not None: - self._run(delay) - - def on_human_start_of_speech(self, ev: vad.VADEvent) -> None: - self._speaking = True - self._last_recv_start_of_speech_time = time.perf_counter() - if self.validating: - assert self._validating_task is not None - self._validating_task.cancel() - - def on_human_end_of_speech(self, ev: vad.VADEvent) -> None: - self._speaking = False - self._last_recv_end_of_speech_time = time.perf_counter() - - delay = self._compute_delay() - if delay is not None: - self._run(delay) + def _on_video_input_changed(self) -> None: + pass - async def aclose(self) -> None: - if self._validating_task is not None: - await utils.aio.gracefully_cancel(self._validating_task) + def _on_audio_input_changed(self) -> None: + self._audio_recognition.audio_input = self._input.audio - def _end_with_punctuation(self) -> bool: - return ( - len(self._last_final_transcript) > 0 - and self._last_final_transcript[-1] in self.PUNCTUATION - ) + def _on_video_output_changed(self) -> None: + pass - def _reset_states(self) -> None: - self._last_final_transcript = "" - self._last_recv_end_of_speech_time = 0.0 - self._last_recv_transcript_time = 0.0 + def _on_audio_output_changed(self) -> None: + pass - def _run(self, delay: float) -> None: - @utils.log_exceptions(logger=logger) - async def _run_task(chat_ctx: ChatContext, delay: float) -> None: - use_turn_detector = self._last_final_transcript and not self._speaking - if ( - use_turn_detector - and self._turn_detector is not None - and self._turn_detector.supports_language(self._last_language) - ): - start_time = time.perf_counter() - try: - eot_prob = await self._turn_detector.predict_end_of_turn(chat_ctx) - unlikely_threshold = self._turn_detector.unlikely_threshold() - elasped = time.perf_counter() - start_time - if eot_prob < unlikely_threshold: - delay = self._max_endpointing_delay - delay = max(0, delay - elasped) - except TimeoutError: - pass # inference process is unresponsive - - await asyncio.sleep(delay) - - self._reset_states() - self._validate_fnc() - - if self._validating_task is not None: - self._validating_task.cancel() - - detect_ctx = self._agent._chat_ctx.copy() - detect_ctx.messages.append( - ChatMessage.create(text=self._agent._transcribed_text, role="user") - ) - self._validating_task = asyncio.create_task(_run_task(detect_ctx, delay)) + def _on_text_output_changed(self) -> None: + pass + + # --- diff --git a/livekit-agents/livekit/agents/pipeline/plotter.py b/livekit-agents/livekit/agents/pipeline/plotter.py deleted file mode 100644 index c0a9a1ca9..000000000 --- a/livekit-agents/livekit/agents/pipeline/plotter.py +++ /dev/null @@ -1,201 +0,0 @@ -import asyncio -import contextlib -import io -import multiprocessing as mp -import selectors -import socket -import time -from dataclasses import dataclass -from typing import ClassVar, Literal, Tuple - -from .. import utils -from ..ipc import channel - -PlotType = Literal["vad_probability", "raw_vol", "smoothed_vol"] -EventType = Literal[ - "user_started_speaking", - "user_stopped_speaking", - "agent_started_speaking", - "agent_stopped_speaking", -] - - -@dataclass -class PlotMessage: - MSG_ID: ClassVar[int] = 1 - - which: PlotType = "vad_probability" - x: float = 0.0 - y: float = 0.0 - - def write(self, b: io.BytesIO) -> None: - channel.write_string(b, self.which) - channel.write_float(b, self.x) - channel.write_float(b, self.y) - - def read(self, b: io.BytesIO) -> None: - self.which = channel.read_string(b) # type: ignore - self.x = channel.read_float(b) - self.y = channel.read_float(b) - - -@dataclass -class PlotEventMessage: - MSG_ID: ClassVar[int] = 2 - - which: EventType = "user_started_speaking" - x: float = 0.0 - - def write(self, b: io.BytesIO) -> None: - channel.write_string(b, self.which) - channel.write_float(b, self.x) - - def read(self, b: io.BytesIO) -> None: - self.which = channel.read_string(b) # type: ignore - self.x = channel.read_float(b) - - -PLT_MESSAGES: dict = { - PlotMessage.MSG_ID: PlotMessage, - PlotEventMessage.MSG_ID: PlotEventMessage, -} - - -def _draw_plot(mp_cch): - try: - import matplotlib as mpl # type: ignore - import matplotlib.pyplot as plt # type: ignore - except ImportError: - raise ImportError( - "matplotlib is required to run use the VoiceAssistant plotter" - ) - - plt.style.use("ggplot") - mpl.rcParams["toolbar"] = "None" - - plot_data: dict[str, Tuple[list[float], list[float]]] = {} - plot_events: dict[str, list[float]] = {} - - fig, (pv, sp) = plt.subplots(2, sharex="all") - fig.canvas.manager.set_window_title("Voice Assistant") # type: ignore - - max_points = 250 - - duplex = utils.aio.duplex_unix._Duplex.open(mp_cch) - - selector = selectors.DefaultSelector() - selector.register(mp_cch, selectors.EVENT_READ) - - def _draw_cb(sp, pv): - while True: - events = selector.select(timeout=0.01) - if not events: - break - - msg = channel.recv_message(duplex, PLT_MESSAGES) - if isinstance(msg, PlotMessage): - data = plot_data.setdefault(msg.which, ([], [])) - data[0].append(msg.x) - data[1].append(msg.y) - data[0][:] = data[0][-max_points:] - data[1][:] = data[1][-max_points:] - - # remove old events older than 7.5s - for events in plot_events.values(): - while events and events[0] < msg.x - 7.5: - events.pop(0) - - elif isinstance(msg, PlotEventMessage): - events = plot_events.setdefault(msg.which, []) - events.append(msg.x) - - vad_raw = plot_data.setdefault("vad_probability", ([], [])) - raw_vol = plot_data.get("raw_vol", ([], [])) - vol = plot_data.get("smoothed_vol", ([], [])) - - pv.clear() - pv.set_ylim(0, 1) - pv.set(ylabel="assistant volume") - pv.plot(vol[0], vol[1], label="volume") - pv.plot(raw_vol[0], raw_vol[1], label="target_volume") - pv.legend() - - sp.clear() - sp.set_ylim(0, 1) - sp.set(xlabel="time (s)", ylabel="speech probability") - sp.plot(vad_raw[0], vad_raw[1], label="raw") - sp.legend() - - for start in plot_events.get("agent_started_speaking", []): - pv.axvline(x=start, color="r", linestyle="--") - - for stop in plot_events.get("agent_stopped_speaking", []): - pv.axvline(x=stop, color="r", linestyle="--") - - for start in plot_events.get("user_started_speaking", []): - sp.axvline(x=start, color="r", linestyle="--") - - for stop in plot_events.get("user_stopped_speaking", []): - sp.axvline(x=stop, color="r", linestyle="--") - - fig.canvas.draw() - - timer = fig.canvas.new_timer(interval=33) - timer.add_callback(_draw_cb, sp, pv) - timer.start() - plt.show() - - -class AssistantPlotter: - def __init__(self, loop: asyncio.AbstractEventLoop) -> None: - self._loop = loop - self._started = False - - async def start(self): - if self._started: - return - - mp_pch, mp_cch = socket.socketpair() - self._duplex = await utils.aio.duplex_unix._AsyncDuplex.open(mp_pch) - self._plot_proc = mp.Process(target=_draw_plot, args=(mp_cch,), daemon=True) - self._plot_proc.start() - mp_cch.close() - - self._started = True - self._closed = False - self._start_time = time.time() - - def plot_value(self, which: PlotType, y: float): - if not self._started: - return - - ts = time.time() - self._start_time - self._send_message(PlotMessage(which=which, x=ts, y=y)) - - def plot_event(self, which: EventType): - if not self._started: - return - - ts = time.time() - self._start_time - self._send_message(PlotEventMessage(which=which, x=ts)) - - def _send_message(self, msg: channel.Message) -> None: - if self._closed: - return - - async def _asend_message(): - try: - await channel.asend_message(self._duplex, msg) - except Exception: - self._closed = True - - asyncio.ensure_future(_asend_message()) - - async def terminate(self): - if not self._started: - return - - self._plot_proc.terminate() - - with contextlib.suppress(utils.aio.duplex_unix.DuplexClosed): - await self._duplex.aclose() diff --git a/livekit-agents/livekit/agents/pipeline/speech_handle.py b/livekit-agents/livekit/agents/pipeline/speech_handle.py deleted file mode 100644 index cd1f39dec..000000000 --- a/livekit-agents/livekit/agents/pipeline/speech_handle.py +++ /dev/null @@ -1,235 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import AsyncIterable - -from .. import utils -from ..llm import ChatMessage, LLMStream -from .agent_output import SynthesisHandle - - -class SpeechHandle: - def __init__( - self, - *, - id: str, - allow_interruptions: bool, - add_to_chat_ctx: bool, - is_reply: bool, - user_question: str, - fnc_nested_depth: int = 0, - extra_tools_messages: list[ChatMessage] | None = None, - fnc_text_message_id: str | None = None, - ) -> None: - self._id = id - self._allow_interruptions = allow_interruptions - self._add_to_chat_ctx = add_to_chat_ctx - - # is_reply is True when the speech is answering to a user question - self._is_reply = is_reply - self._user_question = user_question - self._user_committed = False - - self._init_fut = asyncio.Future[None]() - self._done_fut = asyncio.Future[None]() - self._initialized = False - self._speech_committed = False # speech committed (interrupted or not) - - # source and synthesis_handle are None until the speech is initialized - self._source: str | LLMStream | AsyncIterable[str] | None = None - self._synthesis_handle: SynthesisHandle | None = None - - # nested speech handle and function calls - self._fnc_nested_depth = fnc_nested_depth - self._fnc_extra_tools_messages: list[ChatMessage] | None = extra_tools_messages - self._fnc_text_message_id: str | None = fnc_text_message_id - - self._nested_speech_handles: list[SpeechHandle] = [] - self._nested_speech_changed = asyncio.Event() - self._nested_speech_done_fut = asyncio.Future[None]() - - @staticmethod - def create_assistant_reply( - *, - allow_interruptions: bool, - add_to_chat_ctx: bool, - user_question: str, - ) -> SpeechHandle: - return SpeechHandle( - id=utils.shortuuid(), - allow_interruptions=allow_interruptions, - add_to_chat_ctx=add_to_chat_ctx, - is_reply=True, - user_question=user_question, - ) - - @staticmethod - def create_assistant_speech( - *, - allow_interruptions: bool, - add_to_chat_ctx: bool, - ) -> SpeechHandle: - return SpeechHandle( - id=utils.shortuuid(), - allow_interruptions=allow_interruptions, - add_to_chat_ctx=add_to_chat_ctx, - is_reply=False, - user_question="", - ) - - @staticmethod - def create_tool_speech( - *, - allow_interruptions: bool, - add_to_chat_ctx: bool, - fnc_nested_depth: int, - extra_tools_messages: list[ChatMessage], - fnc_text_message_id: str | None = None, - ) -> SpeechHandle: - return SpeechHandle( - id=utils.shortuuid(), - allow_interruptions=allow_interruptions, - add_to_chat_ctx=add_to_chat_ctx, - is_reply=False, - user_question="", - fnc_nested_depth=fnc_nested_depth, - extra_tools_messages=extra_tools_messages, - fnc_text_message_id=fnc_text_message_id, - ) - - async def wait_for_initialization(self) -> None: - await asyncio.shield(self._init_fut) - - def initialize( - self, - *, - source: str | LLMStream | AsyncIterable[str], - synthesis_handle: SynthesisHandle, - ) -> None: - if self.interrupted: - raise RuntimeError("speech is interrupted") - - self._source = source - self._synthesis_handle = synthesis_handle - self._initialized = True - self._init_fut.set_result(None) - - def mark_user_committed(self) -> None: - self._user_committed = True - - def mark_speech_committed(self) -> None: - self._speech_committed = True - - @property - def user_committed(self) -> bool: - return self._user_committed - - @property - def speech_committed(self) -> bool: - return self._speech_committed - - @property - def id(self) -> str: - return self._id - - @property - def allow_interruptions(self) -> bool: - return self._allow_interruptions - - @property - def add_to_chat_ctx(self) -> bool: - return self._add_to_chat_ctx - - @property - def source(self) -> str | LLMStream | AsyncIterable[str]: - if self._source is None: - raise RuntimeError("speech not initialized") - return self._source - - @property - def synthesis_handle(self) -> SynthesisHandle: - if self._synthesis_handle is None: - raise RuntimeError("speech not initialized") - return self._synthesis_handle - - @synthesis_handle.setter - def synthesis_handle(self, synthesis_handle: SynthesisHandle) -> None: - """synthesis handle can be replaced for the same speech. - This is useful when we need to do a new generation. (e.g for automatic function call answers)""" - if self._synthesis_handle is None: - raise RuntimeError("speech not initialized") - - self._synthesis_handle = synthesis_handle - - @property - def initialized(self) -> bool: - return self._initialized - - @property - def is_reply(self) -> bool: - return self._is_reply - - @property - def user_question(self) -> str: - return self._user_question - - @property - def interrupted(self) -> bool: - return self._init_fut.cancelled() or ( - self._synthesis_handle is not None and self._synthesis_handle.interrupted - ) - - def join(self) -> asyncio.Future: - return self._done_fut - - def _set_done(self) -> None: - self._done_fut.set_result(None) - - def interrupt(self) -> None: - if not self.allow_interruptions: - raise RuntimeError("interruptions are not allowed") - self.cancel() - - def cancel(self, cancel_nested: bool = False) -> None: - self._init_fut.cancel() - - if self._synthesis_handle is not None: - self._synthesis_handle.interrupt() - - if cancel_nested: - for speech in self._nested_speech_handles: - speech.cancel(cancel_nested=True) - self.mark_nested_speech_done() - - @property - def fnc_nested_depth(self) -> int: - return self._fnc_nested_depth - - @property - def extra_tools_messages(self) -> list[ChatMessage] | None: - return self._fnc_extra_tools_messages - - @property - def fnc_text_message_id(self) -> str | None: - return self._fnc_text_message_id - - def add_nested_speech(self, speech_handle: SpeechHandle) -> None: - self._nested_speech_handles.append(speech_handle) - self._nested_speech_changed.set() - - @property - def nested_speech_handles(self) -> list[SpeechHandle]: - return self._nested_speech_handles - - @property - def nested_speech_changed(self) -> asyncio.Event: - return self._nested_speech_changed - - @property - def nested_speech_done(self) -> bool: - return self._nested_speech_done_fut.done() - - def mark_nested_speech_done(self) -> None: - if self._nested_speech_done_fut.done(): - return - self._nested_speech_done_fut.set_result(None) diff --git a/livekit-agents/livekit/agents/transcription/__init__.py b/livekit-agents/livekit/agents/transcription/__init__.py deleted file mode 100644 index 18820b744..000000000 --- a/livekit-agents/livekit/agents/transcription/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .stt_forwarder import STTSegmentsForwarder -from .tts_forwarder import TTSSegmentsForwarder - -__all__ = [ - "TTSSegmentsForwarder", - "STTSegmentsForwarder", -] diff --git a/livekit-agents/livekit/agents/transcription/_utils.py b/livekit-agents/livekit/agents/transcription/_utils.py deleted file mode 100644 index 4e24960dd..000000000 --- a/livekit-agents/livekit/agents/transcription/_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -from livekit import rtc - -from ..utils import shortuuid - - -def find_micro_track_id(room: rtc.Room, identity: str) -> str: - p: rtc.RemoteParticipant | rtc.LocalParticipant | None = ( - room.remote_participants.get(identity) - ) - if identity == room.local_participant.identity: - p = room.local_participant - - if p is None: - raise ValueError(f"participant {identity} not found") - - # find first micro track - track_id = None - for track in p.track_publications.values(): - if track.source == rtc.TrackSource.SOURCE_MICROPHONE: - track_id = track.sid - break - - if track_id is None: - raise ValueError(f"participant {identity} does not have a microphone track") - - return track_id - - -def segment_uuid() -> str: - return shortuuid("SG_") diff --git a/livekit-agents/livekit/agents/transcription/stt_forwarder.py b/livekit-agents/livekit/agents/transcription/stt_forwarder.py deleted file mode 100644 index 0d526a3a6..000000000 --- a/livekit-agents/livekit/agents/transcription/stt_forwarder.py +++ /dev/null @@ -1,126 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -from typing import Awaitable, Callable, Optional, Union - -from livekit import rtc - -from .. import stt -from ..log import logger -from . import _utils - -BeforeForwardCallback = Callable[ - ["STTSegmentsForwarder", rtc.Transcription], - Union[rtc.Transcription, Awaitable[Optional[rtc.Transcription]]], -] - - -WillForwardTranscription = BeforeForwardCallback - - -def _default_before_forward_cb( - fwd: STTSegmentsForwarder, transcription: rtc.Transcription -) -> rtc.Transcription: - return transcription - - -class STTSegmentsForwarder: - """ - Forward STT transcription to the users. (Useful for client-side rendering) - """ - - def __init__( - self, - *, - room: rtc.Room, - participant: rtc.Participant | str, - track: rtc.Track | rtc.TrackPublication | str | None = None, - before_forward_cb: BeforeForwardCallback = _default_before_forward_cb, - # backward compatibility - will_forward_transcription: WillForwardTranscription | None = None, - ): - identity = participant if isinstance(participant, str) else participant.identity - if track is None: - track = _utils.find_micro_track_id(room, identity) - elif isinstance(track, (rtc.TrackPublication, rtc.Track)): - track = track.sid - - if will_forward_transcription is not None: - logger.warning( - "will_forward_transcription is deprecated and will be removed in 1.5.0, use before_forward_cb instead", - ) - before_forward_cb = will_forward_transcription - - self._room, self._participant_identity, self._track_id = room, identity, track - self._before_forward_cb = before_forward_cb - self._queue = asyncio.Queue[Optional[rtc.TranscriptionSegment]]() - self._main_task = asyncio.create_task(self._run()) - self._current_id = _utils.segment_uuid() - - async def _run(self): - try: - while True: - seg = await self._queue.get() - if seg is None: - break - - base_transcription = rtc.Transcription( - participant_identity=self._participant_identity, - track_sid=self._track_id, - segments=[seg], # no history for now - ) - - transcription = self._before_forward_cb(self, base_transcription) - if asyncio.iscoroutine(transcription): - transcription = await transcription - - if not isinstance(transcription, rtc.Transcription): - transcription = _default_before_forward_cb(self, base_transcription) - - if transcription.segments and self._room.isconnected(): - await self._room.local_participant.publish_transcription( - transcription - ) - - except Exception: - logger.exception("error in stt transcription") - - def update(self, ev: stt.SpeechEvent): - if ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: - # TODO(theomonnom): We always take the first alternative, we should mb expose opt to the - # user? - text = ev.alternatives[0].text - self._queue.put_nowait( - rtc.TranscriptionSegment( - id=self._current_id, - text=text, - start_time=0, - end_time=0, - final=False, - language="", # TODO - ) - ) - elif ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: - text = ev.alternatives[0].text - self._queue.put_nowait( - rtc.TranscriptionSegment( - id=self._current_id, - text=text, - start_time=0, - end_time=0, - final=True, - language="", # TODO - ) - ) - - self._current_id = _utils.segment_uuid() - - async def aclose(self, *, wait: bool = True) -> None: - self._queue.put_nowait(None) - - if not wait: - self._main_task.cancel() - - with contextlib.suppress(asyncio.CancelledError): - await self._main_task diff --git a/livekit-agents/livekit/agents/transcription/tts_forwarder.py b/livekit-agents/livekit/agents/transcription/tts_forwarder.py deleted file mode 100644 index 40c1410bc..000000000 --- a/livekit-agents/livekit/agents/transcription/tts_forwarder.py +++ /dev/null @@ -1,430 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import time -from dataclasses import dataclass -from typing import Awaitable, Callable, Optional, Union - -from livekit import rtc -from livekit.rtc.participant import PublishTranscriptionError - -from .. import tokenize, utils -from ..log import logger -from ..tokenize.tokenizer import PUNCTUATIONS -from . import _utils - -# 3.83 is the "baseline", the number of hyphens per second TTS returns in avg. -STANDARD_SPEECH_RATE = 3.83 - - -BeforeForwardCallback = Callable[ - ["TTSSegmentsForwarder", rtc.Transcription], - Union[rtc.Transcription, Awaitable[Optional[rtc.Transcription]]], -] - - -WillForwardTranscription = BeforeForwardCallback - - -def _default_before_forward_callback( - fwd: TTSSegmentsForwarder, transcription: rtc.Transcription -) -> rtc.Transcription: - return transcription - - -@dataclass -class _TTSOptions: - room: rtc.Room - participant_identity: str - track_id: str - language: str - speed: float - word_tokenizer: tokenize.WordTokenizer - sentence_tokenizer: tokenize.SentenceTokenizer - hyphenate_word: Callable[[str], list[str]] - new_sentence_delay: float - before_forward_cb: BeforeForwardCallback - - -@dataclass -class _AudioData: - pushed_duration: float = 0.0 - done: bool = False - - -@dataclass -class _TextData: - sentence_stream: tokenize.SentenceStream - pushed_text: str = "" - done: bool = False - - forwarded_hyphens: int = 0 - forwarded_sentences: int = 0 - - -class TTSSegmentsForwarder: - """ - Forward TTS transcription to the users. This class tries to imitate the right timing of - speech with the synthesized text. The first estimation is based on the speed argument. Once - we have received the full audio of a specific text segment, we recalculate the avg speech - speed using the length of the text & audio and catch up/ slow down the transcription if needed. - """ - - def __init__( - self, - *, - room: rtc.Room, - participant: rtc.Participant | str, - track: rtc.Track | rtc.TrackPublication | str | None = None, - language: str = "", - speed: float = 1.0, - new_sentence_delay: float = 0.4, - word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(), - sentence_tokenizer: tokenize.SentenceTokenizer = tokenize.basic.SentenceTokenizer(), - hyphenate_word: Callable[[str], list[str]] = tokenize.basic.hyphenate_word, - before_forward_cb: BeforeForwardCallback = _default_before_forward_callback, - loop: asyncio.AbstractEventLoop | None = None, - # backward compatibility - will_forward_transcription: WillForwardTranscription | None = None, - ): - """ - Args: - room: room where the transcription will be sent - participant: participant or identity that is pushing the TTS - track: track where the TTS audio is being sent - language: language of the text - speed: average speech speed in characters per second (used by default if the full audio is not received yet) - new_sentence_delay: delay in seconds between sentences - auto_playout: if True, the forwarder will automatically start the transcription once the - first audio frame is received. If False, you need to call segment_playout_started - to start the transcription. - word_tokenizer: word tokenizer used to split the text into words - sentence_tokenizer: sentence tokenizer used to split the text into sentences - hyphenate_word: function that returns a list of hyphens for a given word - - """ - identity = participant if isinstance(participant, str) else participant.identity - - if track is None: - track = _utils.find_micro_track_id(room, identity) - elif isinstance(track, (rtc.TrackPublication, rtc.Track)): - track = track.sid - - if will_forward_transcription is not None: - logger.warning( - "will_forward_transcription is deprecated and will be removed in 1.5.0, use before_forward_cb instead", - ) - before_forward_cb = will_forward_transcription - - speed = speed * STANDARD_SPEECH_RATE - self._opts = _TTSOptions( - room=room, - participant_identity=identity, - track_id=track, - language=language, - speed=speed, - word_tokenizer=word_tokenizer, - sentence_tokenizer=sentence_tokenizer, - hyphenate_word=hyphenate_word, - new_sentence_delay=new_sentence_delay, - before_forward_cb=before_forward_cb, - ) - self._closed = False - self._loop = loop or asyncio.get_event_loop() - self._close_future = asyncio.Future[None]() - - self._playing_seg_index = -1 - self._finshed_seg_index = -1 - - self._text_q_changed = asyncio.Event() - self._text_q = list[Union[_TextData, None]]() - self._audio_q_changed = asyncio.Event() - self._audio_q = list[Union[_AudioData, None]]() - - self._text_data: _TextData | None = None - self._audio_data: _AudioData | None = None - - self._played_text = "" - - self._main_atask: asyncio.Task | None = None - self._task_set = utils.aio.TaskSet(loop) - - def segment_playout_started(self) -> None: - """ - Notify that the playout of the audio segment has started. - This will start forwarding the transcription for the current segment. - """ - self._check_not_closed() - self._playing_seg_index += 1 - - if self._main_atask is None: - self._main_atask = asyncio.create_task(self._main_task()) - - def segment_playout_finished(self) -> None: - """ - Notify that the playout of the audio segment has finished. - This will catchup and directly send the final transcription in case the forwarder is too - late. - """ - self._check_not_closed() - self._finshed_seg_index += 1 - - def push_audio(self, frame: rtc.AudioFrame) -> None: - self._check_not_closed() - - if self._audio_data is None: - self._audio_data = _AudioData() - self._audio_q.append(self._audio_data) - self._audio_q_changed.set() - - frame_duration = frame.samples_per_channel / frame.sample_rate - self._audio_data.pushed_duration += frame_duration - - def mark_audio_segment_end(self) -> None: - self._check_not_closed() - - if self._audio_data is None: - self.push_audio(rtc.AudioFrame(bytes(), 24000, 1, 0)) - - assert self._audio_data is not None - self._audio_data.done = True - self._audio_data = None - - def push_text(self, text: str) -> None: - self._check_not_closed() - - if self._text_data is None: - self._text_data = _TextData( - sentence_stream=self._opts.sentence_tokenizer.stream() - ) - self._text_q.append(self._text_data) - self._text_q_changed.set() - - self._text_data.pushed_text += text - self._text_data.sentence_stream.push_text(text) - - def mark_text_segment_end(self) -> None: - self._check_not_closed() - - if self._text_data is None: - self.push_text("") - - assert self._text_data is not None - self._text_data.done = True - self._text_data.sentence_stream.end_input() - self._text_data = None - - @property - def closed(self) -> bool: - return self._closed - - @property - def played_text(self) -> str: - return self._played_text - - async def aclose(self) -> None: - if self._closed: - return - - self._closed = True - self._close_future.set_result(None) - - for text_data in self._text_q: - assert text_data is not None - await text_data.sentence_stream.aclose() - - self._text_q.append(None) - self._audio_q.append(None) - self._text_q_changed.set() - self._audio_q_changed.set() - - await self._task_set.aclose() - - if self._main_atask is not None: - await self._main_atask - - @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: - """Main task that forwards the transcription to the room.""" - rtc_seg_ch = utils.aio.Chan[rtc.TranscriptionSegment]() - - @utils.log_exceptions(logger=logger) - async def _forward_task(): - async for rtc_seg in rtc_seg_ch: - base_transcription = rtc.Transcription( - participant_identity=self._opts.participant_identity, - track_sid=self._opts.track_id, - segments=[rtc_seg], # no history for now - ) - - transcription = self._opts.before_forward_cb(self, base_transcription) - if asyncio.iscoroutine(transcription): - transcription = await transcription - - # fallback to default impl if no custom/user stream is returned - if not isinstance(transcription, rtc.Transcription): - transcription = _default_before_forward_callback( - self, base_transcription - ) - - if transcription.segments and self._opts.room.isconnected(): - try: - await self._opts.room.local_participant.publish_transcription( - transcription - ) - except PublishTranscriptionError: - continue - - forward_task = asyncio.create_task(_forward_task()) - - seg_index = 0 - q_done = False - while not q_done: - await self._text_q_changed.wait() - await self._audio_q_changed.wait() - - while self._text_q and self._audio_q: - text_data = self._text_q.pop(0) - audio_data = self._audio_q.pop(0) - - if text_data is None or audio_data is None: - q_done = True - break - - # wait until the segment is validated and has started playing - while not self._closed: - if self._playing_seg_index >= seg_index: - break - - await self._sleep_if_not_closed(0.125) - - sentence_stream = text_data.sentence_stream - forward_start_time = time.time() - - async for ev in sentence_stream: - await self._sync_sentence_co( - seg_index, - forward_start_time, - text_data, - audio_data, - ev.token, - rtc_seg_ch, - ) - - seg_index += 1 - - self._text_q_changed.clear() - self._audio_q_changed.clear() - - rtc_seg_ch.close() - await forward_task - - async def _sync_sentence_co( - self, - segment_index: int, - segment_start_time: float, - text_data: _TextData, - audio_data: _AudioData, - sentence: str, - rtc_seg_ch: utils.aio.Chan[rtc.TranscriptionSegment], - ): - """Synchronize the transcription with the audio playout for a given sentence.""" - # put each sentence in a different transcription segment - - real_speed = None - if audio_data.pushed_duration > 0 and audio_data.done: - real_speed = ( - len(self._calc_hyphens(text_data.pushed_text)) - / audio_data.pushed_duration - ) - - seg_id = _utils.segment_uuid() - words = self._opts.word_tokenizer.tokenize(text=sentence) - processed_words: list[str] = [] - - og_text = self._played_text - for word in words: - if segment_index <= self._finshed_seg_index: - # playout of the audio segment already finished - # break the loop and send the final transcription - break - - if self._closed: - # transcription closed, early - return - - word_hyphens = len(self._opts.hyphenate_word(word)) - processed_words.append(word) - - # elapsed time since the start of the seg - elapsed_time = time.time() - segment_start_time - text = self._opts.word_tokenizer.format_words(processed_words) - - # remove any punctuation at the end of a non-final transcript - text = text.rstrip("".join(PUNCTUATIONS)) - - speed = self._opts.speed - if real_speed is not None: - speed = real_speed - estimated_pauses_s = ( - text_data.forwarded_sentences * self._opts.new_sentence_delay - ) - hyph_pauses = estimated_pauses_s * speed - - target_hyphens = round(speed * elapsed_time) - dt = target_hyphens - text_data.forwarded_hyphens - hyph_pauses - to_wait_hyphens = max(0.0, word_hyphens - dt) - delay = to_wait_hyphens / speed - else: - delay = word_hyphens / speed - - first_delay = min(delay / 2, 2 / speed) - await self._sleep_if_not_closed(first_delay) - - rtc_seg_ch.send_nowait( - rtc.TranscriptionSegment( - id=seg_id, - text=text, - start_time=0, - end_time=0, - final=False, - language=self._opts.language, - ) - ) - self._played_text = f"{og_text} {text}" - - await self._sleep_if_not_closed(delay - first_delay) - text_data.forwarded_hyphens += word_hyphens - - rtc_seg_ch.send_nowait( - rtc.TranscriptionSegment( - id=seg_id, - text=sentence, - start_time=0, - end_time=0, - final=True, - language=self._opts.language, - ) - ) - self._played_text = f"{og_text} {sentence}" - - await self._sleep_if_not_closed(self._opts.new_sentence_delay) - text_data.forwarded_sentences += 1 - - async def _sleep_if_not_closed(self, delay: float) -> None: - with contextlib.suppress(asyncio.TimeoutError): - await asyncio.wait([self._close_future], timeout=delay) - - def _calc_hyphens(self, text: str) -> list[str]: - hyphens: list[str] = [] - words = self._opts.word_tokenizer.tokenize(text=text) - for word in words: - new = self._opts.hyphenate_word(word) - hyphens.extend(new) - - return hyphens - - def _check_not_closed(self) -> None: - if self._closed: - raise RuntimeError("TTSForwarder is closed") From 5d5054db272473c6bcf53f59d0dc4cea86ea1e24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Mon, 13 Jan 2025 12:51:24 +0100 Subject: [PATCH 06/19] realtime API abstraction --- .../livekit/agents/multimodal/__init__.py | 24 +- .../livekit/agents/multimodal/realtime.py | 23 +- .../livekit/plugins/openai/beta/__init__.py | 17 - .../plugins/openai/beta/assistant_llm.py | 594 ------ .../plugins/openai/realtime/api_proto.py | 607 ------ .../plugins/openai/realtime/realtime_model.py | 1807 ++--------------- .../plugins/openai/realtime/remote_items.py | 126 -- 7 files changed, 229 insertions(+), 2969 deletions(-) delete mode 100644 livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/__init__.py delete mode 100644 livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py delete mode 100644 livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/api_proto.py delete mode 100644 livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/remote_items.py diff --git a/livekit-agents/livekit/agents/multimodal/__init__.py b/livekit-agents/livekit/agents/multimodal/__init__.py index f741e168a..fe6d5d654 100644 --- a/livekit-agents/livekit/agents/multimodal/__init__.py +++ b/livekit-agents/livekit/agents/multimodal/__init__.py @@ -1,13 +1,19 @@ -from .multimodal_agent import ( - AgentTranscriptionOptions, - MultimodalAgent, - _RealtimeAPI, - _RealtimeAPISession, +from .realtime import ( + RealtimeModel, + RealtimeCapabilities, + RealtimeSession, + InputSpeechStartedEvent, + InputSpeechStoppedEvent, + GenerationCreatedEvent, + ErrorEvent, ) __all__ = [ - "MultimodalAgent", - "AgentTranscriptionOptions", - "_RealtimeAPI", - "_RealtimeAPISession", + "RealtimeModel", + "RealtimeCapabilities", + "RealtimeSession", + "InputSpeechStartedEvent", + "InputSpeechStoppedEvent", + "GenerationCreatedEvent", + "ErrorEvent", ] diff --git a/livekit-agents/livekit/agents/multimodal/realtime.py b/livekit-agents/livekit/agents/multimodal/realtime.py index d8ab54424..114e33ca9 100644 --- a/livekit-agents/livekit/agents/multimodal/realtime.py +++ b/livekit-agents/livekit/agents/multimodal/realtime.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from abc import ABC, abstractmethod @@ -38,7 +40,7 @@ class RealtimeCapabilities: class RealtimeModel: - def __init__(self, capabilities: RealtimeCapabilities) -> None: + def __init__(self, *, capabilities: RealtimeCapabilities) -> None: self._capabilities = capabilities @property @@ -48,6 +50,9 @@ def capabilities(self) -> RealtimeCapabilities: @abstractmethod def session(self) -> "RealtimeSession": ... + @abstractmethod + async def aclose(self) -> None: ... + EventTypes = Literal[ "input_speech_started", # serverside VAD @@ -79,12 +84,28 @@ def chat_ctx(self) -> llm.ChatContext: ... @abstractmethod async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: ... + @property + @abstractmethod + def fnc_ctx(self) -> llm.FunctionContext | None: ... + + @abstractmethod + async def update_fnc_ctx(self, fnc_ctx: llm.FunctionContext | None) -> None: ... + @abstractmethod def push_audio(self, frame: rtc.AudioFrame) -> None: ... @abstractmethod def generate_reply(self) -> None: ... # when VAD is disabled + # cancel the current generation (do nothing if no generation is in progress) + @abstractmethod + def interrupt( + self, + ) -> None: ... + # message_id is the ID of the message to truncate (inside the ChatCtx) @abstractmethod def truncate(self, *, message_id: str, audio_end_ms: int) -> None: ... + + @abstractmethod + async def aclose(self) -> None: ... diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/__init__.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/__init__.py deleted file mode 100644 index f062606fb..000000000 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .assistant_llm import ( - AssistantCreateOptions, - AssistantLLM, - AssistantLoadOptions, - AssistantOptions, - OnFileUploaded, - OnFileUploadedInfo, -) - -__all__ = [ - "AssistantLLM", - "AssistantOptions", - "AssistantCreateOptions", - "AssistantLoadOptions", - "OnFileUploaded", - "OnFileUploadedInfo", -] diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py deleted file mode 100644 index 7df336e89..000000000 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py +++ /dev/null @@ -1,594 +0,0 @@ -# Copyright 2023 LiveKit, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import asyncio -import json -import uuid -from dataclasses import dataclass -from typing import Any, Callable, Dict, Literal, MutableSet, Union - -import httpx -from livekit import rtc -from livekit.agents import llm, utils -from livekit.agents.llm import ToolChoice -from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions - -from openai import AsyncAssistantEventHandler, AsyncClient -from openai.types.beta.threads import Text, TextDelta -from openai.types.beta.threads.run_create_params import AdditionalMessage -from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput -from openai.types.beta.threads.runs import ( - CodeInterpreterToolCall, - FileSearchToolCall, - FunctionToolCall, - ToolCall, -) -from openai.types.file_object import FileObject - -from .._oai_api import build_oai_function_description -from ..log import logger -from ..models import ChatModels - -DEFAULT_MODEL = "gpt-4o" -OPENAI_MESSAGE_ID_KEY = "__openai_message_id__" -LIVEKIT_MESSAGE_ID_KEY = "__livekit_message_id__" -OPENAI_MESSAGES_ADDED_KEY = "__openai_messages_added__" -OPENAI_FILE_ID_KEY = "__openai_file_id__" - - -@dataclass -class LLMOptions: - model: str | ChatModels - - -@dataclass -class AssistantOptions: - """Options for creating (on-the-fly) or loading an assistant. Only one of create_options or load_options should be set.""" - - create_options: AssistantCreateOptions | None = None - load_options: AssistantLoadOptions | None = None - - -@dataclass -class AssistantCreateOptions: - name: str - instructions: str - model: ChatModels - temperature: float | None = None - # TODO: when we implement code_interpreter and file_search tools - # tool_resources: ToolResources | None = None - # tools: list[AssistantTools] = field(default_factory=list) - - -@dataclass -class AssistantLoadOptions: - assistant_id: str - thread_id: str | None - - -@dataclass -class OnFileUploadedInfo: - type: Literal["image"] - original_file: llm.ChatImage - openai_file_object: FileObject - - -OnFileUploaded = Callable[[OnFileUploadedInfo], None] - - -class AssistantLLM(llm.LLM): - def __init__( - self, - *, - assistant_opts: AssistantOptions, - client: AsyncClient | None = None, - api_key: str | None = None, - base_url: str | None = None, - on_file_uploaded: OnFileUploaded | None = None, - ) -> None: - super().__init__() - - test_ctx = llm.ChatContext() - if not hasattr(test_ctx, "_metadata"): - raise Exception( - "This beta feature of 'livekit-plugins-openai' requires a newer version of 'livekit-agents'" - ) - self._client = client or AsyncClient( - api_key=api_key, - base_url=base_url, - http_client=httpx.AsyncClient( - timeout=httpx.Timeout(timeout=30, connect=10, read=5, pool=5), - follow_redirects=True, - limits=httpx.Limits( - max_connections=1000, - max_keepalive_connections=100, - keepalive_expiry=120, - ), - ), - ) - self._assistant_opts = assistant_opts - self._running_fncs: MutableSet[asyncio.Task[Any]] = set() - self._on_file_uploaded = on_file_uploaded - self._tool_call_run_id_lookup = dict[str, str]() - self._submitted_tool_calls = set[str]() - - self._sync_openai_task: asyncio.Task[AssistantLoadOptions] | None = None - try: - self._sync_openai_task = asyncio.create_task(self._sync_openai()) - except Exception: - logger.error( - "failed to create sync openai task. This can happen when instantiating without a running asyncio event loop (such has when running tests)" - ) - self._done_futures = list[asyncio.Future[None]]() - - async def _sync_openai(self) -> AssistantLoadOptions: - if self._assistant_opts.create_options: - kwargs: Dict[str, Any] = { - "model": self._assistant_opts.create_options.model, - "name": self._assistant_opts.create_options.name, - "instructions": self._assistant_opts.create_options.instructions, - # "tools": [ - # {"type": t} for t in self._assistant_opts.create_options.tools - # ], - # "tool_resources": self._assistant_opts.create_options.tool_resources, - } - # TODO when we implement code_interpreter and file_search tools - # if self._assistant_opts.create_options.tool_resources: - # kwargs["tool_resources"] = ( - # self._assistant_opts.create_options.tool_resources - # ) - if self._assistant_opts.create_options.temperature: - kwargs["temperature"] = self._assistant_opts.create_options.temperature - assistant = await self._client.beta.assistants.create(**kwargs) - - thread = await self._client.beta.threads.create() - return AssistantLoadOptions(assistant_id=assistant.id, thread_id=thread.id) - elif self._assistant_opts.load_options: - if not self._assistant_opts.load_options.thread_id: - thread = await self._client.beta.threads.create() - self._assistant_opts.load_options.thread_id = thread.id - return self._assistant_opts.load_options - - raise Exception("One of create_options or load_options must be set") - - def chat( - self, - *, - chat_ctx: llm.ChatContext, - conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, - fnc_ctx: llm.FunctionContext | None = None, - temperature: float | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] - | None = None, - ): - if n is not None: - logger.warning("OpenAI Assistants does not support the 'n' parameter") - - if parallel_tool_calls is not None: - logger.warning( - "OpenAI Assistants does not support the 'parallel_tool_calls' parameter" - ) - - if not self._sync_openai_task: - self._sync_openai_task = asyncio.create_task(self._sync_openai()) - - return AssistantLLMStream( - temperature=temperature, - assistant_llm=self, - sync_openai_task=self._sync_openai_task, - client=self._client, - chat_ctx=chat_ctx, - fnc_ctx=fnc_ctx, - on_file_uploaded=self._on_file_uploaded, - conn_options=conn_options, - ) - - async def _register_tool_call(self, tool_call_id: str, run_id: str) -> None: - self._tool_call_run_id_lookup[tool_call_id] = run_id - - async def _submit_tool_call_result(self, tool_call_id: str, result: str) -> None: - if tool_call_id in self._submitted_tool_calls: - return - logger.debug(f"submitting tool call {tool_call_id} result") - run_id = self._tool_call_run_id_lookup.get(tool_call_id) - if not run_id: - logger.error(f"tool call {tool_call_id} not found") - return - - if not self._sync_openai_task: - logger.error("sync_openai_task not set") - return - - thread_id = (await self._sync_openai_task).thread_id - if not thread_id: - logger.error("thread_id not set") - return - tool_output = ToolOutput(output=result, tool_call_id=tool_call_id) - await self._client.beta.threads.runs.submit_tool_outputs_and_poll( - tool_outputs=[tool_output], run_id=run_id, thread_id=thread_id - ) - self._submitted_tool_calls.add(tool_call_id) - logger.debug(f"submitted tool call {tool_call_id} result") - - -class AssistantLLMStream(llm.LLMStream): - class EventHandler(AsyncAssistantEventHandler): - def __init__( - self, - llm: AssistantLLM, - llm_stream: AssistantLLMStream, - event_ch: utils.aio.Chan[llm.ChatChunk], - chat_ctx: llm.ChatContext, - fnc_ctx: llm.FunctionContext | None = None, - ): - super().__init__() - self._llm = llm - self._llm_stream = llm_stream - self._chat_ctx = chat_ctx - self._event_ch = event_ch - self._fnc_ctx = fnc_ctx - - async def on_text_delta(self, delta: TextDelta, snapshot: Text): - assert self.current_run is not None - - self._event_ch.send_nowait( - llm.ChatChunk( - request_id=self.current_run.id, - choices=[ - llm.Choice( - delta=llm.ChoiceDelta(role="assistant", content=delta.value) - ) - ], - ) - ) - - async def on_tool_call_created(self, tool_call: ToolCall): - if not self.current_run: - logger.error("tool call created without run") - return - await self._llm._register_tool_call(tool_call.id, self.current_run.id) - - async def on_tool_call_done( - self, - tool_call: CodeInterpreterToolCall | FileSearchToolCall | FunctionToolCall, - ) -> None: - assert self.current_run is not None - - if tool_call.type == "code_interpreter": - logger.warning("code interpreter tool call not yet implemented") - elif tool_call.type == "file_search": - logger.warning("file_search tool call not yet implemented") - elif tool_call.type == "function": - if not self._fnc_ctx: - logger.error("function tool called without function context") - return - - fnc = llm.FunctionCallInfo( - function_info=self._fnc_ctx.ai_functions[tool_call.function.name], - arguments=json.loads(tool_call.function.arguments), - tool_call_id=tool_call.id, - raw_arguments=tool_call.function.arguments, - ) - - self._llm_stream._function_calls_info.append(fnc) - chunk = llm.ChatChunk( - request_id=self.current_run.id, - choices=[ - llm.Choice( - delta=llm.ChoiceDelta(role="assistant", tool_calls=[fnc]), - index=0, - ) - ], - ) - self._event_ch.send_nowait(chunk) - - def __init__( - self, - *, - assistant_llm: AssistantLLM, - client: AsyncClient, - sync_openai_task: asyncio.Task[AssistantLoadOptions], - chat_ctx: llm.ChatContext, - fnc_ctx: llm.FunctionContext | None, - temperature: float | None, - on_file_uploaded: OnFileUploaded | None, - conn_options: APIConnectOptions, - ) -> None: - super().__init__( - assistant_llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options - ) - self._client = client - self._temperature = temperature - self._on_file_uploaded = on_file_uploaded - - # current function call that we're waiting for full completion (args are streamed) - self._tool_call_id: str | None = None - self._fnc_name: str | None = None - self._fnc_raw_arguments: str | None = None - self._create_stream_task = asyncio.create_task(self._main_task()) - self._sync_openai_task = sync_openai_task - - # Running stream is used to ensure that we only have one stream running at a time - self._done_future: asyncio.Future[None] = asyncio.Future() - - async def _run(self) -> None: - assert isinstance(self._llm, AssistantLLM) - - # This function's complexity is due to the fact that we need to sync chat_ctx messages with OpenAI. - # OpenAI also does not allow us to modify messages while a stream is running. So we need to make sure streams run - # sequentially. The strategy is as follows: - # - # 1. ensure that we have a thread_id and assistant_id from OpenAI. This comes from the _sync_openai_task - # 2. make sure all previous streams are done before starting a new one - # 3. delete messages that are no longer in the chat_ctx but are still in OpenAI by using the OpenAI message id - # 4. add new messages to OpenAI that are in the chat_ctx but not in OpenAI. We don't know the OpenAI message id yet - # so we create a random uuid (we call it the LiveKit message id) and set that in the metdata. - # 5. start the stream and wait for it to finish - # 6. get the OpenAI message ids for the messages we added to OpenAI by using the metadata - # 7. Resolve the OpenAI message id with all messages that have a LiveKit message id. - try: - load_options = await self._sync_openai_task - - # The assistants api does not let us modify messages while a stream is running. - # So we have to make sure previous streams are done before starting a new one. - await asyncio.gather(*self._llm._done_futures) - self._llm._done_futures.clear() - self._llm._done_futures.append(self._done_future) - - # OpenAI required submitting tool call outputs manually. We iterate - # tool outputs in the chat_ctx (from previous runs) and submit them - # before continuing. - for msg in self._chat_ctx.messages: - if msg.role == "tool": - if not msg.tool_call_id: - logger.error("tool message without tool_call_id") - continue - if not isinstance(msg.content, str): - logger.error("tool message content is not str") - continue - await self._llm._submit_tool_call_result( - msg.tool_call_id, msg.content - ) - - # At the chat_ctx level, create a map of thread_id to message_ids - # This is used to keep track of which messages have been added to the thread - # and which we may need to delete from OpenAI - if OPENAI_MESSAGES_ADDED_KEY not in self._chat_ctx._metadata: - self._chat_ctx._metadata[OPENAI_MESSAGES_ADDED_KEY] = dict() - - if ( - load_options.thread_id - not in self._chat_ctx._metadata[OPENAI_MESSAGES_ADDED_KEY] - ): - self._chat_ctx._metadata[OPENAI_MESSAGES_ADDED_KEY][ - load_options.thread_id - ] = set() - - # Keep this handy to make the code more readable later on - openai_addded_messages_set: set[str] = self._chat_ctx._metadata[ - OPENAI_MESSAGES_ADDED_KEY - ][load_options.thread_id] - - # Keep track of messages that are no longer in the chat_ctx but are still in OpenAI - # Note: Unfortuneately, this will add latency unfortunately. Usually it's just one message so we loop it but - # it will create an extra round trip to OpenAI before being able to run inference. - # TODO: parallelize it? - for msg in self._chat_ctx.messages: - msg_id = msg._metadata.get(OPENAI_MESSAGE_ID_KEY, {}).get( - load_options.thread_id - ) - assert load_options.thread_id - if msg_id and msg_id not in openai_addded_messages_set: - await self._client.beta.threads.messages.delete( - thread_id=load_options.thread_id, - message_id=msg_id, - ) - logger.debug( - f"Deleted message '{msg_id}' in thread '{load_options.thread_id}'" - ) - openai_addded_messages_set.remove(msg_id) - - # Upload any images in the chat_ctx that have not been uploaded to OpenAI - for msg in self._chat_ctx.messages: - if msg.role != "user": - continue - - if not isinstance(msg.content, list): - continue - - for cnt in msg.content: - if ( - not isinstance(cnt, llm.ChatImage) - or OPENAI_FILE_ID_KEY in cnt._cache - ): - continue - - if isinstance(cnt.image, str): - continue - - file_obj = await self._upload_frame( - cnt.image, cnt.inference_width, cnt.inference_height - ) - cnt._cache[OPENAI_FILE_ID_KEY] = file_obj.id - if self._on_file_uploaded: - self._on_file_uploaded( - OnFileUploadedInfo( - type="image", - original_file=cnt, - openai_file_object=file_obj, - ) - ) - - # Keep track of the new messages in the chat_ctx that we need to add to OpenAI - additional_messages: list[AdditionalMessage] = [] - for msg in self._chat_ctx.messages: - if msg.role != "user": - continue - - msg_id = str(uuid.uuid4()) - if OPENAI_MESSAGE_ID_KEY not in msg._metadata: - msg._metadata[OPENAI_MESSAGE_ID_KEY] = dict[str, str]() - - if LIVEKIT_MESSAGE_ID_KEY not in msg._metadata: - msg._metadata[LIVEKIT_MESSAGE_ID_KEY] = dict[str, str]() - - oai_msg_id_dict = msg._metadata[OPENAI_MESSAGE_ID_KEY] - lk_msg_id_dict = msg._metadata[LIVEKIT_MESSAGE_ID_KEY] - - if load_options.thread_id not in oai_msg_id_dict: - converted_msg = build_oai_message(msg) - converted_msg["private_message_id"] = msg_id - additional_messages.append( - AdditionalMessage( - role="user", - content=converted_msg["content"], - metadata={LIVEKIT_MESSAGE_ID_KEY: msg_id}, - ) - ) - lk_msg_id_dict[load_options.thread_id] = msg_id - - eh = AssistantLLMStream.EventHandler( - llm=self._llm, - event_ch=self._event_ch, - chat_ctx=self._chat_ctx, - fnc_ctx=self._fnc_ctx, - llm_stream=self, - ) - assert load_options.thread_id - kwargs: dict[str, Any] = { - "additional_messages": additional_messages, - "thread_id": load_options.thread_id, - "assistant_id": load_options.assistant_id, - "event_handler": eh, - "temperature": self._temperature, - } - if self._fnc_ctx: - kwargs["tools"] = [ - build_oai_function_description(f) - for f in self._fnc_ctx.ai_functions.values() - ] - - async with self._client.beta.threads.runs.stream(**kwargs) as stream: - await stream.until_done() - - # Populate the openai_message_id for the messages we added to OpenAI. Note, we do this after - # sending None to close the iterator so that it is done in parellel with any users of - # the stream. However, the next stream will not start until this is done. - lk_to_oai_lookup = dict[str, str]() - messages = await self._client.beta.threads.messages.list( - thread_id=load_options.thread_id, - limit=10, # We could be smarter and make a more exact query, but this is probably fine - ) - for oai_msg in messages.data: - if oai_msg.metadata.get(LIVEKIT_MESSAGE_ID_KEY): # type: ignore - lk_to_oai_lookup[oai_msg.metadata[LIVEKIT_MESSAGE_ID_KEY]] = ( # type: ignore - oai_msg.id - ) - - for msg in self._chat_ctx.messages: - if msg.role != "user": - continue - oai_msg_id_dict = msg._metadata.get(OPENAI_MESSAGE_ID_KEY) - lk_msg_id_dict = msg._metadata.get(LIVEKIT_MESSAGE_ID_KEY) - if oai_msg_id_dict is None or lk_msg_id_dict is None: - continue - - lk_msg_id = lk_msg_id_dict.get(load_options.thread_id) - if lk_msg_id and lk_msg_id in lk_to_oai_lookup: - oai_msg_id = lk_to_oai_lookup[lk_msg_id] - oai_msg_id_dict[load_options.thread_id] = oai_msg_id - openai_addded_messages_set.add(oai_msg_id) - # We don't need the LiveKit message id anymore - lk_msg_id_dict.pop(load_options.thread_id) - - finally: - self._done_future.set_result(None) - - async def _upload_frame( - self, - frame: rtc.VideoFrame, - inference_width: int | None, - inference_height: int | None, - ): - # inside our internal implementation, we allow to put extra metadata to - # each ChatImage (avoid to reencode each time we do a chatcompletion request) - opts = utils.images.EncodeOptions() - if inference_width and inference_height: - opts.resize_options = utils.images.ResizeOptions( - width=inference_width, - height=inference_height, - strategy="scale_aspect_fit", - ) - - encoded_data = utils.images.encode(frame, opts) - fileObj = await self._client.files.create( - file=("image.jpg", encoded_data), - purpose="vision", - ) - - return fileObj - - -def build_oai_message(msg: llm.ChatMessage): - oai_msg: dict[str, Any] = {"role": msg.role} - - if msg.name: - oai_msg["name"] = msg.name - - # add content if provided - if isinstance(msg.content, str): - oai_msg["content"] = msg.content - elif isinstance(msg.content, list): - oai_content: list[dict[str, Any]] = [] - for cnt in msg.content: - if isinstance(cnt, str): - oai_content.append({"type": "text", "text": cnt}) - elif isinstance(cnt, llm.ChatImage): - if cnt._cache[OPENAI_FILE_ID_KEY]: - oai_content.append( - { - "type": "image_file", - "image_file": {"file_id": cnt._cache[OPENAI_FILE_ID_KEY]}, - } - ) - - oai_msg["content"] = oai_content - - # make sure to provide when function has been called inside the context - # (+ raw_arguments) - if msg.tool_calls is not None: - tool_calls: list[dict[str, Any]] = [] - oai_msg["tool_calls"] = tool_calls - for fnc in msg.tool_calls: - tool_calls.append( - { - "id": fnc.tool_call_id, - "type": "function", - "function": { - "name": fnc.function_info.name, - "arguments": fnc.raw_arguments, - }, - } - ) - - # tool_call_id is set when the message is a response/result to a function call - # (content is a string in this case) - if msg.tool_call_id: - oai_msg["tool_call_id"] = msg.tool_call_id - - return oai_msg diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/api_proto.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/api_proto.py deleted file mode 100644 index 2bf9778d3..000000000 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/api_proto.py +++ /dev/null @@ -1,607 +0,0 @@ -from __future__ import annotations - -from typing import Literal, Union - -from typing_extensions import NotRequired, TypedDict - -SAMPLE_RATE = 24000 -NUM_CHANNELS = 1 - -IN_FRAME_SIZE = 2400 # 100ms -OUT_FRAME_SIZE = 1200 # 50ms - - -class FunctionToolChoice(TypedDict): - type: Literal["function"] - name: str - - -Voice = Literal["alloy", "echo", "shimmer", "ash", "ballad", "coral", "sage", "verse"] -ToolChoice = Union[Literal["auto", "none", "required"], FunctionToolChoice] -Role = Literal["system", "assistant", "user", "tool"] -GenerationFinishedReason = Literal["stop", "max_tokens", "content_filter", "interrupt"] -AudioFormat = Literal["pcm16", "g711_ulaw", "g711_alaw"] -InputTranscriptionModel = Literal["whisper-1"] -Modality = Literal["text", "audio"] -ResponseStatus = Literal[ - "in_progress", "completed", "incomplete", "cancelled", "failed" -] - -# https://platform.openai.com/docs/models/gp#gpt-4o-realtime -OpenAIModel = Literal[ - "gpt-4o-realtime-preview", - "gpt-4o-realtime-preview-2024-10-01", - "gpt-4o-realtime-preview-2024-12-17", - "gpt-4o-mini-realtime-preview", - "gpt-4o-mini-realtime-preview-2024-12-17", -] -DefaultOpenAIModel = "gpt-4o-realtime-preview" - - -class TextContent(TypedDict): - type: Literal["text"] - text: str - - -class InputTextContent(TypedDict): - type: Literal["input_text"] - text: str - - -class AudioContent(TypedDict): - type: Literal["audio"] - audio: str # b64 - - -class InputAudioContent(TypedDict): - type: Literal["input_audio"] - audio: str # b64 - - -Content = Union[InputTextContent, TextContent, AudioContent, InputAudioContent] - - -class ContentPart(TypedDict): - type: Literal["text", "audio"] - audio: NotRequired[str] # b64 - transcript: NotRequired[str] - - -class InputAudioTranscription(TypedDict): - model: InputTranscriptionModel | str - - -class ServerVad(TypedDict): - type: Literal["server_vad"] - threshold: NotRequired[float] - prefix_padding_ms: NotRequired[int] - silence_duration_ms: NotRequired[int] - - -class FunctionTool(TypedDict): - type: Literal["function"] - name: str - description: NotRequired[str | None] - parameters: dict - - -class SystemItem(TypedDict): - id: str - object: Literal["realtime.item"] - type: Literal["message"] - role: Literal["system"] - content: list[InputTextContent] - - -class UserItem(TypedDict): - id: str - object: Literal["realtime.item"] - type: Literal["message"] - role: Literal["user"] - content: list[InputTextContent | InputAudioContent] - - -class AssistantItem(TypedDict): - id: str - object: Literal["realtime.item"] - type: Literal["message"] - role: Literal["assistant"] - content: list[TextContent | AudioContent] - - -class FunctionCallItem(TypedDict): - id: str - object: Literal["realtime.item"] - type: Literal["function_call"] - call_id: str - name: str - arguments: str - - -class FunctionCallOutputItem(TypedDict): - id: str - object: Literal["realtime.item"] - type: Literal["function_call_output"] - call_id: str - output: str - - -class CancelledStatusDetails(TypedDict): - type: Literal["cancelled"] - reason: Literal["turn_detected", "client_cancelled"] - - -class IncompleteStatusDetails(TypedDict): - type: Literal["incomplete"] - reason: Literal["max_output_tokens", "content_filter"] - - -class Error(TypedDict): - code: str - message: str - - -class FailedStatusDetails(TypedDict): - type: Literal["failed"] - error: NotRequired[Error | None] - - -ResponseStatusDetails = Union[ - CancelledStatusDetails, IncompleteStatusDetails, FailedStatusDetails -] - - -class InputTokenDetails(TypedDict): - cached_tokens: int - text_tokens: int - audio_tokens: int - cached_tokens_details: CachedTokenDetails - - -class CachedTokenDetails(TypedDict): - text_tokens: int - audio_tokens: int - - -class OutputTokenDetails(TypedDict): - text_tokens: int - audio_tokens: int - - -class Usage(TypedDict): - total_tokens: int - input_tokens: int - output_tokens: int - input_token_details: InputTokenDetails - output_token_details: OutputTokenDetails - - -class Resource: - class Session(TypedDict): - id: str - object: Literal["realtime.session"] - expires_at: int - model: str - modalities: list[Literal["text", "audio"]] - instructions: str - voice: Voice - input_audio_format: AudioFormat - output_audio_format: AudioFormat - input_audio_transcription: InputAudioTranscription | None - turn_detection: ServerVad | None - tools: list[FunctionTool] - tool_choice: ToolChoice - temperature: float - max_response_output_tokens: int | Literal["inf"] - - class Conversation(TypedDict): - id: str - object: Literal["realtime.conversation"] - - Item = Union[SystemItem, UserItem, FunctionCallItem, FunctionCallOutputItem] - - class Response(TypedDict): - id: str - object: Literal["realtime.response"] - status: ResponseStatus - status_details: NotRequired[ResponseStatusDetails | None] - output: list[Resource.Item] - usage: NotRequired[Usage | None] - - -class ClientEvent: - class SessionUpdateData(TypedDict): - modalities: list[Literal["text", "audio"]] - instructions: str - voice: Voice - input_audio_format: AudioFormat - output_audio_format: AudioFormat - input_audio_transcription: InputAudioTranscription | None - turn_detection: ServerVad | None - tools: list[FunctionTool] - tool_choice: ToolChoice - temperature: float - # microsoft does not support inf, but accepts None - max_response_output_tokens: int | Literal["inf"] | None - - class SessionUpdate(TypedDict): - event_id: NotRequired[str] - type: Literal["session.update"] - session: ClientEvent.SessionUpdateData - - class InputAudioBufferAppend(TypedDict): - event_id: NotRequired[str] - type: Literal["input_audio_buffer.append"] - audio: str # b64 - - class InputAudioBufferCommit(TypedDict): - event_id: NotRequired[str] - type: Literal["input_audio_buffer.commit"] - - class InputAudioBufferClear(TypedDict): - event_id: NotRequired[str] - type: Literal["input_audio_buffer.clear"] - - class UserItemCreate(TypedDict): - id: str | None - type: Literal["message"] - role: Literal["user"] - content: list[InputTextContent | InputAudioContent] - - class AssistantItemCreate(TypedDict): - id: str | None - type: Literal["message"] - role: Literal["assistant"] - content: list[TextContent] - - class SystemItemCreate(TypedDict): - id: str | None - type: Literal["message"] - role: Literal["system"] - content: list[InputTextContent] - - class FunctionCallOutputItemCreate(TypedDict): - id: str | None - type: Literal["function_call_output"] - call_id: str - output: str - - class FunctionCallItemCreate(TypedDict): - id: str | None - type: Literal["function_call"] - call_id: str - name: str - arguments: str - - ConversationItemCreateContent = Union[ - UserItemCreate, - AssistantItemCreate, - SystemItemCreate, - FunctionCallOutputItemCreate, - FunctionCallItemCreate, - ] - - class ConversationItemCreate(TypedDict): - event_id: NotRequired[str] - type: Literal["conversation.item.create"] - previous_item_id: NotRequired[str | None] - item: ClientEvent.ConversationItemCreateContent - - class ConversationItemTruncate(TypedDict): - event_id: NotRequired[str] - type: Literal["conversation.item.truncate"] - item_id: str - content_index: int - audio_end_ms: int - - class ConversationItemDelete(TypedDict): - event_id: NotRequired[str] - type: Literal["conversation.item.delete"] - item_id: str - - class ResponseCreateData(TypedDict, total=False): - modalities: list[Literal["text", "audio"]] - instructions: str - voice: Voice - output_audio_format: AudioFormat - tools: list[FunctionTool] - tool_choice: ToolChoice - temperature: float - max_output_tokens: int | Literal["inf"] - - class ResponseCreate(TypedDict): - event_id: NotRequired[str] - type: Literal["response.create"] - response: NotRequired[ClientEvent.ResponseCreateData] - - class ResponseCancel(TypedDict): - event_id: NotRequired[str] - type: Literal["response.cancel"] - - -class ServerEvent: - class ErrorContent(TypedDict): - type: str - code: NotRequired[str] - message: str - param: NotRequired[str] - event_id: NotRequired[str] - - class Error(TypedDict): - event_id: str - type: Literal["error"] - error: ServerEvent.ErrorContent - - class SessionCreated(TypedDict): - event_id: str - type: Literal["session.created"] - session: Resource.Session - - class SessionUpdated(TypedDict): - event_id: str - type: Literal["session.updated"] - session: Resource.Session - - class ConversationCreated(TypedDict): - event_id: str - type: Literal["conversation.created"] - conversation: Resource.Conversation - - class InputAudioBufferCommitted(TypedDict): - event_id: str - type: Literal["input_audio_buffer.committed"] - item_id: str - - class InputAudioBufferCleared(TypedDict): - event_id: str - type: Literal["input_audio_buffer.cleared"] - - class InputAudioBufferSpeechStarted(TypedDict): - event_id: str - type: Literal["input_audio_buffer.speech_started"] - item_id: str - audio_start_ms: int - - class InputAudioBufferSpeechStopped(TypedDict): - event_id: str - type: Literal["input_audio_buffer.speech_stopped"] - item_id: str - audio_end_ms: int - - class ConversationItemCreated(TypedDict): - event_id: str - type: Literal["conversation.item.created"] - previous_item_id: str | None - item: Resource.Item - - class ConversationItemInputAudioTranscriptionCompleted(TypedDict): - event_id: str - type: Literal["conversation.item.input_audio_transcription.completed"] - item_id: str - content_index: int - transcript: str - - class InputAudioTranscriptionError(TypedDict): - type: str - code: NotRequired[str] - message: str - param: NotRequired[str] - - class ConversationItemInputAudioTranscriptionFailed(TypedDict): - event_id: str - type: Literal["conversation.item.input_audio_transcription.failed"] - item_id: str - content_index: int - error: ServerEvent.InputAudioTranscriptionError - - class ConversationItemTruncated(TypedDict): - event_id: str - type: Literal["conversation.item.truncated"] - item_id: str - content_index: int - audio_end_ms: int - - class ConversationItemDeleted(TypedDict): - event_id: str - type: Literal["conversation.item.deleted"] - item_id: str - - class ResponseCreated(TypedDict): - event_id: str - type: Literal["response.created"] - response: Resource.Response - - class ResponseDone(TypedDict): - event_id: str - type: Literal["response.done"] - response: Resource.Response - - class ResponseOutputItemAdded(TypedDict): - event_id: str - type: Literal["response.output_item.added"] - response_id: str - output_index: int - item: Resource.Item - - class ResponseOutputItemDone(TypedDict): - event_id: str - type: Literal["response.output.done"] - response_id: str - output_index: int - item: Resource.Item - - class ResponseContentPartAdded(TypedDict): - event_id: str - type: Literal["response.content_part.added"] - item_id: str - response_id: str - output_index: int - content_index: int - part: ContentPart - - class ResponseContentPartDone(TypedDict): - event_id: str - type: Literal["response.content.done"] - response_id: str - output_index: int - content_index: int - part: ContentPart - - class ResponseTextDeltaAdded(TypedDict): - event_id: str - type: Literal["response.text.delta"] - response_id: str - output_index: int - content_index: int - delta: str - - class ResponseTextDone(TypedDict): - event_id: str - type: Literal["response.text.done"] - response_id: str - output_index: int - content_index: int - text: str - - class ResponseAudioTranscriptDelta(TypedDict): - event_id: str - type: Literal["response.audio_transcript.delta"] - response_id: str - output_index: int - content_index: int - delta: str - - class ResponseAudioTranscriptDone(TypedDict): - event_id: str - type: Literal["response.audio_transcript.done"] - response_id: str - output_index: int - content_index: int - transcript: str - - class ResponseAudioDelta(TypedDict): - event_id: str - type: Literal["response.audio.delta"] - response_id: str - output_index: int - content_index: int - delta: str # b64 - - class ResponseAudioDone(TypedDict): - event_id: str - type: Literal["response.audio.done"] - response_id: str - output_index: int - content_index: int - - class ResponseFunctionCallArgumentsDelta(TypedDict): - event_id: str - type: Literal["response.function_call_arguments.delta"] - response_id: str - output_index: int - delta: str - - class ResponseFunctionCallArgumentsDone(TypedDict): - event_id: str - type: Literal["response.function_call_arguments.done"] - response_id: str - output_index: int - arguments: str - - class RateLimitsData(TypedDict): - name: Literal["requests", "tokens", "input_tokens", "output_tokens"] - limit: int - remaining: int - reset_seconds: float - - class RateLimitsUpdated: - event_id: str - type: Literal["rate_limits.updated"] - limits: list[ServerEvent.RateLimitsData] - - -ClientEvents = Union[ - ClientEvent.SessionUpdate, - ClientEvent.InputAudioBufferAppend, - ClientEvent.InputAudioBufferCommit, - ClientEvent.InputAudioBufferClear, - ClientEvent.ConversationItemCreate, - ClientEvent.ConversationItemTruncate, - ClientEvent.ConversationItemDelete, - ClientEvent.ResponseCreate, - ClientEvent.ResponseCancel, -] - -ServerEvents = Union[ - ServerEvent.Error, - ServerEvent.SessionCreated, - ServerEvent.SessionUpdated, - ServerEvent.ConversationCreated, - ServerEvent.InputAudioBufferCommitted, - ServerEvent.InputAudioBufferCleared, - ServerEvent.InputAudioBufferSpeechStarted, - ServerEvent.InputAudioBufferSpeechStopped, - ServerEvent.ConversationItemCreated, - ServerEvent.ConversationItemInputAudioTranscriptionCompleted, - ServerEvent.ConversationItemInputAudioTranscriptionFailed, - ServerEvent.ConversationItemTruncated, - ServerEvent.ConversationItemDeleted, - ServerEvent.ResponseCreated, - ServerEvent.ResponseDone, - ServerEvent.ResponseOutputItemAdded, - ServerEvent.ResponseOutputItemDone, - ServerEvent.ResponseContentPartAdded, - ServerEvent.ResponseContentPartDone, - ServerEvent.ResponseTextDeltaAdded, - ServerEvent.ResponseTextDone, - ServerEvent.ResponseAudioTranscriptDelta, - ServerEvent.ResponseAudioTranscriptDone, - ServerEvent.ResponseAudioDelta, - ServerEvent.ResponseAudioDone, - ServerEvent.ResponseFunctionCallArgumentsDelta, - ServerEvent.ResponseFunctionCallArgumentsDone, - ServerEvent.RateLimitsUpdated, -] - -ClientEventType = Literal[ - "session.update", - "input_audio_buffer.append", - "input_audio_buffer.commit", - "input_audio_buffer.clear", - "conversation.item.create", - "conversation.item.truncate", - "conversation.item.delete", - "response.create", - "response.cancel", -] - -ServerEventType = Literal[ - "error", - "session.created", - "session.updated", - "conversation.created", - "input_audio_buffer.committed", - "input_audio_buffer.cleared", - "input_audio_buffer.speech_started", - "input_audio_buffer.speech_stopped", - "conversation.item.created", - "conversation.item.input_audio_transcription.completed", - "conversation.item.input_audio_transcription.failed", - "conversation.item.truncated", - "conversation.item.deleted", - "response.created", - "response.done", - "response.output_item.added", - "response.output_item.done", - "response.content_part.added", - "response.content_part.done", - "response.text.delta", - "response.text.done", - "response.audio_transcript.delta", - "response.audio_transcript.done", - "response.audio.delta", - "response.audio.done", - "response.function_call_arguments.delta", - "response.function_call_arguments.done", - "rate_limits.updated", -] diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index be94cde3f..d29ce6235 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -1,1707 +1,284 @@ from __future__ import annotations import asyncio +import openai import base64 -import os -import time -from copy import deepcopy -from dataclasses import dataclass -from typing import AsyncIterable, Literal, Optional, Union, cast, overload -from urllib.parse import urlencode -import aiohttp -from livekit import rtc -from livekit.agents import llm, utils from livekit.agents.llm.function_context import _create_ai_function_info -from livekit.agents.metrics import MultimodalLLMError, MultimodalLLMMetrics -from typing_extensions import TypedDict - -from .._oai_api import build_oai_function_description -from . import api_proto, remote_items -from .log import logger - -EventTypes = Literal[ - "start_session", - "session_updated", - "error", - "input_speech_started", - "input_speech_stopped", - "input_speech_committed", - "input_speech_transcription_completed", - "input_speech_transcription_failed", - "response_created", - "response_output_added", # message & assistant - "response_content_added", # message type (audio/text) - "response_content_done", - "response_output_done", - "response_done", - "function_calls_collected", - "function_calls_finished", - "metrics_collected", -] - - -@dataclass -class InputTranscriptionCompleted: - item_id: str - """id of the item""" - transcript: str - """transcript of the input audio""" - +from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection +from openai.types.beta.realtime import ( + InputAudioBufferSpeechStartedEvent, + InputAudioBufferSpeechStoppedEvent, + RealtimeClientEvent, + InputAudioBufferAppendEvent, + ResponseAudioDeltaEvent, + ResponseAudioDoneEvent, + ResponseAudioTranscriptDeltaEvent, + ResponseAudioTranscriptDoneEvent, + ResponseCancelEvent, + ResponseCreateEvent, + ConversationItemTruncateEvent, + ResponseCreatedEvent, + ResponseDoneEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, +) -@dataclass -class InputTranscriptionFailed: - item_id: str - """id of the item""" - message: str - """error message""" +from dataclasses import dataclass +from livekit.agents import multimodal, llm, utils +from livekit import rtc -@dataclass -class RealtimeResponse: - id: str - """id of the message""" - status: api_proto.ResponseStatus - """status of the response""" - status_details: api_proto.ResponseStatusDetails | None - """details of the status (only with "incomplete, cancelled and failed")""" - output: list[RealtimeOutput] - """list of outputs""" - usage: api_proto.Usage | None - """usage of the response""" - done_fut: asyncio.Future[None] - """future that will be set when the response is completed""" - _created_timestamp: float - """timestamp when the response was created""" - _first_token_timestamp: float | None = None - """timestamp when the first token was received""" +from .log import logger -@dataclass -class RealtimeOutput: - response_id: str - """id of the response""" - item_id: str - """id of the item""" - output_index: int - """index of the output""" - role: api_proto.Role - """role of the message""" - type: Literal["message", "function_call"] - """type of the output""" - content: list[RealtimeContent] - """list of content""" - done_fut: asyncio.Future[None] - """future that will be set when the output is completed""" +# When a response is created with the OpenAI Realtime API, those events are sent in this order: +# 1. response.created (contains resp_id) +# 2. response.output_item.added (contains item_id) +# 3. conversation.item.created +# 4. response.content_part.added (type audio/text) +# 5. response.audio_transcript.delta (x2, x3, x4, etc) +# 6. response.audio.delta (x2, x3, x4, etc) +# 7. response.content_part.done +# 8. response.output_item.done (contains item_status: "completed/incomplete") +# 9. response.done (contains status_details for cancelled/failed/turn_detected/content_filter) +# +# Ourcode assumes a response will generate only one item with type "message." -@dataclass -class RealtimeToolCall: - name: str - """name of the function""" - arguments: str - """accumulated arguments""" - tool_call_id: str - """id of the tool call""" +SAMPLE_RATE = 24000 +NUM_CHANNELS = 1 @dataclass -class Capabilities: - supports_truncate: bool +class _RealtimeOptions: + model: str @dataclass -class RealtimeContent: +class _ResponseGeneration: response_id: str - """id of the response""" item_id: str - """id of the item""" - output_index: int - """index of the output""" - content_index: int - """index of the content""" - text: str - """accumulated text content""" - audio: list[rtc.AudioFrame] - """accumulated audio content""" - text_stream: AsyncIterable[str] - """stream of text content""" - audio_stream: AsyncIterable[rtc.AudioFrame] - """stream of audio content""" - tool_calls: list[RealtimeToolCall] - """pending tool calls""" - content_type: api_proto.Modality - """type of the content""" - - -@dataclass -class ServerVadOptions: - threshold: float - prefix_padding_ms: int - silence_duration_ms: int - - -@dataclass -class InputTranscriptionOptions: - model: api_proto.InputTranscriptionModel | str - - -@dataclass -class RealtimeError: - event_id: str - type: str - message: str - code: Optional[str] - param: Optional[str] - - -@dataclass -class RealtimeSessionOptions: - model: api_proto.OpenAIModel | str - modalities: list[api_proto.Modality] - instructions: str - voice: api_proto.Voice - input_audio_format: api_proto.AudioFormat - output_audio_format: api_proto.AudioFormat - input_audio_transcription: InputTranscriptionOptions | None - turn_detection: ServerVadOptions | None - tool_choice: api_proto.ToolChoice - temperature: float - max_response_output_tokens: int | Literal["inf"] - - -@dataclass -class _ModelOptions(RealtimeSessionOptions): - api_key: str | None - base_url: str - entra_token: str | None - azure_deployment: str | None - is_azure: bool - api_version: str | None + audio_ch: utils.aio.Chan[rtc.AudioFrame] + text_ch: utils.aio.Chan[str] + tool_calls_ch: utils.aio.Chan[llm.FunctionCallInfo] -class _ContentPtr(TypedDict): - response_id: str - output_index: int - content_index: int - - -DEFAULT_SERVER_VAD_OPTIONS = ServerVadOptions( - threshold=0.5, - prefix_padding_ms=300, - silence_duration_ms=500, -) - -DEFAULT_INPUT_AUDIO_TRANSCRIPTION = InputTranscriptionOptions(model="whisper-1") - - -class RealtimeModel: - @overload - def __init__( - self, - *, - instructions: str = "", - modalities: list[api_proto.Modality] = ["text", "audio"], - model: api_proto.OpenAIModel | str = api_proto.DefaultOpenAIModel, - voice: api_proto.Voice = "alloy", - input_audio_format: api_proto.AudioFormat = "pcm16", - output_audio_format: api_proto.AudioFormat = "pcm16", - input_audio_transcription: InputTranscriptionOptions = DEFAULT_INPUT_AUDIO_TRANSCRIPTION, - turn_detection: ServerVadOptions = DEFAULT_SERVER_VAD_OPTIONS, - tool_choice: api_proto.ToolChoice = "auto", - temperature: float = 0.8, - max_response_output_tokens: int | Literal["inf"] = "inf", - api_key: str | None = None, - base_url: str | None = None, - http_session: aiohttp.ClientSession | None = None, - loop: asyncio.AbstractEventLoop | None = None, - ) -> None: ... - - @overload +class RealtimeModel(multimodal.RealtimeModel): def __init__( self, *, - azure_deployment: str | None = None, - entra_token: str | None = None, - api_key: str | None = None, - api_version: str | None = None, - base_url: str | None = None, - instructions: str = "", - modalities: list[api_proto.Modality] = ["text", "audio"], - voice: api_proto.Voice = "alloy", - input_audio_format: api_proto.AudioFormat = "pcm16", - output_audio_format: api_proto.AudioFormat = "pcm16", - input_audio_transcription: InputTranscriptionOptions = DEFAULT_INPUT_AUDIO_TRANSCRIPTION, - turn_detection: ServerVadOptions = DEFAULT_SERVER_VAD_OPTIONS, - tool_choice: api_proto.ToolChoice = "auto", - temperature: float = 0.8, - max_response_output_tokens: int | Literal["inf"] = "inf", - http_session: aiohttp.ClientSession | None = None, - loop: asyncio.AbstractEventLoop | None = None, - ) -> None: ... - - def __init__( - self, - *, - instructions: str = "", - modalities: list[api_proto.Modality] = ["text", "audio"], - model: api_proto.OpenAIModel | str = api_proto.DefaultOpenAIModel, - voice: api_proto.Voice = "alloy", - input_audio_format: api_proto.AudioFormat = "pcm16", - output_audio_format: api_proto.AudioFormat = "pcm16", - input_audio_transcription: InputTranscriptionOptions = DEFAULT_INPUT_AUDIO_TRANSCRIPTION, - turn_detection: ServerVadOptions = DEFAULT_SERVER_VAD_OPTIONS, - tool_choice: api_proto.ToolChoice = "auto", - temperature: float = 0.8, - max_response_output_tokens: int | Literal["inf"] = "inf", - base_url: str | None = None, - http_session: aiohttp.ClientSession | None = None, - loop: asyncio.AbstractEventLoop | None = None, - # azure specific parameters - azure_deployment: str | None = None, - entra_token: str | None = None, - api_key: str | None = None, - api_version: str | None = None, + model: str = "gpt-4o-realtime-preview-2024-12-17", + client: openai.AsyncClient | None = None, ) -> None: - """ - Initializes a RealtimeClient instance for interacting with OpenAI's Realtime API. - - Args: - instructions (str, optional): Initial system instructions for the model. Defaults to "". - api_key (str or None, optional): OpenAI API key. If None, will attempt to read from the environment variable OPENAI_API_KEY - modalities (list[api_proto.Modality], optional): Modalities to use, such as ["text", "audio"]. Defaults to ["text", "audio"]. - model (str or None, optional): The name of the model to use. Defaults to "gpt-4o-realtime-preview-2024-10-01". - voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "alloy". - input_audio_format (api_proto.AudioFormat, optional): Format of input audio data. Defaults to "pcm16". - output_audio_format (api_proto.AudioFormat, optional): Format of output audio data. Defaults to "pcm16". - input_audio_transcription (InputTranscriptionOptions, optional): Options for transcribing input audio. Defaults to DEFAULT_INPUT_AUDIO_TRANSCRIPTION. - turn_detection (ServerVadOptions, optional): Options for server-based voice activity detection (VAD). Defaults to DEFAULT_SERVER_VAD_OPTIONS. - tool_choice (api_proto.ToolChoice, optional): Tool choice for the model, such as "auto". Defaults to "auto". - temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8. - max_response_output_tokens (int or Literal["inf"], optional): Maximum number of tokens in the response. Defaults to "inf". - base_url (str or None, optional): Base URL for the API endpoint. If None, defaults to OpenAI's default API URL. - http_session (aiohttp.ClientSession or None, optional): Async HTTP session to use for requests. If None, a new session will be created. - loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used. - - Raises: - ValueError: If the API key is not provided and cannot be found in environment variables. - """ - super().__init__() - self._capabilities = Capabilities( - supports_truncate=True, - ) - self._base_url = base_url - - is_azure = ( - api_version is not None - or entra_token is not None - or azure_deployment is not None + super().__init__( + capabilities=multimodal.RealtimeCapabilities(message_truncation=True) ) - api_key = api_key or os.environ.get("OPENAI_API_KEY") - if api_key is None and not is_azure: - raise ValueError( - "OpenAI API key is required, either using the argument or by setting the OPENAI_API_KEY environmental variable" - ) - - if not base_url: - base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") - - self._default_opts = _ModelOptions( - model=model, - modalities=modalities, - instructions=instructions, - voice=voice, - input_audio_format=input_audio_format, - output_audio_format=output_audio_format, - input_audio_transcription=input_audio_transcription, - turn_detection=turn_detection, - temperature=temperature, - tool_choice=tool_choice, - max_response_output_tokens=max_response_output_tokens, - api_key=api_key, - base_url=base_url, - azure_deployment=azure_deployment, - entra_token=entra_token, - is_azure=is_azure, - api_version=api_version, - ) - - self._loop = loop or asyncio.get_event_loop() - self._rt_sessions: list[RealtimeSession] = [] - self._http_session = http_session - - @classmethod - def with_azure( - cls, - *, - azure_deployment: str, - azure_endpoint: str | None = None, - api_version: str | None = None, - api_key: str | None = None, - entra_token: str | None = None, - base_url: str | None = None, - instructions: str = "", - modalities: list[api_proto.Modality] = ["text", "audio"], - voice: api_proto.Voice = "alloy", - input_audio_format: api_proto.AudioFormat = "pcm16", - output_audio_format: api_proto.AudioFormat = "pcm16", - input_audio_transcription: InputTranscriptionOptions = DEFAULT_INPUT_AUDIO_TRANSCRIPTION, - turn_detection: ServerVadOptions = DEFAULT_SERVER_VAD_OPTIONS, - tool_choice: api_proto.ToolChoice = "auto", - temperature: float = 0.8, - max_response_output_tokens: int | Literal["inf"] = "inf", - http_session: aiohttp.ClientSession | None = None, - loop: asyncio.AbstractEventLoop | None = None, - ): - """ - Create a RealtimeClient instance configured for Azure OpenAI Service. - - Args: - azure_deployment (str): The name of your Azure OpenAI deployment. - azure_endpoint (str or None, optional): The endpoint URL for your Azure OpenAI resource. If None, will attempt to read from the environment variable AZURE_OPENAI_ENDPOINT. - api_version (str or None, optional): API version to use with Azure OpenAI Service. If None, will attempt to read from the environment variable OPENAI_API_VERSION. - api_key (str or None, optional): Azure OpenAI API key. If None, will attempt to read from the environment variable AZURE_OPENAI_API_KEY. - entra_token (str or None, optional): Azure Entra authentication token. Required if not using API key authentication. - base_url (str or None, optional): Base URL for the API endpoint. If None, constructed from the azure_endpoint. - instructions (str, optional): Initial system instructions for the model. Defaults to "". - modalities (list[api_proto.Modality], optional): Modalities to use, such as ["text", "audio"]. Defaults to ["text", "audio"]. - voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "alloy". - input_audio_format (api_proto.AudioFormat, optional): Format of input audio data. Defaults to "pcm16". - output_audio_format (api_proto.AudioFormat, optional): Format of output audio data. Defaults to "pcm16". - input_audio_transcription (InputTranscriptionOptions, optional): Options for transcribing input audio. Defaults to DEFAULT_INPUT_AUDIO_TRANSCRIPTION. - turn_detection (ServerVadOptions, optional): Options for server-based voice activity detection (VAD). Defaults to DEFAULT_SERVER_VAD_OPTIONS. - tool_choice (api_proto.ToolChoice, optional): Tool choice for the model, such as "auto". Defaults to "auto". - temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8. - max_response_output_tokens (int or Literal["inf"], optional): Maximum number of tokens in the response. Defaults to "inf". - http_session (aiohttp.ClientSession or None, optional): Async HTTP session to use for requests. If None, a new session will be created. - loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used. + self._opts = _RealtimeOptions(model=model) + self._client = client or openai.AsyncClient() - Returns: - RealtimeClient: An instance of RealtimeClient configured for Azure OpenAI Service. + def session(self) -> "RealtimeSession": + return RealtimeSession(self) - Raises: - ValueError: If required Azure parameters are missing or invalid. - """ - api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY") - if api_key is None and entra_token is None: - raise ValueError( - "Missing credentials. Please pass one of `api_key`, `entra_token`, or the `AZURE_OPENAI_API_KEY` environment variable." - ) + async def aclose(self) -> None: ... - api_version = api_version or os.getenv("OPENAI_API_VERSION") - if api_version is None: - raise ValueError( - "Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable" - ) - if base_url is None: - azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT") - if azure_endpoint is None: - raise ValueError( - "Missing Azure endpoint. Please pass the `azure_endpoint` parameter or set the `AZURE_OPENAI_ENDPOINT` environment variable." - ) +class RealtimeSession(multimodal.RealtimeSession): + def __init__(self, realtime_model: RealtimeModel) -> None: + super().__init__(realtime_model) + self._realtime_model = realtime_model + self._chat_ctx = llm.ChatContext() + self._fnc_ctx: llm.FunctionContext | None = None + self._msg_ch = utils.aio.Chan[RealtimeClientEvent]() - base_url = f"{azure_endpoint.rstrip('/')}/openai" - elif azure_endpoint is not None: - raise ValueError("base_url and azure_endpoint are mutually exclusive") - - return cls( - instructions=instructions, - modalities=modalities, - voice=voice, - input_audio_format=input_audio_format, - output_audio_format=output_audio_format, - input_audio_transcription=input_audio_transcription, - turn_detection=turn_detection, - tool_choice=tool_choice, - temperature=temperature, - max_response_output_tokens=max_response_output_tokens, - api_key=api_key, - http_session=http_session, - loop=loop, - azure_deployment=azure_deployment, - api_version=api_version, - entra_token=entra_token, - base_url=base_url, - ) - - def _ensure_session(self) -> aiohttp.ClientSession: - if not self._http_session: - self._http_session = utils.http_context.http_session() - - return self._http_session - - @property - def sessions(self) -> list[RealtimeSession]: - return self._rt_sessions - - @property - def capabilities(self) -> Capabilities: - return self._capabilities - - def session( - self, - *, - chat_ctx: llm.ChatContext | None = None, - fnc_ctx: llm.FunctionContext | None = None, - modalities: list[api_proto.Modality] | None = None, - instructions: str | None = None, - voice: api_proto.Voice | None = None, - input_audio_format: api_proto.AudioFormat | None = None, - output_audio_format: api_proto.AudioFormat | None = None, - tool_choice: api_proto.ToolChoice | None = None, - input_audio_transcription: InputTranscriptionOptions | None = None, - turn_detection: ServerVadOptions | None = None, - temperature: float | None = None, - max_response_output_tokens: int | Literal["inf"] | None = None, - ) -> RealtimeSession: - opts = deepcopy(self._default_opts) - if modalities is not None: - opts.modalities = modalities - if instructions is not None: - opts.instructions = instructions - if voice is not None: - opts.voice = voice - if input_audio_format is not None: - opts.input_audio_format = input_audio_format - if output_audio_format is not None: - opts.output_audio_format = output_audio_format - if tool_choice is not None: - opts.tool_choice = tool_choice - if input_audio_transcription is not None: - opts.input_audio_transcription - if turn_detection is not None: - opts.turn_detection = turn_detection - if temperature is not None: - opts.temperature = temperature - if max_response_output_tokens is not None: - opts.max_response_output_tokens = max_response_output_tokens - - new_session = RealtimeSession( - chat_ctx=chat_ctx or llm.ChatContext(), - fnc_ctx=fnc_ctx, - opts=opts, - http_session=self._ensure_session(), - loop=self._loop, - ) - self._rt_sessions.append(new_session) - return new_session - - async def aclose(self) -> None: - for session in self._rt_sessions: - await session.aclose() - - -class RealtimeSession(utils.EventEmitter[EventTypes]): - class InputAudioBuffer: - def __init__(self, sess: RealtimeSession) -> None: - self._sess = sess - - def append(self, frame: rtc.AudioFrame) -> None: - self._sess._queue_msg( - { - "type": "input_audio_buffer.append", - "audio": base64.b64encode(frame.data).decode("utf-8"), - } - ) - - def clear(self) -> None: - self._sess._queue_msg({"type": "input_audio_buffer.clear"}) - - def commit(self) -> None: - self._sess._queue_msg({"type": "input_audio_buffer.commit"}) - - class ConversationItem: - def __init__(self, sess: RealtimeSession) -> None: - self._sess = sess - - def create( - self, message: llm.ChatMessage, previous_item_id: str | None = None - ) -> asyncio.Future[bool]: - fut = asyncio.Future[bool]() - - message_content = message.content - tool_call_id = message.tool_call_id - event: api_proto.ClientEvent.ConversationItemCreate | None = None - if tool_call_id: - if message.role == "tool": - # function_call_output - assert isinstance(message_content, str) - event = { - "type": "conversation.item.create", - "previous_item_id": previous_item_id, - "item": { - "id": message.id, - "type": "function_call_output", - "call_id": tool_call_id, - "output": message_content, - }, - } - else: - # function_call - if not message.tool_calls or message.name is None: - logger.warning( - "function call message has no name or tool calls: %s", - message, - extra=self._sess.logging_extra(), - ) - fut.set_result(False) - return fut - if len(message.tool_calls) > 1: - logger.warning( - "function call message has multiple tool calls, " - "only the first one will be used", - extra=self._sess.logging_extra(), - ) - - event = { - "type": "conversation.item.create", - "previous_item_id": previous_item_id, - "item": { - "id": message.id, - "type": "function_call", - "call_id": tool_call_id, - "name": message.name, - "arguments": message.tool_calls[0].raw_arguments, - }, - } - else: - if message_content is None: - logger.warning( - "message content is None, skipping: %s", - message, - extra=self._sess.logging_extra(), - ) - fut.set_result(False) - return fut - if not isinstance(message_content, list): - message_content = [message_content] - - if message.role == "user": - user_contents: list[ - api_proto.InputTextContent | api_proto.InputAudioContent - ] = [] - for cnt in message_content: - if isinstance(cnt, str): - user_contents.append( - { - "type": "input_text", - "text": cnt, - } - ) - elif isinstance(cnt, llm.ChatAudio): - user_contents.append( - { - "type": "input_audio", - "audio": base64.b64encode( - utils.merge_frames(cnt.frame).data - ).decode("utf-8"), - } - ) - - event = { - "type": "conversation.item.create", - "previous_item_id": previous_item_id, - "item": { - "id": message.id, - "type": "message", - "role": "user", - "content": user_contents, - }, - } - - elif message.role == "assistant": - assistant_contents: list[api_proto.TextContent] = [] - for cnt in message_content: - if isinstance(cnt, str): - assistant_contents.append( - { - "type": "text", - "text": cnt, - } - ) - elif isinstance(cnt, llm.ChatAudio): - logger.warning( - "audio content in assistant message is not supported" - ) - - event = { - "type": "conversation.item.create", - "previous_item_id": previous_item_id, - "item": { - "id": message.id, - "type": "message", - "role": "assistant", - "content": assistant_contents, - }, - } - elif message.role == "system": - system_contents: list[api_proto.InputTextContent] = [] - for cnt in message_content: - if isinstance(cnt, str): - system_contents.append({"type": "input_text", "text": cnt}) - elif isinstance(cnt, llm.ChatAudio): - logger.warning( - "audio content in system message is not supported" - ) - - event = { - "type": "conversation.item.create", - "previous_item_id": previous_item_id, - "item": { - "id": message.id, - "type": "message", - "role": "system", - "content": system_contents, - }, - } - - if event is None: - logger.warning( - "chat message is not supported inside the realtime API %s", - message, - extra=self._sess.logging_extra(), - ) - fut.set_result(False) - return fut - - self._sess._item_created_futs[message.id] = fut - self._sess._queue_msg(event) - return fut - - def truncate( - self, *, item_id: str, content_index: int, audio_end_ms: int - ) -> asyncio.Future[bool]: - fut = asyncio.Future[bool]() - self._sess._item_truncated_futs[item_id] = fut - self._sess._queue_msg( - { - "type": "conversation.item.truncate", - "item_id": item_id, - "content_index": content_index, - "audio_end_ms": audio_end_ms, - } - ) - return fut - - def delete(self, *, item_id: str) -> asyncio.Future[bool]: - fut = asyncio.Future[bool]() - self._sess._item_deleted_futs[item_id] = fut - self._sess._queue_msg( - { - "type": "conversation.item.delete", - "item_id": item_id, - } - ) - return fut - - class Conversation: - def __init__(self, sess: RealtimeSession) -> None: - self._sess = sess - - @property - def item(self) -> RealtimeSession.ConversationItem: - return RealtimeSession.ConversationItem(self._sess) - - class Response: - def __init__(self, sess: RealtimeSession) -> None: - self._sess = sess - - def create( - self, - *, - on_duplicate: Literal[ - "cancel_existing", "cancel_new", "keep_both" - ] = "keep_both", - ) -> asyncio.Future[bool]: - """Creates a new response. - - Args: - on_duplicate: How to handle when there is an existing response in progress: - - "cancel_existing": Cancel the existing response before creating new one - - "cancel_new": Skip creating new response if one is in progress - - "keep_both": Wait for the existing response to be done and then create a new one - - Returns: - Future that resolves when the response create request is queued - """ - if on_duplicate not in ("cancel_existing", "cancel_new", "keep_both"): - raise ValueError( - "invalid on_duplicate value, must be one of: " - "cancel_existing, cancel_new, keep_both" - ) - - # check if there is a pending response creation request sent - pending_create_fut = self._sess._response_create_fut - if pending_create_fut is not None: - if on_duplicate == "cancel_new": - logger.warning( - "skip new response creation due to previous pending response creation", - extra=self._sess.logging_extra(), - ) - _fut = asyncio.Future[bool]() - _fut.set_result(False) - return _fut - - active_resp_id = self._sess._active_response_id - _logging_extra = { - "existing_response_id": active_resp_id, - **self._sess.logging_extra(), - } - - if ( - not active_resp_id - or self._sess._pending_responses[active_resp_id].done_fut.done() - ): - # no active response in progress, create a new one - self._sess._queue_msg({"type": "response.create"}) - _fut = asyncio.Future[bool]() - _fut.set_result(True) - return _fut - - # there is an active response in progress - if on_duplicate == "cancel_new": - logger.warning( - "skip new response creation due to active response in progress", - extra=_logging_extra, - ) - _fut = asyncio.Future[bool]() - _fut.set_result(False) - return _fut - - if on_duplicate == "cancel_existing": - self.cancel() - logger.warning( - "cancelling in-progress response to create a new one", - extra=_logging_extra, - ) - elif on_duplicate == "keep_both": - logger.warning( - "waiting for in-progress response to be done " - "before creating a new one", - extra=_logging_extra, - ) - - # create a task to wait for the previous response and then create new one - async def wait_and_create() -> bool: - await self._sess._pending_responses[active_resp_id].done_fut - logger.info( - "in-progress response is done, creating a new one", - extra=_logging_extra, - ) - new_create_fut = asyncio.Future[None]() - self._sess._response_create_fut = new_create_fut - self._sess._queue_msg({"type": "response.create"}) - return True - - return asyncio.create_task(wait_and_create()) - - def cancel(self) -> None: - self._sess._queue_msg({"type": "response.cancel"}) - - def __init__( - self, - *, - opts: _ModelOptions, - http_session: aiohttp.ClientSession, - chat_ctx: llm.ChatContext, - fnc_ctx: llm.FunctionContext | None, - loop: asyncio.AbstractEventLoop, - ) -> None: - super().__init__() - self._label = f"{type(self).__module__}.{type(self).__name__}" + self._conn: AsyncRealtimeConnection | None = None self._main_atask = asyncio.create_task( - self._main_task(), name="openai-realtime-session" + self._main_task(), name="RealtimeSession._main_task" ) - # manage conversation items internally - self._remote_conversation_items = remote_items._RemoteConversationItems() - - # wait for the item to be created or deleted - self._item_created_futs: dict[str, asyncio.Future[bool]] = {} - self._item_deleted_futs: dict[str, asyncio.Future[bool]] = {} - self._item_truncated_futs: dict[str, asyncio.Future[bool]] = {} - - self._fnc_ctx = fnc_ctx - self._loop = loop - - self._opts = opts - self._send_ch = utils.aio.Chan[api_proto.ClientEvents]() - self._http_session = http_session - - self._pending_responses: dict[str, RealtimeResponse] = {} - self._active_response_id: str | None = None - self._response_create_fut: asyncio.Future[None] | None = None - self._session_id = "not-connected" - self.session_update() # initial session init - - # sync the chat context to the session - self._init_sync_task = asyncio.create_task(self.set_chat_ctx(chat_ctx)) - - self._fnc_tasks = utils.aio.TaskSet() - - async def aclose(self) -> None: - if self._send_ch.closed: - return - - self._send_ch.close() - await self._main_atask - - @property - def fnc_ctx(self) -> llm.FunctionContext | None: - return self._fnc_ctx - - @fnc_ctx.setter - def fnc_ctx(self, fnc_ctx: llm.FunctionContext | None) -> None: - self._fnc_ctx = fnc_ctx - - @property - def conversation(self) -> Conversation: - return RealtimeSession.Conversation(self) - - @property - def input_audio_buffer(self) -> InputAudioBuffer: - return RealtimeSession.InputAudioBuffer(self) - - def _push_audio(self, frame: rtc.AudioFrame) -> None: - self.input_audio_buffer.append(frame) - - @property - def response(self) -> Response: - return RealtimeSession.Response(self) - - def session_update( - self, - *, - modalities: list[api_proto.Modality] | None = None, - instructions: str | None = None, - voice: api_proto.Voice | None = None, - input_audio_format: api_proto.AudioFormat | None = None, - output_audio_format: api_proto.AudioFormat | None = None, - input_audio_transcription: InputTranscriptionOptions | None = None, - turn_detection: ServerVadOptions | None = None, - tool_choice: api_proto.ToolChoice | None = None, - temperature: float | None = None, - max_response_output_tokens: int | Literal["inf"] | None = None, - ) -> None: - self._opts = deepcopy(self._opts) - if modalities is not None: - self._opts.modalities = modalities - if instructions is not None: - self._opts.instructions = instructions - if voice is not None: - self._opts.voice = voice - if input_audio_format is not None: - self._opts.input_audio_format = input_audio_format - if output_audio_format is not None: - self._opts.output_audio_format = output_audio_format - if input_audio_transcription is not None: - self._opts.input_audio_transcription = input_audio_transcription - if turn_detection is not None: - self._opts.turn_detection = turn_detection - if tool_choice is not None: - self._opts.tool_choice = tool_choice - if temperature is not None: - self._opts.temperature = temperature - if max_response_output_tokens is not None: - self._opts.max_response_output_tokens = max_response_output_tokens - - tools = [] - if self._fnc_ctx is not None: - for fnc in self._fnc_ctx.ai_functions.values(): - # the realtime API is using internally-tagged polymorphism. - # build_oai_function_description was built for the ChatCompletion API - function_data = build_oai_function_description(fnc)["function"] - function_data["type"] = "function" - tools.append(function_data) - - server_vad_opts: api_proto.ServerVad | None = None - if self._opts.turn_detection is not None: - server_vad_opts = { - "type": "server_vad", - "threshold": self._opts.turn_detection.threshold, - "prefix_padding_ms": self._opts.turn_detection.prefix_padding_ms, - "silence_duration_ms": self._opts.turn_detection.silence_duration_ms, - } - input_audio_transcription_opts: api_proto.InputAudioTranscription | None = None - if self._opts.input_audio_transcription is not None: - input_audio_transcription_opts = { - "model": self._opts.input_audio_transcription.model, - } - - session_data: api_proto.ClientEvent.SessionUpdateData = { - "modalities": self._opts.modalities, - "instructions": self._opts.instructions, - "voice": self._opts.voice, - "input_audio_format": self._opts.input_audio_format, - "output_audio_format": self._opts.output_audio_format, - "input_audio_transcription": input_audio_transcription_opts, - "turn_detection": server_vad_opts, - "tools": tools, - "tool_choice": self._opts.tool_choice, - "temperature": self._opts.temperature, - "max_response_output_tokens": None, - } - - # azure doesn't support inf for max_response_output_tokens - if not self._opts.is_azure or isinstance( - self._opts.max_response_output_tokens, int - ): - session_data["max_response_output_tokens"] = ( - self._opts.max_response_output_tokens - ) - else: - del session_data["max_response_output_tokens"] # type: ignore - - self._queue_msg( - { - "type": "session.update", - "session": session_data, - } - ) - - def chat_ctx_copy(self) -> llm.ChatContext: - return self._remote_conversation_items.to_chat_context() - - async def set_chat_ctx(self, new_ctx: llm.ChatContext) -> None: - """Sync the chat context with the agent's chat context. - - Compute the minimum number of insertions and deletions to transform the old - chat context messages to the new chat context messages. - """ - original_ctx = self._remote_conversation_items.to_chat_context() - - # filter out messages that are not function calls and content is None - filtered_messages = [ - msg - for msg in new_ctx.messages - if msg.tool_call_id or msg.content is not None - ] - changes = utils._compute_changes( - original_ctx.messages, filtered_messages, key_fnc=lambda x: x.id - ) - logger.debug( - "sync chat context", - extra={ - "to_delete": [msg.id for msg in changes.to_delete], - "to_add": [ - (prev.id if prev else None, msg.id) for prev, msg in changes.to_add - ], - }, - ) - - # append an empty audio message if all new messages are text - if changes.to_add and not any( - isinstance(msg.content, llm.ChatAudio) for _, msg in changes.to_add - ): - # Patch: append an empty audio message to set the API in audio mode - changes.to_add.append((None, self._create_empty_user_audio_message(1.0))) - - _futs = [ - self.conversation.item.delete(item_id=msg.id) for msg in changes.to_delete - ] + [ - self.conversation.item.create(msg, prev.id if prev else None) - for prev, msg in changes.to_add - ] - - # wait for all the futures to complete - await asyncio.gather(*_futs) - - def _create_empty_user_audio_message(self, duration: float) -> llm.ChatMessage: - """Create an empty audio message with the given duration.""" - samples = int(duration * api_proto.SAMPLE_RATE) - return llm.ChatMessage( - role="user", - content=llm.ChatAudio( - frame=rtc.AudioFrame( - data=b"\x00\x00" * (samples * api_proto.NUM_CHANNELS), - sample_rate=api_proto.SAMPLE_RATE, - num_channels=api_proto.NUM_CHANNELS, - samples_per_channel=samples, - ) - ), - ) - - def _recover_from_text_response(self, item_id: str | None = None) -> None: - """Try to recover from a text response to audio mode. - - Sometimes the OpenAI Realtime API returns text instead of audio responses. - This method tries to recover from this by requesting a new response after - deleting the text response and creating an empty user audio message. - """ - if item_id: - # remove the text response if needed - self.conversation.item.delete(item_id=item_id) - self.conversation.item.create(self._create_empty_user_audio_message(1.0)) - self.response.create(on_duplicate="keep_both") - - def _truncate_conversation_item( - self, item_id: str, content_index: int, audio_end_ms: int - ) -> None: - self.conversation.item.truncate( - item_id=item_id, - content_index=content_index, - audio_end_ms=audio_end_ms, - ) - - def _update_conversation_item_content( - self, item_id: str, content: llm.ChatContent | list[llm.ChatContent] | None - ) -> None: - item = self._remote_conversation_items.get(item_id) - if item is None: - logger.warning( - "conversation item not found, skipping update", - extra={"item_id": item_id}, - ) - return - item.content = content - - def _queue_msg(self, msg: api_proto.ClientEvents) -> None: - self._send_ch.send_nowait(msg) + self._current_generation: _ResponseGeneration | None = None @utils.log_exceptions(logger=logger) async def _main_task(self) -> None: - try: - headers = {"User-Agent": "LiveKit Agents"} - query_params: dict[str, str] = {} - - base_url = self._opts.base_url - if self._opts.is_azure: - if self._opts.entra_token: - headers["Authorization"] = f"Bearer {self._opts.entra_token}" - - if self._opts.api_key: - headers["api-key"] = self._opts.api_key - - if self._opts.api_version: - query_params["api-version"] = self._opts.api_version - - if self._opts.azure_deployment: - query_params["deployment"] = self._opts.azure_deployment - else: - # OAI endpoint - headers["Authorization"] = f"Bearer {self._opts.api_key}" - headers["OpenAI-Beta"] = "realtime=v1" - - if self._opts.model: - query_params["model"] = self._opts.model - - url = f"{base_url.rstrip('/')}/realtime?{urlencode(query_params)}" - if url.startswith("http"): - url = url.replace("http", "ws", 1) - - ws_conn = await self._http_session.ws_connect( - url, - headers=headers, - ) - except Exception: - logger.exception("failed to connect to OpenAI API S2S") - return - - closing = False + self._conn = conn = await self._realtime_model._client.beta.realtime.connect( + model=self._realtime_model._opts.model + ).enter() @utils.log_exceptions(logger=logger) - async def _send_task(): - nonlocal closing - async for msg in self._send_ch: - await ws_conn.send_json(msg) - - closing = True - await ws_conn.close() + async def _listen_for_events() -> None: + async for event in conn: + if event.type == "input_audio_buffer.speech_started": + self._handle_input_audio_buffer_speech_started(event) + elif event.type == "input_audio_buffer.speech_stopped": + self._handle_input_audio_buffer_speech_stopped(event) + elif event.type == "response.created": + self._handle_response_created(event) + elif event.type == "response.output_item.added": + self._handle_response_output_item_added(event) + elif event.type == "response.audio_transcript.delta": + self._handle_response_audio_transcript_delta(event) + elif event.type == "response.audio.delta": + self._handle_response_audio_delta(event) + elif event.type == "response.audio_transcript.done": + self._handle_response_audio_transcript_done(event) + elif event.type == "response.audio.done": + self._handle_response_audio_done(event) + elif event.type == "response.output_item.done": + self._handle_response_output_item_done(event) + elif event.type == "response.done": + self._handle_response_done(event) @utils.log_exceptions(logger=logger) - async def _recv_task(): - while True: - msg = await ws_conn.receive() - if msg.type in ( - aiohttp.WSMsgType.CLOSED, - aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.CLOSING, - ): - if closing: - return - - raise Exception("OpenAI S2S connection closed unexpectedly") - - if msg.type != aiohttp.WSMsgType.TEXT: - logger.warning( - "unexpected OpenAI S2S message type %s", - msg.type, - extra=self.logging_extra(), - ) - continue - - try: - data = msg.json() - event: api_proto.ServerEventType = data["type"] - - if event == "session.created": - self._handle_session_created(data) - if event == "session.updated": - self._handle_session_updated(data) - elif event == "error": - self._handle_error(data) - elif event == "input_audio_buffer.speech_started": - self._handle_input_audio_buffer_speech_started(data) - elif event == "input_audio_buffer.speech_stopped": - self._handle_input_audio_buffer_speech_stopped(data) - elif event == "input_audio_buffer.committed": - self._handle_input_audio_buffer_speech_committed(data) - elif ( - event == "conversation.item.input_audio_transcription.completed" - ): - self._handle_conversation_item_input_audio_transcription_completed( - data - ) - elif event == "conversation.item.input_audio_transcription.failed": - self._handle_conversation_item_input_audio_transcription_failed( - data - ) - elif event == "conversation.item.created": - self._handle_conversation_item_created(data) - elif event == "conversation.item.deleted": - self._handle_conversation_item_deleted(data) - elif event == "conversation.item.truncated": - self._handle_conversation_item_truncated(data) - elif event == "response.created": - self._handle_response_created(data) - elif event == "response.output_item.added": - self._handle_response_output_item_added(data) - elif event == "response.content_part.added": - self._handle_response_content_part_added(data) - elif event == "response.audio.delta": - self._handle_response_audio_delta(data) - elif event == "response.audio_transcript.delta": - self._handle_response_audio_transcript_delta(data) - elif event == "response.audio.done": - self._handle_response_audio_done(data) - elif event == "response.text.done": - self._handle_response_text_done(data) - elif event == "response.audio_transcript.done": - self._handle_response_audio_transcript_done(data) - elif event == "response.content_part.done": - self._handle_response_content_part_done(data) - elif event == "response.output_item.done": - self._handle_response_output_item_done(data) - elif event == "response.done": - self._handle_response_done(data) - - except Exception: - logger.exception( - "failed to handle OpenAI S2S message", - extra={"websocket_message": msg, **self.logging_extra()}, - ) + async def _forward_input_audio() -> None: + async for msg in self._msg_ch: + await conn.send(msg) tasks = [ - asyncio.create_task(_send_task(), name="openai-realtime-send"), - asyncio.create_task(_recv_task(), name="openai-realtime-recv"), + asyncio.create_task(_listen_for_events(), name="_listen_for_events"), + asyncio.create_task(_forward_input_audio(), name="_forward_input_audio"), ] - try: await asyncio.gather(*tasks) finally: await utils.aio.gracefully_cancel(*tasks) - - def _handle_session_created( - self, session_created: api_proto.ServerEvent.SessionCreated - ): - self._session_id = session_created["session"]["id"] - - def _handle_session_updated( - self, session_updated: api_proto.ServerEvent.SessionUpdated - ): - session = session_updated["session"] - if session["turn_detection"] is None: - turn_detection = None - else: - turn_detection = ServerVadOptions( - threshold=session["turn_detection"]["threshold"], - prefix_padding_ms=session["turn_detection"]["prefix_padding_ms"], - silence_duration_ms=session["turn_detection"]["silence_duration_ms"], - ) - if session["input_audio_transcription"] is None: - input_audio_transcription = None - else: - input_audio_transcription = InputTranscriptionOptions( - model=session["input_audio_transcription"]["model"], - ) - - self.emit( - "session_updated", - RealtimeSessionOptions( - model=session["model"], - modalities=session["modalities"], - instructions=session["instructions"], - voice=session["voice"], - input_audio_format=session["input_audio_format"], - output_audio_format=session["output_audio_format"], - input_audio_transcription=input_audio_transcription, - turn_detection=turn_detection, - tool_choice=session["tool_choice"], - temperature=session["temperature"], - max_response_output_tokens=session["max_response_output_tokens"], - ), - ) - - def _handle_error(self, error: api_proto.ServerEvent.Error): - logger.error( - "OpenAI S2S error %s", - error, - extra=self.logging_extra(), - ) - error_content = error["error"] - self.emit( - "error", - RealtimeError( - event_id=error["event_id"], - type=error_content["type"], - message=error_content["message"], - code=error_content.get("code"), - param=error_content.get("param"), - ), - ) + await conn.close() def _handle_input_audio_buffer_speech_started( - self, speech_started: api_proto.ServerEvent.InputAudioBufferSpeechStarted - ): - self.emit("input_speech_started") + self, _: InputAudioBufferSpeechStartedEvent + ) -> None: + self.emit("input_speech_started", multimodal.InputSpeechStartedEvent()) def _handle_input_audio_buffer_speech_stopped( - self, speech_stopped: api_proto.ServerEvent.InputAudioBufferSpeechStopped - ): - self.emit("input_speech_stopped") - - def _handle_input_audio_buffer_speech_committed( - self, speech_committed: api_proto.ServerEvent.InputAudioBufferCommitted - ): - self.emit("input_speech_committed") - - def _handle_conversation_item_input_audio_transcription_completed( - self, - transcription_completed: api_proto.ServerEvent.ConversationItemInputAudioTranscriptionCompleted, - ): - transcript = transcription_completed["transcript"] - self.emit( - "input_speech_transcription_completed", - InputTranscriptionCompleted( - item_id=transcription_completed["item_id"], - transcript=transcript, - ), - ) - - def _handle_conversation_item_input_audio_transcription_failed( - self, - transcription_failed: api_proto.ServerEvent.ConversationItemInputAudioTranscriptionFailed, - ): - error = transcription_failed["error"] - logger.error( - "OAI S2S failed to transcribe input audio: %s", - error["message"], - extra=self.logging_extra(), - ) - self.emit( - "input_speech_transcription_failed", - InputTranscriptionFailed( - item_id=transcription_failed["item_id"], - message=error["message"], - ), - ) - - def _handle_conversation_item_created( - self, item_created: api_proto.ServerEvent.ConversationItemCreated - ): - previous_item_id = item_created["previous_item_id"] - item = item_created["item"] - item_type = item["type"] - item_id = item["id"] - - # Create message based on item type - # Leave the content empty and fill it in later from the content parts - if item_type == "message": - # Handle message items (system/user/assistant) - item = cast(Union[api_proto.SystemItem, api_proto.UserItem], item) - role = item["role"] - message = llm.ChatMessage(id=item_id, role=role) - if item.get("content"): - content = item["content"][0] - if content["type"] in ("text", "input_text"): - content = cast(api_proto.InputTextContent, content) - message.content = content["text"] - elif content["type"] == "input_audio" and content.get("audio"): - audio_data = base64.b64decode(content["audio"]) - message.content = llm.ChatAudio( - frame=rtc.AudioFrame( - data=audio_data, - sample_rate=api_proto.SAMPLE_RATE, - num_channels=api_proto.NUM_CHANNELS, - samples_per_channel=len(audio_data) // 2, - ) - ) - - elif item_type == "function_call": - # Handle function call items - item = cast(api_proto.FunctionCallItem, item) - message = llm.ChatMessage( - id=item_id, - role="assistant", - name=item["name"], - tool_call_id=item["call_id"], - ) - - elif item_type == "function_call_output": - # Handle function call output items - item = cast(api_proto.FunctionCallOutputItem, item) - message = llm.ChatMessage( - id=item_id, - role="tool", - tool_call_id=item["call_id"], - content=item["output"], - ) - - else: - logger.error( - f"unknown conversation item type {item_type}", - extra=self.logging_extra(), - ) - return - - # Insert into conversation items - self._remote_conversation_items.insert_after(previous_item_id, message) - if item_id in self._item_created_futs: - self._item_created_futs[item_id].set_result(True) - del self._item_created_futs[item_id] - logger.debug("conversation item created", extra=item_created) - - def _handle_conversation_item_deleted( - self, item_deleted: api_proto.ServerEvent.ConversationItemDeleted - ): - # Delete from conversation items - item_id = item_deleted["item_id"] - self._remote_conversation_items.delete(item_id) - if item_id in self._item_deleted_futs: - self._item_deleted_futs[item_id].set_result(True) - del self._item_deleted_futs[item_id] - logger.debug("conversation item deleted", extra=item_deleted) - - def _handle_conversation_item_truncated( - self, item_truncated: api_proto.ServerEvent.ConversationItemTruncated - ): - item_id = item_truncated["item_id"] - if item_id in self._item_truncated_futs: - self._item_truncated_futs[item_id].set_result(True) - del self._item_truncated_futs[item_id] + self, _: InputAudioBufferSpeechStoppedEvent + ) -> None: + self.emit("input_speech_stopped", multimodal.InputSpeechStoppedEvent()) - def _handle_response_created( - self, response_created: api_proto.ServerEvent.ResponseCreated - ): - response = response_created["response"] - done_fut = self._loop.create_future() - status_details = response.get("status_details") - new_response = RealtimeResponse( - id=response["id"], - status=response["status"], - status_details=status_details, - output=[], - usage=response.get("usage"), - done_fut=done_fut, - _created_timestamp=time.time(), + def _handle_response_created(self, event: ResponseCreatedEvent) -> None: + response_id = event.response.id + assert response_id is not None, "response.id is None" + self._current_generation = _ResponseGeneration( + response_id=response_id, + item_id="", + audio_ch=utils.aio.Chan(), + text_ch=utils.aio.Chan(), + tool_calls_ch=utils.aio.Chan(), ) - self._pending_responses[new_response.id] = new_response - self._active_response_id = new_response.id - - # complete the create future if it exists - if self._response_create_fut is not None: - self._response_create_fut.set_result(None) - self._response_create_fut = None - - self.emit("response_created", new_response) def _handle_response_output_item_added( - self, response_output_added: api_proto.ServerEvent.ResponseOutputItemAdded - ): - response_id = response_output_added["response_id"] - response = self._pending_responses[response_id] - done_fut = self._loop.create_future() - item_data = response_output_added["item"] - - item_type: Literal["message", "function_call"] = item_data["type"] # type: ignore - assert item_type in ("message", "function_call") - # function_call doesn't have a role field, defaulting it to assistant - item_role: api_proto.Role = item_data.get("role") or "assistant" # type: ignore + self, event: ResponseOutputItemAddedEvent + ) -> None: + assert self._current_generation is not None, "current_generation is None" + item_id = event.item.id + assert item_id is not None, "item.id is None" - new_output = RealtimeOutput( - response_id=response_id, - item_id=item_data["id"], - output_index=response_output_added["output_index"], - type=item_type, - role=item_role, - content=[], - done_fut=done_fut, - ) - response.output.append(new_output) - self.emit("response_output_added", new_output) + # We assume only one "message" item in the current approach + if self._current_generation.item_id and event.item.type == "message": + logger.warning("Received an unexpected second item with type `message`") + return - def _handle_response_content_part_added( - self, response_content_added: api_proto.ServerEvent.ResponseContentPartAdded - ): - response_id = response_content_added["response_id"] - response = self._pending_responses[response_id] - output_index = response_content_added["output_index"] - output = response.output[output_index] - content_type = response_content_added["part"]["type"] + if event.item.type == "function_call": + return - text_ch = utils.aio.Chan[str]() - audio_ch = utils.aio.Chan[rtc.AudioFrame]() + self._current_generation.item_id = item_id - new_content = RealtimeContent( - response_id=response_id, - item_id=response_content_added["item_id"], - output_index=output_index, - content_index=response_content_added["content_index"], - text="", - audio=[], - text_stream=text_ch, - audio_stream=audio_ch, - tool_calls=[], - content_type=content_type, - ) - output.content.append(new_content) - response._first_token_timestamp = time.time() - self.emit("response_content_added", new_content) + def _handle_response_audio_transcript_delta( + self, event: ResponseAudioTranscriptDeltaEvent + ) -> None: + assert self._current_generation is not None, "current_generation is None" + self._current_generation.text_ch.send_nowait(event.delta) - def _handle_response_audio_delta( - self, response_audio_delta: api_proto.ServerEvent.ResponseAudioDelta - ): - content = self._get_content(response_audio_delta) - data = base64.b64decode(response_audio_delta["delta"]) - audio = rtc.AudioFrame( + def _handle_response_audio_delta(self, event: ResponseAudioDeltaEvent) -> None: + assert self._current_generation is not None, "current_generation is None" + data = base64.b64decode(event.delta) + frame = rtc.AudioFrame( data=data, - sample_rate=api_proto.SAMPLE_RATE, - num_channels=api_proto.NUM_CHANNELS, + sample_rate=SAMPLE_RATE, + num_channels=NUM_CHANNELS, samples_per_channel=len(data) // 2, ) - content.audio.append(audio) - - assert isinstance(content.audio_stream, utils.aio.Chan) - content.audio_stream.send_nowait(audio) - - def _handle_response_audio_transcript_delta( - self, - response_audio_transcript_delta: api_proto.ServerEvent.ResponseAudioTranscriptDelta, - ): - content = self._get_content(response_audio_transcript_delta) - transcript = response_audio_transcript_delta["delta"] - content.text += transcript - - assert isinstance(content.text_stream, utils.aio.Chan) - content.text_stream.send_nowait(transcript) - - def _handle_response_audio_done( - self, response_audio_done: api_proto.ServerEvent.ResponseAudioDone - ): - content = self._get_content(response_audio_done) - assert isinstance(content.audio_stream, utils.aio.Chan) - content.audio_stream.close() - - def _handle_response_text_done( - self, response_text_done: api_proto.ServerEvent.ResponseTextDone - ): - content = self._get_content(response_text_done) - content.text = response_text_done["text"] + self._current_generation.audio_ch.send_nowait(frame) def _handle_response_audio_transcript_done( - self, - response_audio_transcript_done: api_proto.ServerEvent.ResponseAudioTranscriptDone, - ): - content = self._get_content(response_audio_transcript_done) - assert isinstance(content.text_stream, utils.aio.Chan) - content.text_stream.close() + self, _: ResponseAudioTranscriptDoneEvent + ) -> None: + assert self._current_generation is not None, "current_generation is None" + self._current_generation.text_ch.close() - def _handle_response_content_part_done( - self, response_content_done: api_proto.ServerEvent.ResponseContentPartDone - ): - content = self._get_content(response_content_done) - self.emit("response_content_done", content) + def _handle_response_audio_done(self, _: ResponseAudioDoneEvent) -> None: + assert self._current_generation is not None, "current_generation is None" + self._current_generation.audio_ch.close() def _handle_response_output_item_done( - self, response_output_done: api_proto.ServerEvent.ResponseOutputItemDone - ): - response_id = response_output_done["response_id"] - response = self._pending_responses[response_id] - output_index = response_output_done["output_index"] - output = response.output[output_index] + self, event: ResponseOutputItemDoneEvent + ) -> None: + assert self._current_generation is not None, "current_generation is None" - if output.type == "function_call": - if self._fnc_ctx is None: - logger.error( - "function call received but no fnc_ctx is available", - extra=self.logging_extra(), + item = event.item + if item.type == "function_call": + if self.fnc_ctx is None: + logger.warning( + "received a function_call item without a function context", + extra={"item": item}, ) return - # parse the arguments and call the function inside the fnc_ctx - item = response_output_done["item"] - assert item["type"] == "function_call" + assert item.call_id is not None, "call_id is None" + assert item.name is not None, "name is None" + assert item.arguments is not None, "arguments is None" fnc_call_info = _create_ai_function_info( - self._fnc_ctx, - item["call_id"], - item["name"], - item["arguments"], + self.fnc_ctx, + item.call_id, + item.name, + item.arguments, ) + self._current_generation.tool_calls_ch.send_nowait(fnc_call_info) - msg = self._remote_conversation_items.get(output.item_id) - if msg is not None: - # update the content of the message - assert msg.tool_call_id == item["call_id"] - assert msg.role == "assistant" - msg.name = item["name"] - msg.tool_calls = [fnc_call_info] - - self.emit("function_calls_collected", [fnc_call_info]) - - self._fnc_tasks.create_task( - self._run_fnc_task(fnc_call_info, output.item_id) - ) + def _handle_response_done(self, _: ResponseDoneEvent) -> None: + assert self._current_generation is not None, "current_generation is None" + self._current_generation.tool_calls_ch.close() + self._current_generation = None - output.done_fut.set_result(None) - self.emit("response_output_done", output) - - def _handle_response_done(self, response_done: api_proto.ServerEvent.ResponseDone): - response_data = response_done["response"] - response_id = response_data["id"] - response = self._pending_responses[response_id] - self._active_response_id = None - response.done_fut.set_result(None) - - response.status = response_data["status"] - response.status_details = response_data.get("status_details") - response.usage = response_data.get("usage") + @property + def chat_ctx(self) -> llm.ChatContext: + return self._chat_ctx.copy() - metrics_error = None - cancelled = False - if response.status == "failed": - assert response.status_details is not None + async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: + # get the difference between self._chat_ctx and chat_ctx + pass - error = response.status_details.get("error", {}) - code: str | None = error.get("code") # type: ignore - message: str | None = error.get("message") # type: ignore - metrics_error = MultimodalLLMError( - type=response.status_details.get("type"), - code=code, - message=message, - ) + @property + def fnc_ctx(self) -> llm.FunctionContext | None: + return self._fnc_ctx - logger.error( - "response generation failed", - extra={"code": code, "error": message, **self.logging_extra()}, - ) - elif response.status == "incomplete": - assert response.status_details is not None - reason = response.status_details.get("reason") + async def update_fnc_ctx(self, fnc_ctx: llm.FunctionContext | None) -> None: + pass - metrics_error = MultimodalLLMError( - type=response.status_details.get("type"), - reason=reason, # type: ignore + def push_audio(self, frame: rtc.AudioFrame) -> None: + self._msg_ch.send_nowait( + InputAudioBufferAppendEvent( + type="input_audio_buffer.append", + audio=base64.b64encode(frame.data).decode("ascii"), ) - - logger.warning( - "response generation incomplete", - extra={"reason": reason, **self.logging_extra()}, - ) - elif response.status == "cancelled": - cancelled = True - - self.emit("response_done", response) - - # calculate metrics - ttft = -1.0 - if response._first_token_timestamp is not None: - ttft = response._first_token_timestamp - response._created_timestamp - duration = time.time() - response._created_timestamp - - usage = response.usage or {} # type: ignore - input_token_details = usage.get("input_token_details", {}) - metrics = MultimodalLLMMetrics( - timestamp=response._created_timestamp, - request_id=response.id, - ttft=ttft, - duration=duration, - cancelled=cancelled, - label=self._label, - completion_tokens=usage.get("output_tokens", 0), - prompt_tokens=usage.get("input_tokens", 0), - total_tokens=usage.get("total_tokens", 0), - tokens_per_second=usage.get("output_tokens", 0) / duration, - error=metrics_error, - input_token_details=MultimodalLLMMetrics.InputTokenDetails( - cached_tokens=input_token_details.get("cached_tokens", 0), - text_tokens=usage.get("input_token_details", {}).get("text_tokens", 0), - audio_tokens=usage.get("input_token_details", {}).get( - "audio_tokens", 0 - ), - cached_tokens_details=MultimodalLLMMetrics.CachedTokenDetails( - text_tokens=input_token_details.get( - "cached_tokens_details", {} - ).get("text_tokens", 0), - audio_tokens=input_token_details.get( - "cached_tokens_details", {} - ).get("audio_tokens", 0), - ), - ), - output_token_details=MultimodalLLMMetrics.OutputTokenDetails( - text_tokens=usage.get("output_token_details", {}).get("text_tokens", 0), - audio_tokens=usage.get("output_token_details", {}).get( - "audio_tokens", 0 - ), - ), ) - self.emit("metrics_collected", metrics) - def _get_content(self, ptr: _ContentPtr) -> RealtimeContent: - response = self._pending_responses[ptr["response_id"]] - output = response.output[ptr["output_index"]] - content = output.content[ptr["content_index"]] - return content + def generate_reply(self) -> None: + self._msg_ch.send_nowait(ResponseCreateEvent(type="response.create")) - @utils.log_exceptions(logger=logger) - async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str): - logger.debug( - "executing ai function", - extra={ - "function": fnc_call_info.function_info.name, - }, - ) + def interrupt(self) -> None: + self._msg_ch.send_nowait(ResponseCancelEvent(type="response.cancel")) - called_fnc = fnc_call_info.execute() - await called_fnc.task - - tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc) - logger.info( - "creating response for tool call", - extra={ - "function": fnc_call_info.function_info.name, - }, - ) - if tool_call.content is not None: - create_fut = self.conversation.item.create( - tool_call, - previous_item_id=item_id, + def truncate(self, *, message_id: str, audio_end_ms: int) -> None: + self._msg_ch.send_nowait( + ConversationItemTruncateEvent( + type="conversation.item.truncate", + content_index=0, + item_id=message_id, + audio_end_ms=audio_end_ms, ) - await self.response.create(on_duplicate="keep_both") - await create_fut - - # update the message with the tool call result - msg = self._remote_conversation_items.get(tool_call.id) - if msg is not None: - assert msg.tool_call_id == tool_call.tool_call_id - assert msg.role == "tool" - msg.name = tool_call.name - msg.content = tool_call.content - msg.tool_exception = tool_call.tool_exception - - self.emit("function_calls_finished", [called_fnc]) + ) - def logging_extra(self) -> dict: - return {"session_id": self._session_id} + async def aclose(self) -> None: + if self._conn is not None: + await self._conn.close() diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/remote_items.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/remote_items.py deleted file mode 100644 index 465f0789c..000000000 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/remote_items.py +++ /dev/null @@ -1,126 +0,0 @@ -from __future__ import annotations - -from collections import OrderedDict -from dataclasses import dataclass, field -from typing import Optional - -from livekit.agents import llm - -from .log import logger - - -@dataclass -class _ConversationItem: - """A node in the conversation linked list""" - - message: llm.ChatMessage - _prev: Optional[_ConversationItem] = field(default=None, repr=False) - _next: Optional[_ConversationItem] = field(default=None, repr=False) - - -class _RemoteConversationItems: - """Manages conversation items in a doubly-linked list""" - - def __init__(self) -> None: - self._head: Optional[_ConversationItem] = None - self._tail: Optional[_ConversationItem] = None - self._id_to_item: OrderedDict[str, _ConversationItem] = OrderedDict() - - @classmethod - def from_chat_context(cls, chat_ctx: llm.ChatContext) -> _RemoteConversationItems: - """Create ConversationItems from a ChatContext""" - items = cls() - for msg in chat_ctx.messages: - items.append(msg) - return items - - def to_chat_context(self) -> llm.ChatContext: - """Export to a ChatContext""" - chat_ctx = llm.ChatContext() - current = self._head - while current: - chat_ctx.messages.append(current.message.copy()) - current = current._next - return chat_ctx - - def append(self, message: llm.ChatMessage) -> None: - """Add a message to the end of the conversation""" - if message.id is None: - raise ValueError("Message must have an id") - - if message.id in self._id_to_item: - raise ValueError(f"Message with id {message.id} already exists") - - item = _ConversationItem(message=message) - item._prev = self._tail - item._next = None - - if self._tail: - self._tail._next = item - self._tail = item - - if not self._head: - self._head = item - - self._id_to_item[message.id] = item - - def insert_after(self, prev_item_id: str | None, message: llm.ChatMessage) -> None: - """Insert a message after the specified message ID. - If prev_item_id is None, append to the end.""" - if message.id is None: - raise ValueError("Message must have an id") - - if message.id in self._id_to_item: - raise ValueError(f"Message with id {message.id} already exists") - - if prev_item_id is None: - # Append to end instead of inserting at head - self.append(message) - return - - prev_item = self._id_to_item.get(prev_item_id) - if not prev_item: - logger.error( - f"Previous message with id {prev_item_id} not found, ignore it" - ) - return - - new_item = _ConversationItem(message=message) - new_item._prev = prev_item - new_item._next = prev_item._next - prev_item._next = new_item - if new_item._next: - new_item._next._prev = new_item - else: - self._tail = new_item - - self._id_to_item[message.id] = new_item - - def delete(self, item_id: str) -> None: - """Delete a message by its ID""" - item = self._id_to_item.get(item_id) - if not item: - logger.error(f"Message with id {item_id} not found for deletion") - return - - if item._prev: - item._prev._next = item._next - else: - self._head = item._next - - if item._next: - item._next._prev = item._prev - else: - self._tail = item._prev - - del self._id_to_item[item_id] - - def get(self, item_id: str) -> llm.ChatMessage | None: - """Get a message by its ID""" - item = self._id_to_item.get(item_id) - return item.message if item else None - - @property - def messages(self) -> list[llm.ChatMessage]: - """Return all messages in order""" - return [item.message for item in self._id_to_item.values()] From eb048d2671586e98db2305cc262f133d65657c3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Mon, 13 Jan 2025 14:36:23 +0100 Subject: [PATCH 07/19] chat_ctx diff with root --- livekit-agents/livekit/agents/llm/utils.py | 71 +++++ .../livekit/agents/utils/_message_change.py | 250 ++++++------------ tests/test_message_change.py | 170 ------------ 3 files changed, 153 insertions(+), 338 deletions(-) create mode 100644 livekit-agents/livekit/agents/llm/utils.py delete mode 100644 tests/test_message_change.py diff --git a/livekit-agents/livekit/agents/llm/utils.py b/livekit-agents/livekit/agents/llm/utils.py new file mode 100644 index 000000000..cc2a26c15 --- /dev/null +++ b/livekit-agents/livekit/agents/llm/utils.py @@ -0,0 +1,71 @@ +from __future__ import annotations +from dataclasses import dataclass + +from .chat_context import ChatContext + + +def _compute_lcs(old_ids: list[str], new_ids: list[str]) -> list[str]: + """ + Standard dynamic-programming LCS to get the common subsequence + of IDs (in order) that appear in both old_ids and new_ids. + """ + n, m = len(old_ids), len(new_ids) + dp = [[0] * (m + 1) for _ in range(n + 1)] + + # Fill DP table + for i in range(1, n + 1): + for j in range(1, m + 1): + if old_ids[i - 1] == new_ids[j - 1]: + dp[i][j] = dp[i - 1][j - 1] + 1 + else: + dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) + + # Backtrack to find the actual LCS sequence + lcs_ids = [] + i, j = n, m + while i > 0 and j > 0: + if old_ids[i - 1] == new_ids[j - 1]: + lcs_ids.append(old_ids[i - 1]) + i -= 1 + j -= 1 + elif dp[i - 1][j] > dp[i][j - 1]: + i -= 1 + else: + j -= 1 + + return list(reversed(lcs_ids)) + + +@dataclass +class DiffOps: + to_remove: list[str] + to_create: list[ + tuple[str | None, str] + ] # (previous_item_id, id), if previous_item_id is None, add to the root + + +def compute_chat_ctx_diff(old_ctx: ChatContext, new_ctx: ChatContext) -> DiffOps: + """Computes the minimal list of create/remove operations to transform old_ctx into new_ctx.""" + # TODO(theomonnom): Make ChatMessage hashable and also add update ops + + old_ids = [m.id for m in old_ctx.messages] + new_ids = [m.id for m in new_ctx.messages] + lcs_ids = set(_compute_lcs(old_ids, new_ids)) + + to_remove = [msg.id for msg in old_ctx.messages if msg.id not in lcs_ids] + to_create: list[tuple[str | None, str]] = [] + + last_id_in_sequence: str | None = None + for new_msg in new_ctx.messages: + if new_msg.id in lcs_ids: + last_id_in_sequence = new_msg.id + else: + if last_id_in_sequence is None: + prev_id = None # root + else: + prev_id = last_id_in_sequence + + to_create.append((prev_id, new_msg.id)) + last_id_in_sequence = new_msg.id + + return DiffOps(to_remove=to_remove, to_create=to_create) diff --git a/livekit-agents/livekit/agents/utils/_message_change.py b/livekit-agents/livekit/agents/utils/_message_change.py index b8a715eec..a5e214828 100644 --- a/livekit-agents/livekit/agents/utils/_message_change.py +++ b/livekit-agents/livekit/agents/utils/_message_change.py @@ -1,177 +1,91 @@ from dataclasses import dataclass from typing import Callable, Generic, TypeVar, Union -T = TypeVar("T") +T = TypeVar("T") -@dataclass -class MessageChange(Generic[T]): - """Represents changes needed to transform one list into another - The changes must be applied in order: - 1. First apply all deletions - 2. Then apply all insertions with their previous_item_id +def compute_lcs(old_ids: list[str], new_ids: list[str]) -> list[str]: """ + Standard dynamic-programming LCS to get the common subsequence + of IDs (in order) that appear in both old_ids and new_ids. + """ + n, m = len(old_ids), len(new_ids) + dp = [[0] * (m + 1) for _ in range(n + 1)] + + # Fill DP table + for i in range(1, n + 1): + for j in range(1, m + 1): + if old_ids[i - 1] == new_ids[j - 1]: + dp[i][j] = dp[i - 1][j - 1] + 1 + else: + dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) + + # Backtrack to find the actual LCS sequence + lcs_ids = [] + i, j = n, m + while i > 0 and j > 0: + if old_ids[i - 1] == new_ids[j - 1]: + lcs_ids.append(old_ids[i - 1]) + i -= 1 + j -= 1 + elif dp[i - 1][j] > dp[i][j - 1]: + i -= 1 + else: + j -= 1 + + return list(reversed(lcs_ids)) + + +def sync_chat_ctx(old_ctx: ChatContext, new_ctx: ChatContext) -> List[dict]: + old_ids = [m.id for m in old_ctx.messages] + new_ids = [m.id for m in new_ctx.messages] + + # 1. Find the set of message IDs we will keep + lcs_ids = set(compute_lcs(old_ids, new_ids)) + + requests = [] + + # 2. Remove messages from old that are NOT in LCS + # (this ensures that anything not in the final conversation is removed) + for old_msg in old_ctx.messages: + if old_msg.id not in lcs_ids: + requests.append( + { + "action": "remove", + "id": old_msg.id, + } + ) - to_delete: list[T] - """Items to delete from old list""" - to_add: list[tuple[Union[T, None], T]] - """Items to add as (previous_item, new_item) pairs""" - - -def compute_changes( - old_list: list[T], new_list: list[T], key_fnc: Callable[[T], str] -) -> MessageChange[T]: - """Compute minimum changes needed to transform old list into new list""" - # Convert to lists of ids - old_ids = [key_fnc(msg) for msg in old_list] - new_ids = [key_fnc(msg) for msg in new_list] - - # Create lookup maps - old_msgs = {key_fnc(msg): msg for msg in old_list} - new_msgs = {key_fnc(msg): msg for msg in new_list} - - # Compute changes using ids - changes = _compute_list_changes(old_ids, new_ids) - - # Convert back to items - return MessageChange( - to_delete=[old_msgs[id] for id in changes.to_delete], - to_add=[ - ( - None if prev is None else old_msgs.get(prev) or new_msgs[prev], - new_msgs[new], + # 3. Create the missing messages from new in the correct order + # We keep track of the "last message ID" that ends up in the final new sequence. + last_id_in_new_sequence: Optional[str] = None + + for i, new_msg in enumerate(new_ctx.messages): + if new_msg.id in lcs_ids: + # This message is already kept (it's in LCS), so just update + # our 'last_id_in_new_sequence' pointer to it, + # meaning "this message is logically next in the final conversation" + last_id_in_new_sequence = new_msg.id + else: + # This message is not in LCS: we need to create it + if last_id_in_new_sequence is None: + # Insert at the very beginning + prev_item = "root" + else: + # Insert after the last item that we have in the final conversation + prev_item = last_id_in_new_sequence + + requests.append( + { + "action": "create", + "id": new_msg.id, + "role": new_msg.role, + "previous_item_id": prev_item, # could be "root" or last_id + } ) - for prev, new in changes.to_add - ], - ) - - -def _compute_list_changes(old_list: list[T], new_list: list[T]) -> MessageChange[T]: - """Compute minimum changes needed to transform old_list into new_list - - Rules: - - Delete first, then insert - - Can't insert at start if list not empty (must delete all first) - - Each insert needs previous item except for first item in new list - - If an item changes position relative to others, it must be deleted and reinserted - - If first item in new list exists in old list, must delete all items before it - - Examples: - old [a b c d] new [b c d e] -> delete a, insert (d,e) - old [a b c d] new [e a b c d] -> delete all, insert (None,e),(e,a),(a,b),(b,c),(c,d) - old [a b c d] new [a b d e c] -> delete d, insert (b,d),(d,e) - old [a b c d] new [a d c b] -> delete c,d, insert (a,d),(d,c) - """ - if not new_list: - return MessageChange(to_delete=old_list, to_add=[]) - - # Find first item's position in old list - try: - first_idx = old_list.index(new_list[0]) - except ValueError: - # Special case: if first item is new, delete everything - prev_item: Union[T, None] = None - to_add: list[tuple[Union[T, None], T]] = [] - for x in new_list: - to_add.append((prev_item, x)) - prev_item = x - return MessageChange(to_delete=old_list, to_add=to_add) - - # Delete all items before first_idx - to_delete = old_list[:first_idx] - remaining_old = old_list[first_idx:] - - # Get positions of remaining items in new list - indices = [] - items = [] - new_positions = {x: i for i, x in enumerate(new_list)} - for x in remaining_old: - if x in new_positions: - indices.append(new_positions[x]) - items.append(x) - - # Try fast path first - check if remaining order is preserved - if _check_order_preserved(indices): - kept_indices = list(range(len(indices))) - else: - # Order changed, need to find kept items using LIS - # First item must be kept since we've already handled items before it - kept_indices = _find_longest_increasing_subsequence(indices) - - # Convert kept indices back to items - kept_items = {items[i] for i in kept_indices} - - # Add items that need to be deleted from remaining list - to_delete.extend(x for x in remaining_old if x not in kept_items) - - # Compute items to add by following new list order - to_add = [] - prev_item = None - for x in new_list: - if x not in kept_items: - to_add.append((prev_item, x)) - prev_item = x - - return MessageChange(to_delete=to_delete, to_add=to_add) - - -def _check_order_preserved(indices: list[int]) -> bool: - """Check if indices form an increasing sequence""" - if not indices: - return True - - # Check if indices form an increasing sequence - for i in range(1, len(indices)): - if indices[i] <= indices[i - 1]: - return False - - return True - - -def _find_longest_increasing_subsequence(indices: list[int]) -> list[int]: - """Find indices of the longest increasing subsequence - - Args: - indices: List of indices to find LIS from - - Returns: - List of indices into the input list that form the LIS - For example, indices = [0, 4, 1, 2] -> [0, 2, 3] - """ - if not indices: - return [] - - # Must include first index, find LIS starting from it - first_val = indices[0] - dp = [1] * len(indices) - prev = [-1] * len(indices) - best_len = 1 # At minimum we keep the first index - best_end = 0 # Start with first index - - # Start from second element - for i in range(1, len(indices)): - # Only consider sequences starting from first index - if indices[i] > first_val: - dp[i] = 2 - prev[i] = 0 - if dp[i] > best_len: - best_len = dp[i] - best_end = i - - # Try extending existing sequences - for j in range(1, i): - if indices[j] < indices[i] and prev[j] != -1 and dp[j] + 1 > dp[i]: - dp[i] = dp[j] + 1 - prev[i] = j - if dp[i] > best_len: - best_len = dp[i] - best_end = i - - # Reconstruct sequence - result = [] - while best_end != -1: - result.append(best_end) - best_end = prev[best_end] - result.reverse() - return result + + # Now this newly created message becomes the last in the final conversation + last_id_in_new_sequence = new_msg.id + + return requests diff --git a/tests/test_message_change.py b/tests/test_message_change.py deleted file mode 100644 index a02f395a7..000000000 --- a/tests/test_message_change.py +++ /dev/null @@ -1,170 +0,0 @@ -import pytest -from livekit.agents.llm import ChatMessage -from livekit.agents.utils._message_change import ( - _check_order_preserved, - _compute_list_changes, - _find_longest_increasing_subsequence, - compute_changes, -) - - -@pytest.mark.parametrize( - "indices,expected_seq,desc", - [ - # Basic cases - ([0, 1, 2], [0, 1, 2], "Already sorted"), - ([2, 1, 0], [2], "Must keep first (2)"), - ([2, 0, 1], [2], "Must keep first (2)"), - ([2, 1, 0, 3], [2, 3], "Keep first and what can follow"), - ([3, 0, 1, 2], [3], "Only first when nothing can follow"), - ([2, 1, 0, 3, 4], [2, 3, 4], "Keep first and increasing suffix"), - ([4, 1, 2, 3], [4], "Only first when better sequence exists"), - ([0, 1, 4, 2], [0, 1, 4], "Keep longest increasing with first"), - # Edge cases - ([], [], "Empty list"), - ([0], [0], "Single element"), - ([1], [1], "Single element not zero"), - ([2, 1], [2], "Two elements, keep first"), - ], -) -def test_find_longest_increasing_subsequence(indices, expected_seq, desc): - """Test the LIS algorithm with various cases""" - result = _find_longest_increasing_subsequence(indices) - result_seq = [indices[i] for i in result] if result else [] - - # Verify sequence is increasing - if result_seq: - assert all( - result_seq[i] < result_seq[i + 1] for i in range(len(result_seq) - 1) - ), f"Not increasing in {desc}" - - # Verify first element is included - if result: - assert result[0] == 0, f"First index not included in {desc}" - - # Verify sequence matches expected - assert ( - result_seq == expected_seq - ), f"Wrong sequence in {desc}: expected {expected_seq}, got {result_seq}" - - -@pytest.mark.parametrize( - "indices,expected", - [ - ([], True), - ([0], True), - ([0, 1, 2], True), - ([0, 2, 1], False), - ([1, 1, 2], False), - ], -) -def test_check_order_preserved(indices, expected): - assert _check_order_preserved(indices) is expected - - -@pytest.mark.parametrize( - "old,new,expected_delete,expected_add", - [ - # Empty lists - ([], [], [], []), - (["a"], [], ["a"], []), - ([], ["a"], [], [(None, "a")]), - # Simple append/delete - (["a", "b", "c"], ["a", "b", "c", "d"], [], [("c", "d")]), - (["a", "b", "c", "d"], ["a", "b", "c"], ["d"], []), - # Delete first item - (["a", "b", "c", "d"], ["b", "c", "d", "e"], ["a"], [("d", "e")]), - (["x", "y", "b", "c"], ["b", "c", "d"], ["x", "y"], [("c", "d")]), - # New first item - must delete all - ( - ["a", "b", "c", "d"], - ["e", "a", "b", "c"], - ["a", "b", "c", "d"], - [(None, "e"), ("e", "a"), ("a", "b"), ("b", "c")], - ), - # First item exists but order changes - (["a", "b", "c", "d"], ["b", "a", "c", "d"], ["a"], [("b", "a")]), - (["x", "y", "b", "c"], ["b", "d", "c"], ["x", "y"], [("b", "d")]), - # Complex reordering - ( - ["a", "b", "c", "d"], - ["a", "b", "d", "e", "c"], - ["d"], - [("b", "d"), ("d", "e")], - ), - ( - ["a", "b", "c", "d"], - ["a", "d", "c", "b"], - ["c", "d"], - [("a", "d"), ("d", "c")], - ), - ], -) -def test_compute_list_changes(old, new, expected_delete, expected_add): - changes = _compute_list_changes(old, new) - assert changes.to_delete == expected_delete - assert changes.to_add == expected_add - - -@pytest.mark.parametrize( - "old_ids,new_ids", - [ - (["a", "b", "c", "d"], ["b", "c", "d", "e"]), - (["a", "b", "c", "d"], ["e", "a", "b", "c"]), - (["a", "b", "c", "d"], ["a", "b", "d", "e", "c"]), - ], -) -def test_compute_changes(old_ids, new_ids): - """Test computing changes with ChatMessage objects""" - - def create_msg(id: str) -> ChatMessage: - return ChatMessage(role="test", id=id) - - old = [create_msg(id) for id in old_ids] - new = [create_msg(id) for id in new_ids] - - changes = compute_changes(old, new, lambda msg: msg.id) - - # Apply changes and verify result - result = [msg for msg in old if msg not in changes.to_delete] - - for prev, msg in changes.to_add: - if prev is None: - result.append(msg) - else: - idx = result.index(prev) + 1 - result.insert(idx, msg) - - assert [msg.id for msg in result] == new_ids - - -@pytest.mark.parametrize( - "old,new", - [ - (["a", "b", "c", "d"], ["b", "c", "d", "e"]), - (["a", "b", "c", "d"], ["e", "a", "b", "c"]), - (["a", "b", "c", "d"], ["a", "b", "d", "e", "c"]), - (["a", "b", "c", "d"], ["b", "a", "c", "d"]), - (["x", "y", "b", "c"], ["b", "d", "c"]), - (["a", "b", "c", "d"], ["a", "d", "c", "b"]), - ], -) -def test_changes_maintain_list_integrity(old, new): - """Test that applying changes maintains list integrity""" - - def apply_changes(old: list[str], changes): - # Apply deletions - result = [x for x in old if x not in changes.to_delete] - - # Apply insertions - for prev, item in changes.to_add: - if prev is None: - result.append(item) - else: - idx = result.index(prev) + 1 - result.insert(idx, item) - return result - - changes = _compute_list_changes(old, new) - result = apply_changes(old, changes) - assert result == new, f"Failed to transform {old} into {new}, got {result}" From 4047f829bc54cd42da6fea4af94222a511461be1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Mon, 13 Jan 2025 14:37:55 +0100 Subject: [PATCH 08/19] remove unused --- livekit-agents/livekit/agents/utils/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/livekit-agents/livekit/agents/utils/__init__.py b/livekit-agents/livekit/agents/utils/__init__.py index 254ae812d..ccd72958c 100644 --- a/livekit-agents/livekit/agents/utils/__init__.py +++ b/livekit-agents/livekit/agents/utils/__init__.py @@ -1,7 +1,6 @@ from livekit import rtc from . import aio, audio, codecs, http_context, hw, images -from ._message_change import compute_changes as _compute_changes # keep internal from .audio import AudioBuffer, combine_frames, merge_frames from .exp_filter import ExpFilter from .log import log_exceptions @@ -27,5 +26,4 @@ "aio", "hw", "is_given", - "_compute_changes", ] From 3a25d8d836dac91f844aa7e0348bbe009137958e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Mon, 13 Jan 2025 14:38:20 +0100 Subject: [PATCH 09/19] Delete _message_change.py --- .../livekit/agents/utils/_message_change.py | 91 ------------------- 1 file changed, 91 deletions(-) delete mode 100644 livekit-agents/livekit/agents/utils/_message_change.py diff --git a/livekit-agents/livekit/agents/utils/_message_change.py b/livekit-agents/livekit/agents/utils/_message_change.py deleted file mode 100644 index a5e214828..000000000 --- a/livekit-agents/livekit/agents/utils/_message_change.py +++ /dev/null @@ -1,91 +0,0 @@ -from dataclasses import dataclass -from typing import Callable, Generic, TypeVar, Union - - -T = TypeVar("T") - - -def compute_lcs(old_ids: list[str], new_ids: list[str]) -> list[str]: - """ - Standard dynamic-programming LCS to get the common subsequence - of IDs (in order) that appear in both old_ids and new_ids. - """ - n, m = len(old_ids), len(new_ids) - dp = [[0] * (m + 1) for _ in range(n + 1)] - - # Fill DP table - for i in range(1, n + 1): - for j in range(1, m + 1): - if old_ids[i - 1] == new_ids[j - 1]: - dp[i][j] = dp[i - 1][j - 1] + 1 - else: - dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) - - # Backtrack to find the actual LCS sequence - lcs_ids = [] - i, j = n, m - while i > 0 and j > 0: - if old_ids[i - 1] == new_ids[j - 1]: - lcs_ids.append(old_ids[i - 1]) - i -= 1 - j -= 1 - elif dp[i - 1][j] > dp[i][j - 1]: - i -= 1 - else: - j -= 1 - - return list(reversed(lcs_ids)) - - -def sync_chat_ctx(old_ctx: ChatContext, new_ctx: ChatContext) -> List[dict]: - old_ids = [m.id for m in old_ctx.messages] - new_ids = [m.id for m in new_ctx.messages] - - # 1. Find the set of message IDs we will keep - lcs_ids = set(compute_lcs(old_ids, new_ids)) - - requests = [] - - # 2. Remove messages from old that are NOT in LCS - # (this ensures that anything not in the final conversation is removed) - for old_msg in old_ctx.messages: - if old_msg.id not in lcs_ids: - requests.append( - { - "action": "remove", - "id": old_msg.id, - } - ) - - # 3. Create the missing messages from new in the correct order - # We keep track of the "last message ID" that ends up in the final new sequence. - last_id_in_new_sequence: Optional[str] = None - - for i, new_msg in enumerate(new_ctx.messages): - if new_msg.id in lcs_ids: - # This message is already kept (it's in LCS), so just update - # our 'last_id_in_new_sequence' pointer to it, - # meaning "this message is logically next in the final conversation" - last_id_in_new_sequence = new_msg.id - else: - # This message is not in LCS: we need to create it - if last_id_in_new_sequence is None: - # Insert at the very beginning - prev_item = "root" - else: - # Insert after the last item that we have in the final conversation - prev_item = last_id_in_new_sequence - - requests.append( - { - "action": "create", - "id": new_msg.id, - "role": new_msg.role, - "previous_item_id": prev_item, # could be "root" or last_id - } - ) - - # Now this newly created message becomes the last in the final conversation - last_id_in_new_sequence = new_msg.id - - return requests From 1bc6b3c44653b1b9e280e032f41ef37682b00870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 15 Jan 2025 14:05:17 +0100 Subject: [PATCH 10/19] chat_ctx wip --- examples/minimal_worker.py | 6 +- .../livekit/agents/debug/__init__.py | 1 - .../livekit/agents/debug/tracing.py | 7 +- livekit-agents/livekit/agents/http_server.py | 1 - .../livekit/agents/ipc/job_proc_lazy_main.py | 2 +- livekit-agents/livekit/agents/llm/__init__.py | 20 +- .../livekit/agents/llm/chat_context.py | 180 ++++++------------ livekit-agents/livekit/agents/llm/utils.py | 1 + .../livekit/agents/multimodal/__init__.py | 10 +- .../livekit/agents/multimodal/realtime.py | 5 +- .../agents/pipeline/audio_recognition.py | 3 +- livekit-agents/livekit/agents/pipeline/io.py | 3 +- .../livekit/agents/pipeline/pipeline_agent.py | 9 +- livekit-agents/livekit/agents/worker.py | 7 +- .../plugins/openai/realtime/realtime_model.py | 123 ++++++++++-- .../livekit/plugins/openai/stt.py | 1 + 16 files changed, 208 insertions(+), 171 deletions(-) diff --git a/examples/minimal_worker.py b/examples/minimal_worker.py index 1a5fe6b40..d3d100513 100644 --- a/examples/minimal_worker.py +++ b/examples/minimal_worker.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, WorkerType, cli from livekit.agents.pipeline import ChatCLI, PipelineAgent -from livekit.plugins import deepgram, openai, cartesia +from livekit.plugins import cartesia, deepgram, openai, silero logger = logging.getLogger("my-worker") logger.setLevel(logging.INFO) @@ -14,7 +14,9 @@ async def entrypoint(ctx: JobContext): await ctx.connect(auto_subscribe=AutoSubscribe.SUBSCRIBE_ALL) - agent = PipelineAgent(llm=openai.LLM(), stt=deepgram.STT(), tts=cartesia.TTS()) + agent = PipelineAgent( + llm=openai.LLM(), stt=deepgram.STT(), tts=cartesia.TTS(), vad=silero.VAD.load() + ) agent.start() # start a chat inside the CLI diff --git a/livekit-agents/livekit/agents/debug/__init__.py b/livekit-agents/livekit/agents/debug/__init__.py index 7a3535d28..57306db8d 100644 --- a/livekit-agents/livekit/agents/debug/__init__.py +++ b/livekit-agents/livekit/agents/debug/__init__.py @@ -1,6 +1,5 @@ from .tracing import Tracing, TracingGraph, TracingHandle - __all__ = [ "Tracing", "TracingGraph", diff --git a/livekit-agents/livekit/agents/debug/tracing.py b/livekit-agents/livekit/agents/debug/tracing.py index 5ecc59d5e..91867c61d 100644 --- a/livekit-agents/livekit/agents/debug/tracing.py +++ b/livekit-agents/livekit/agents/debug/tracing.py @@ -1,9 +1,11 @@ from __future__ import annotations + import asyncio import time +from typing import TYPE_CHECKING, Any, Literal from aiohttp import web -from typing import TYPE_CHECKING, Any, Literal + from .. import job if TYPE_CHECKING: @@ -138,9 +140,10 @@ def add_graph( def _create_tracing_app(w: Worker) -> web.Application: async def tracing_index(request: web.Request) -> web.Response: - import aiofiles import importlib.resources + import aiofiles + with importlib.resources.path("livekit.agents.debug", "index.html") as path: async with aiofiles.open(path, mode="r") as f: content = await f.read() diff --git a/livekit-agents/livekit/agents/http_server.py b/livekit-agents/livekit/agents/http_server.py index e769d530d..922de7815 100644 --- a/livekit-agents/livekit/agents/http_server.py +++ b/livekit-agents/livekit/agents/http_server.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -from typing import Any from aiohttp import web diff --git a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py index 4baec41b0..ed0092012 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py @@ -27,8 +27,8 @@ def _no_traceback_excepthook(exc_type, exc_val, traceback): from livekit import rtc -from ..job import JobContext, JobProcess, _JobContextVar from ..debug import tracing +from ..job import JobContext, JobProcess, _JobContextVar from ..log import logger from ..utils import aio, http_context, log_exceptions, shortuuid from .channel import Message diff --git a/livekit-agents/livekit/agents/llm/__init__.py b/livekit-agents/livekit/agents/llm/__init__.py index d3a06f520..7c60fe701 100644 --- a/livekit-agents/livekit/agents/llm/__init__.py +++ b/livekit-agents/livekit/agents/llm/__init__.py @@ -1,10 +1,11 @@ from .chat_context import ( - ChatAudio, - ChatContent, ChatContext, - ChatImage, ChatMessage, - ChatRole, + FunctionCall, + FunctionCallOutput, + AudioContent, + ImageContent, + ChatItem, ) from .fallback_adapter import AvailabilityChangedEvent, FallbackAdapter from .function_context import ( @@ -28,16 +29,18 @@ LLMStream, ToolChoice, ) +from .utils import compute_chat_ctx_diff __all__ = [ "LLM", "LLMStream", "ChatContext", - "ChatRole", "ChatMessage", - "ChatAudio", - "ChatImage", - "ChatContent", + "FunctionCall", + "FunctionCallOutput", + "AudioContent", + "ImageContent", + "ChatItem", "ChatContext", "ChoiceDelta", "Choice", @@ -56,4 +59,5 @@ "AvailabilityChangedEvent", "ToolChoice", "_create_ai_function_info", + "compute_chat_ctx_diff", ] diff --git a/livekit-agents/livekit/agents/llm/chat_context.py b/livekit-agents/livekit/agents/llm/chat_context.py index ccde86bba..f1d7e9c03 100644 --- a/livekit-agents/livekit/agents/llm/chat_context.py +++ b/livekit-agents/livekit/agents/llm/chat_context.py @@ -13,19 +13,15 @@ # limitations under the License. from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any, Literal, Union +from typing import ( + Literal, +) from livekit import rtc -from livekit.agents import utils +from pydantic import BaseModel -from . import function_context -ChatRole = Literal["system", "user", "assistant", "tool"] - - -@dataclass -class ChatImage: +class ImageContent(BaseModel): """ ChatImage is used to input images into the ChatContext on supported LLM providers / plugins. @@ -76,119 +72,59 @@ class ChatImage: Currently only supported by OpenAI (see https://platform.openai.com/docs/guides/vision?lang=node#low-or-high-fidelity-image-understanding) """ - _cache: dict[Any, Any] = field(default_factory=dict, repr=False, init=False) - """ - _cache is used internally by LLM implementations to store a processed version of the image - for later use. - """ -@dataclass -class ChatAudio: - frame: rtc.AudioFrame | list[rtc.AudioFrame] - - -ChatContent = Union[str, ChatImage, ChatAudio] - - -@dataclass -class ChatMessage: - role: ChatRole - id: str = field( - default_factory=lambda: utils.shortuuid("item_") - ) # used by the OAI realtime API - name: str | None = None - content: ChatContent | list[ChatContent] | None = None - tool_calls: list[function_context.FunctionCallInfo] | None = None - tool_call_id: str | None = None - tool_exception: Exception | None = None - _metadata: dict[str, Any] = field(default_factory=dict, repr=False, init=False) - - @staticmethod - def create_tool_from_called_function( - called_function: function_context.CalledFunction, - ) -> "ChatMessage": - if not called_function.task.done(): - raise ValueError("cannot create a tool result from a running ai function") - - tool_exception: Exception | None = None - try: - content = called_function.task.result() - except BaseException as e: - if isinstance(e, Exception): - tool_exception = e - content = f"Error: {e}" - - return ChatMessage( - role="tool", - name=called_function.call_info.function_info.name, - content=content, - tool_call_id=called_function.call_info.tool_call_id, - tool_exception=tool_exception, - ) - - @staticmethod - def create_tool_calls( - called_functions: list[function_context.FunctionCallInfo], - *, - text: str = "", - ) -> "ChatMessage": - return ChatMessage(role="assistant", tool_calls=called_functions, content=text) - - @staticmethod - def create( - *, - text: str = "", - images: list[ChatImage] = [], - role: ChatRole = "system", - id: str | None = None, - ) -> "ChatMessage": - id = id or utils.shortuuid("item_") - if len(images) == 0: - return ChatMessage(role=role, content=text, id=id) - else: - content: list[ChatContent] = [] - if text: - content.append(text) - - if len(images) > 0: - content.extend(images) - - return ChatMessage(role=role, content=content, id=id) - - def copy(self): - content = self.content - if isinstance(content, list): - content = content.copy() - - tool_calls = self.tool_calls - if tool_calls is not None: - tool_calls = tool_calls.copy() - - copied_msg = ChatMessage( - role=self.role, - id=self.id, - name=self.name, - content=content, - tool_calls=tool_calls, - tool_call_id=self.tool_call_id, - ) - copied_msg._metadata = self._metadata - return copied_msg - - -@dataclass +class AudioContent(BaseModel): + frame: list[rtc.AudioFrame] + transcript: str | None = None + + +class FunctionCall(BaseModel): + type: Literal["function_call"] + call_id: str + name: str + arguments: str + + +class FunctionCallOutput(BaseModel): + type: Literal["function_call_output"] + call_id: str + output: str + is_error: bool + + +class ChatMessage(BaseModel): + type: Literal["message"] + role: Literal["developer", "system", "user", "assistant"] + content: list[str | ImageContent | AudioContent] + hash: bytes | None = None + + +class ChatItem(BaseModel): + id: str + content: list[ChatMessage | FunctionCall | FunctionCallOutput] + + class ChatContext: - messages: list[ChatMessage] = field(default_factory=list) - _metadata: dict[str, Any] = field(default_factory=dict, repr=False, init=False) - - def append( - self, *, text: str = "", images: list[ChatImage] = [], role: ChatRole = "system" - ) -> ChatContext: - self.messages.append(ChatMessage.create(text=text, images=images, role=role)) - return self - - def copy(self) -> ChatContext: - copied_chat_ctx = ChatContext(messages=[m.copy() for m in self.messages]) - copied_chat_ctx._metadata = self._metadata - return copied_chat_ctx + def __init__(self, items: list[ChatItem] | None = None): + self._items: list[ChatItem] = items or [] + + @property + def items(self) -> list[ChatItem]: + return self._items + + def get_by_id(self, item_id: str) -> ChatItem | None: + # ideally, get_by_id should be O(1) + for item in self.items: + if item.id == item_id: + return item + + def copy(self) -> "ChatContext": + return ChatContext(self.items.copy()) + + def to_dict(self) -> dict: + raise NotImplementedError + + @classmethod + def from_dict(cls, _: dict) -> "ChatContext": + raise NotImplementedError diff --git a/livekit-agents/livekit/agents/llm/utils.py b/livekit-agents/livekit/agents/llm/utils.py index cc2a26c15..3e53ce0d1 100644 --- a/livekit-agents/livekit/agents/llm/utils.py +++ b/livekit-agents/livekit/agents/llm/utils.py @@ -1,4 +1,5 @@ from __future__ import annotations + from dataclasses import dataclass from .chat_context import ChatContext diff --git a/livekit-agents/livekit/agents/multimodal/__init__.py b/livekit-agents/livekit/agents/multimodal/__init__.py index fe6d5d654..c38056a68 100644 --- a/livekit-agents/livekit/agents/multimodal/__init__.py +++ b/livekit-agents/livekit/agents/multimodal/__init__.py @@ -1,11 +1,11 @@ from .realtime import ( - RealtimeModel, - RealtimeCapabilities, - RealtimeSession, + ErrorEvent, + GenerationCreatedEvent, InputSpeechStartedEvent, InputSpeechStoppedEvent, - GenerationCreatedEvent, - ErrorEvent, + RealtimeCapabilities, + RealtimeModel, + RealtimeSession, ) __all__ = [ diff --git a/livekit-agents/livekit/agents/multimodal/realtime.py b/livekit-agents/livekit/agents/multimodal/realtime.py index 114e33ca9..32a58671e 100644 --- a/livekit-agents/livekit/agents/multimodal/realtime.py +++ b/livekit-agents/livekit/agents/multimodal/realtime.py @@ -1,14 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import AsyncIterable, Generic, Literal, TypeVar, Union from livekit import rtc from .. import llm -from typing import AsyncIterable, Union, Literal, Generic, TypeVar - @dataclass class InputSpeechStartedEvent: diff --git a/livekit-agents/livekit/agents/pipeline/audio_recognition.py b/livekit-agents/livekit/agents/pipeline/audio_recognition.py index 7ae05430d..9cb6306f5 100644 --- a/livekit-agents/livekit/agents/pipeline/audio_recognition.py +++ b/livekit-agents/livekit/agents/pipeline/audio_recognition.py @@ -6,12 +6,11 @@ from livekit import rtc from .. import llm, stt, utils, vad +from ..debug import tracing from ..log import logger from ..utils import aio from . import io -from ..debug import tracing - if TYPE_CHECKING: from .pipeline2 import PipelineAgent diff --git a/livekit-agents/livekit/agents/pipeline/io.py b/livekit-agents/livekit/agents/pipeline/io.py index fea6d01a8..a93d5008b 100644 --- a/livekit-agents/livekit/agents/pipeline/io.py +++ b/livekit-agents/livekit/agents/pipeline/io.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod from dataclasses import dataclass from typing import ( @@ -15,8 +16,6 @@ from .. import llm, stt -import asyncio - STTNode = Callable[ [AsyncIterable[rtc.AudioFrame]], Union[Awaitable[Optional[AsyncIterable[stt.SpeechEvent]]]], diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index f668476ec..1eb44a8a0 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -3,27 +3,26 @@ import asyncio import contextlib import heapq - from dataclasses import dataclass from typing import ( AsyncIterable, - Tuple, Literal, Optional, + Tuple, Union, ) from livekit import rtc -from .. import llm, stt, tts, utils, vad, debug, tokenize +from .. import debug, llm, stt, tokenize, tts, utils, vad from ..llm import ChatContext, FunctionContext from ..log import logger -from . import io, events +from . import events, io from .audio_recognition import AudioRecognition, _TurnDetector from .generation import ( + _TTSGenerationData, do_llm_inference, do_tts_inference, - _TTSGenerationData, ) diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index 7693b0406..494ff84db 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -14,7 +14,6 @@ from __future__ import annotations -import time import asyncio import contextlib import datetime @@ -24,6 +23,7 @@ import os import sys import threading +import time from dataclasses import dataclass, field from enum import Enum from functools import reduce @@ -39,12 +39,13 @@ import aiohttp import jwt +from aiohttp import web from livekit import api, rtc from livekit.protocol import agent, models from . import http_server, ipc, utils -from .debug import tracing from ._exceptions import AssignmentTimeoutError +from .debug import tracing from .inference_runner import _InferenceRunner from .job import ( JobAcceptArguments, @@ -58,8 +59,6 @@ from .utils.hw import get_cpu_monitor from .version import __version__ -from aiohttp import web - ASSIGNMENT_TIMEOUT = 7.5 UPDATE_STATUS_INTERVAL = 2.5 UPDATE_LOAD_INTERVAL = 0.5 diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index d29ce6235..28ed56247 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -1,37 +1,40 @@ from __future__ import annotations import asyncio -import openai import base64 +from dataclasses import dataclass +from livekit import rtc +from livekit.agents import llm, multimodal, utils from livekit.agents.llm.function_context import _create_ai_function_info + +import openai from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection from openai.types.beta.realtime import ( + ConversationItem, + ConversationItemContent, + ConversationItemCreateEvent, + ConversationItemDeleteEvent, + ConversationItemTruncateEvent, + ErrorEvent, + InputAudioBufferAppendEvent, InputAudioBufferSpeechStartedEvent, InputAudioBufferSpeechStoppedEvent, RealtimeClientEvent, - InputAudioBufferAppendEvent, ResponseAudioDeltaEvent, ResponseAudioDoneEvent, ResponseAudioTranscriptDeltaEvent, ResponseAudioTranscriptDoneEvent, ResponseCancelEvent, - ResponseCreateEvent, - ConversationItemTruncateEvent, ResponseCreatedEvent, + ResponseCreateEvent, ResponseDoneEvent, ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, ) -from dataclasses import dataclass - -from livekit.agents import multimodal, llm, utils -from livekit import rtc - from .log import logger - # When a response is created with the OpenAI Realtime API, those events are sent in this order: # 1. response.created (contains resp_id) # 2. response.output_item.added (contains item_id) @@ -128,6 +131,8 @@ async def _listen_for_events() -> None: self._handle_response_output_item_done(event) elif event.type == "response.done": self._handle_response_done(event) + elif event.type == "error": + self._handle_error(event) @utils.log_exceptions(logger=logger) async def _forward_input_audio() -> None: @@ -240,13 +245,45 @@ def _handle_response_done(self, _: ResponseDoneEvent) -> None: self._current_generation.tool_calls_ch.close() self._current_generation = None + def _handle_error(self, event: ErrorEvent) -> None: + logger.error( + "OpenAI Realtime API returned an error", + extra={"error": event.error}, + ) + self.emit( + "error", + multimodal.ErrorEvent(type=event.error.type, message=event.error.message), + ) + @property def chat_ctx(self) -> llm.ChatContext: return self._chat_ctx.copy() async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: - # get the difference between self._chat_ctx and chat_ctx - pass + diff_ops = llm.compute_chat_ctx_diff(self._chat_ctx, chat_ctx) + + for msg_id in diff_ops.to_remove: + self._msg_ch.send_nowait( + ConversationItemDeleteEvent( + type="conversation.item.delete", + item_id=msg_id, + ) + ) + + for previous_msg_id, msg_id in diff_ops.to_create: + chat_item = chat_ctx.get_by_id(msg_id) + assert chat_item is not None + self._msg_ch.send_nowait( + ConversationItemCreateEvent( + type="conversation.item.create", + item=_chat_item_to_conversation_item(chat_item), + previous_item_id=( + "root" if previous_msg_id is None else previous_msg_id + ), + ) + ) + + # TODO(theomonnom): wait for the server confirmation @property def fnc_ctx(self) -> llm.FunctionContext | None: @@ -259,7 +296,7 @@ def push_audio(self, frame: rtc.AudioFrame) -> None: self._msg_ch.send_nowait( InputAudioBufferAppendEvent( type="input_audio_buffer.append", - audio=base64.b64encode(frame.data).decode("ascii"), + audio=base64.b64encode(frame.data).decode("utf-8"), ) ) @@ -282,3 +319,63 @@ def truncate(self, *, message_id: str, audio_end_ms: int) -> None: async def aclose(self) -> None: if self._conn is not None: await self._conn.close() + + +def _chat_item_to_conversation_item(msg: llm.ChatItem) -> ConversationItem: + if not msg.content: + raise ValueError("ChatItem has no content") + + item = msg.content[0] + conversation_item = ConversationItem( + id=msg.id, + object="realtime.item", + ) + + if isinstance(item, llm.FunctionCall): + conversation_item.type = "function_call" + conversation_item.call_id = item.call_id + conversation_item.name = item.name + conversation_item.arguments = item.arguments + + elif isinstance(item, llm.FunctionCallOutput): + conversation_item.type = "function_call_output" + conversation_item.call_id = item.call_id + conversation_item.output = item.output + + elif isinstance(item, llm.ChatMessage): + role = "system" if item.role == "developer" else item.role + conversation_item.type = "message" + conversation_item.role = role + + content_list: list[ConversationItemContent] = [] + for c in item.content: + if isinstance(c, str): + content_list.append( + ConversationItemContent( + type=("text" if role == "assistant" else "input_text"), + text=c, + ) + ) + + elif isinstance(c, llm.ImageContent): + continue # not supported for now + elif isinstance(c, llm.AudioContent): + if conversation_item.role == "user": + encoded_audio = base64.b64encode( + rtc.combine_audio_frames(c.frame).data + ).decode("utf-8") + + content_list.append( + ConversationItemContent( + type="input_audio", + audio=encoded_audio, + transcript=c.transcript, + ) + ) + + conversation_item.content = content_list + + else: + raise ValueError(f"Unsupported ChatItem content: {item}") + + return conversation_item diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py index e3f19972a..1b59e1f98 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py @@ -138,6 +138,7 @@ async def _recognize_impl( conn_options: APIConnectOptions, ) -> stt.SpeechEvent: try: + print("buffer", buffer) config = self._sanitize_options(language=language) data = rtc.combine_audio_frames(buffer).to_wav_bytes() resp = await self._client.audio.transcriptions.create( From eb4a68eefd61f22e0a75a3d59bbd216079349e91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Mon, 20 Jan 2025 08:56:10 +0100 Subject: [PATCH 11/19] wip --- livekit-agents/livekit/agents/__init__.py | 2 - .../livekit/agents/ipc/job_thread_executor.py | 6 +- .../livekit/agents/ipc/proc_client.py | 6 +- .../livekit/agents/ipc/supervised_proc.py | 6 +- livekit-agents/livekit/agents/llm/__init__.py | 30 +- .../livekit/agents/llm/chat_context.py | 54 +- .../livekit/agents/llm/fallback_adapter.py | 1 - .../livekit/agents/llm/function_context.py | 501 +++++------------- livekit-agents/livekit/agents/llm/llm.py | 6 +- .../livekit/agents/multimodal/realtime.py | 8 +- .../livekit/agents/pipeline/__init__.py | 2 +- .../livekit/agents/pipeline/agent_task.py | 472 +++++++++++++++++ .../agents/pipeline/audio_recognition.py | 79 ++- .../livekit/agents/pipeline/chat_cli.py | 2 +- .../livekit/agents/pipeline/generation.py | 2 +- .../livekit/agents/pipeline/pipeline_agent.py | 497 +++-------------- .../agents/pipeline/speech_scheduler.py | 1 + .../livekit/agents/utils/aio/channel.py | 4 +- .../livekit/plugins/llama_index/llm.py | 6 +- .../plugins/openai/realtime/realtime_model.py | 98 +++- .../livekit/plugins/silero/vad.py | 4 +- .../livekit/plugins/turn_detector/eou.py | 6 +- tests/test_create_func.py | 30 +- tests/test_ipc.py | 6 +- tests/test_llm.py | 18 +- tests/test_stt.py | 6 +- tests/test_tts.py | 6 +- 27 files changed, 903 insertions(+), 956 deletions(-) create mode 100644 livekit-agents/livekit/agents/pipeline/agent_task.py create mode 100644 livekit-agents/livekit/agents/pipeline/speech_scheduler.py diff --git a/livekit-agents/livekit/agents/__init__.py b/livekit-agents/livekit/agents/__init__.py index 5c504b162..401e8cb56 100644 --- a/livekit-agents/livekit/agents/__init__.py +++ b/livekit-agents/livekit/agents/__init__.py @@ -21,7 +21,6 @@ pipeline, stt, tokenize, - transcription, tts, utils, vad, @@ -68,7 +67,6 @@ "tokenize", "llm", "metrics", - "transcription", "pipeline", "multimodal", "cli", diff --git a/livekit-agents/livekit/agents/ipc/job_thread_executor.py b/livekit-agents/livekit/agents/ipc/job_thread_executor.py index feace496d..3eefac28f 100644 --- a/livekit-agents/livekit/agents/ipc/job_thread_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_thread_executor.py @@ -158,9 +158,9 @@ async def initialize(self) -> None: channel.arecv_message(self._pch, proto.IPC_MESSAGES), timeout=self._opts.initialize_timeout, ) - assert isinstance( - init_res, proto.InitializeResponse - ), "first message must be InitializeResponse" + assert isinstance(init_res, proto.InitializeResponse), ( + "first message must be InitializeResponse" + ) except asyncio.TimeoutError: self._initialize_fut.set_exception( asyncio.TimeoutError("runner initialization timed out") diff --git a/livekit-agents/livekit/agents/ipc/proc_client.py b/livekit-agents/livekit/agents/ipc/proc_client.py index 76b77fb88..97080e8cc 100644 --- a/livekit-agents/livekit/agents/ipc/proc_client.py +++ b/livekit-agents/livekit/agents/ipc/proc_client.py @@ -53,9 +53,9 @@ def initialize(self) -> None: cch = aio.duplex_unix._Duplex.open(self._mp_cch) first_req = recv_message(cch, IPC_MESSAGES) - assert isinstance( - first_req, InitializeRequest - ), "first message must be proto.InitializeRequest" + assert isinstance(first_req, InitializeRequest), ( + "first message must be proto.InitializeRequest" + ) self._init_req = first_req self._initialize_fnc(self._init_req, self) diff --git a/livekit-agents/livekit/agents/ipc/supervised_proc.py b/livekit-agents/livekit/agents/ipc/supervised_proc.py index e56119876..dfd18172d 100644 --- a/livekit-agents/livekit/agents/ipc/supervised_proc.py +++ b/livekit-agents/livekit/agents/ipc/supervised_proc.py @@ -165,9 +165,9 @@ async def initialize(self) -> None: channel.arecv_message(self._pch, proto.IPC_MESSAGES), timeout=self._opts.initialize_timeout, ) - assert isinstance( - init_res, proto.InitializeResponse - ), "first message must be InitializeResponse" + assert isinstance(init_res, proto.InitializeResponse), ( + "first message must be InitializeResponse" + ) except asyncio.TimeoutError: self._initialize_fut.set_exception( asyncio.TimeoutError("process initialization timed out") diff --git a/livekit-agents/livekit/agents/llm/__init__.py b/livekit-agents/livekit/agents/llm/__init__.py index 7c60fe701..c1adf8710 100644 --- a/livekit-agents/livekit/agents/llm/__init__.py +++ b/livekit-agents/livekit/agents/llm/__init__.py @@ -1,23 +1,19 @@ from .chat_context import ( + AudioContent, ChatContext, + ChatItem, ChatMessage, FunctionCall, FunctionCallOutput, - AudioContent, ImageContent, - ChatItem, ) from .fallback_adapter import AvailabilityChangedEvent, FallbackAdapter from .function_context import ( - USE_DOCSTRING, - CalledFunction, - FunctionArgInfo, - FunctionCallInfo, + AIFunction, FunctionContext, - FunctionInfo, - TypeInfo, - _create_ai_function_info, - ai_callable, + ai_function, + find_ai_functions, + is_ai_function, ) from .llm import ( LLM, @@ -46,18 +42,14 @@ "Choice", "ChatChunk", "CompletionUsage", - "FunctionContext", - "ai_callable", - "TypeInfo", - "FunctionArgInfo", - "FunctionInfo", - "FunctionCallInfo", - "CalledFunction", - "USE_DOCSTRING", "LLMCapabilities", "FallbackAdapter", "AvailabilityChangedEvent", "ToolChoice", - "_create_ai_function_info", "compute_chat_ctx_diff", + "is_ai_function", + "ai_function", + "find_ai_functions", + "AIFunction", + "FunctionContext", ] diff --git a/livekit-agents/livekit/agents/llm/chat_context.py b/livekit-agents/livekit/agents/llm/chat_context.py index f1d7e9c03..bf9fa2c7b 100644 --- a/livekit-agents/livekit/agents/llm/chat_context.py +++ b/livekit-agents/livekit/agents/llm/chat_context.py @@ -11,14 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from __future__ import annotations from typing import ( Literal, + Optional, + Union, ) from livekit import rtc -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing_extensions import TypeAlias + +from .. import utils class ImageContent(BaseModel): @@ -54,15 +60,17 @@ class ImageContent(BaseModel): ``` """ - image: str | rtc.VideoFrame + type: Literal["image_content"] = Field(default="image_content") + + image: Union[str, rtc.VideoFrame] """ Either a string URL or a VideoFrame object """ - inference_width: int | None = None + inference_width: Optional[int] = None """ Resizing parameter for rtc.VideoFrame inputs (ignored for URL images) """ - inference_height: int | None = None + inference_height: Optional[int] = None """ Resizing parameter for rtc.VideoFrame inputs (ignored for URL images) """ @@ -75,39 +83,45 @@ class ImageContent(BaseModel): class AudioContent(BaseModel): + type: Literal["audio_content"] = Field(default="audio_content") frame: list[rtc.AudioFrame] - transcript: str | None = None + transcript: Optional[str] = None + + +class ChatMessage(BaseModel): + id: str = Field(default_factory=lambda: utils.shortuuid("item_")) + type: Literal["message"] = "message" + role: Literal["developer", "system", "user", "assistant"] + content: list[Union[str, ImageContent, AudioContent]] + hash: Optional[bytes] = None class FunctionCall(BaseModel): - type: Literal["function_call"] + id: str = Field(default_factory=lambda: utils.shortuuid("item_")) + type: Literal["function_call"] = "function_call" call_id: str - name: str arguments: str + name: str class FunctionCallOutput(BaseModel): - type: Literal["function_call_output"] + id: str = Field(default_factory=lambda: utils.shortuuid("item_")) + type: Literal["function_call_output"] = Field(default="function_call_output") call_id: str output: str is_error: bool -class ChatMessage(BaseModel): - type: Literal["message"] - role: Literal["developer", "system", "user", "assistant"] - content: list[str | ImageContent | AudioContent] - hash: bytes | None = None - - -class ChatItem(BaseModel): - id: str - content: list[ChatMessage | FunctionCall | FunctionCallOutput] +ChatItem: TypeAlias = Union[ChatMessage, FunctionCall, FunctionCallOutput] class ChatContext: - def __init__(self, items: list[ChatItem] | None = None): - self._items: list[ChatItem] = items or [] + def __init__(self, items: list[ChatItem]): + self._items: list[ChatItem] = items + + @classmethod + def empty(cls) -> "ChatContext": + return cls([]) @property def items(self) -> list[ChatItem]: diff --git a/livekit-agents/livekit/agents/llm/fallback_adapter.py b/livekit-agents/livekit/agents/llm/fallback_adapter.py index fd5242e4d..6b8f15523 100644 --- a/livekit-agents/livekit/agents/llm/fallback_adapter.py +++ b/livekit-agents/livekit/agents/llm/fallback_adapter.py @@ -11,7 +11,6 @@ from ..log import logger from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions from .chat_context import ChatContext -from .function_context import FunctionContext from .llm import LLM, ChatChunk, LLMStream, ToolChoice DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions( diff --git a/livekit-agents/livekit/agents/llm/function_context.py b/livekit-agents/livekit/agents/llm/function_context.py index 59604fc8d..9507d7e79 100644 --- a/livekit-agents/livekit/agents/llm/function_context.py +++ b/livekit-agents/livekit/agents/llm/function_context.py @@ -14,384 +14,153 @@ from __future__ import annotations -import asyncio -import enum -import functools import inspect -import json -import types -import typing -from dataclasses import dataclass -from typing import Any, Callable, Optional, Tuple - -from ..log import logger - - -class _UseDocMarker: - pass - - -METADATA_ATTR = "__livekit_ai_metadata__" -USE_DOCSTRING = _UseDocMarker() +from typing import ( + Annotated, + Any, + Callable, + List, + Protocol, + Type, + get_args, + get_origin, + get_type_hints, + runtime_checkable, +) + +from pydantic import BaseModel, create_model +from pydantic.fields import FieldInfo +from typing_extensions import TypeGuard + + +@runtime_checkable +class AIFunction(Protocol): + __livekit_agents_ai_callable: bool + __name__: str + __doc__: str | None + + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + + +def ai_function(f: Callable | None = None) -> Callable[[Callable], AIFunction]: + def deco(f) -> AIFunction: + setattr(f, "__livekit_agents_ai_callable", True) + return f + if callable(f): + return deco(f) -@dataclass(frozen=True, init=False) -class TypeInfo: - description: str - choices: tuple + return deco - def __init__(self, description: str, choices: tuple | list[Any] = tuple()) -> None: - object.__setattr__(self, "description", description) - if isinstance(choices, list): - choices = tuple(choices) +def is_ai_function(f: Callable) -> TypeGuard[AIFunction]: + return getattr(f, "__livekit_agents_ai_callable", False) - object.__setattr__(self, "choices", choices) +def find_ai_functions(cls: Type) -> List[AIFunction]: + methods: list[AIFunction] = [] + for _, method in inspect.getmembers(cls, predicate=inspect.isfunction): + if is_ai_function(method): + methods.append(method) + return methods -@dataclass(frozen=True) -class FunctionArgInfo: - name: str - description: str - type: type - default: Any - choices: tuple | None +class FunctionContext: + """Stateless container for a set of AI functions""" -@dataclass(frozen=True) -class FunctionInfo: - name: str - description: str - auto_retry: bool - callable: Callable - arguments: dict[str, FunctionArgInfo] + def __init__(self, ai_functions: list[AIFunction]) -> None: + self.update_ai_functions(ai_functions) + @classmethod + def empty(cls) -> FunctionContext: + return cls([]) -@dataclass(frozen=True) -class FunctionCallInfo: - tool_call_id: str - function_info: FunctionInfo - raw_arguments: str - arguments: dict[str, Any] + @property + def ai_functions(self) -> dict[str, AIFunction]: + return self._ai_functions_map.copy() + + def update_ai_functions(self, ai_functions: list[AIFunction]) -> None: + self._ai_functions = ai_functions + + for method in find_ai_functions(self.__class__): + ai_functions.append(method) + + self._ai_functions_map = {} + for fnc in ai_functions: + if fnc.__name__ in self._ai_functions_map: + raise ValueError(f"duplicate function name: {fnc.__name__}") + + self._ai_functions_map[fnc.__name__] = fnc + + def copy(self) -> FunctionContext: + return FunctionContext(self._ai_functions.copy()) + + +def build_legacy_openai_schema( + ai_function: AIFunction, *, internally_tagged: bool = False +) -> dict[str, Any]: + """non-strict mode tool description + see https://serde.rs/enum-representations.html for the internally tagged representation""" + model = build_pydantic_model_from_function(ai_function) + schema = model.model_json_schema() + + fnc_name = ai_function.__name__ + fnc_description = ai_function.__doc__ + + if internally_tagged: + return { + "name": fnc_name, + "description": fnc_description or "", + "parameters": schema, + "type": "function", + } + else: + return { + "type": "function", + "function": { + "name": fnc_name, + "description": fnc_description or "", + "parameters": schema, + }, + } + + +def build_pydantic_model_from_function( + func: Callable, +) -> type[BaseModel]: + fnc_name = func.__name__.split("_") + fnc_name = "".join(x.capitalize() for x in fnc_name) + model_name = fnc_name + "Args" + + signature = inspect.signature(func) + type_hints = get_type_hints(func, include_extras=True) + + # field_name -> (type, FieldInfo or default) + fields: dict[str, Any] = {} + + for param_name, param in signature.parameters.items(): + annotation = type_hints[param_name] + default_value = param.default if param.default is not param.empty else ... + + # Annotated[str, Field(description="...")] + if get_origin(annotation) is Annotated: + annotated_args = get_args(annotation) + actual_type = annotated_args[0] + field_info = None + + for extra in annotated_args[1:]: + if isinstance(extra, FieldInfo): + field_info = extra # get the first FieldInfo + break + + if field_info: + if default_value is not ... and field_info.default is None: + field_info.default = default_value + fields[param_name] = (actual_type, field_info) + else: + fields[param_name] = (actual_type, default_value) - def execute(self) -> CalledFunction: - function_info = self.function_info - func = functools.partial(function_info.callable, **self.arguments) - if asyncio.iscoroutinefunction(function_info.callable): - task = asyncio.create_task(func()) else: - task = asyncio.create_task(asyncio.to_thread(func)) - - called_fnc = CalledFunction(call_info=self, task=task) - - def _on_done(fut): - try: - called_fnc.result = fut.result() - except BaseException as e: - called_fnc.exception = e - - task.add_done_callback(_on_done) - return called_fnc + fields[param_name] = (annotation, default_value) - -@dataclass -class CalledFunction: - call_info: FunctionCallInfo - task: asyncio.Task[Any] - result: Any | None = None - exception: BaseException | None = None - - -def ai_callable( - *, - name: str | None = None, - description: str | _UseDocMarker = USE_DOCSTRING, - auto_retry: bool = False, -) -> Callable: - def deco(f): - _set_metadata(f, name=name, desc=description, auto_retry=auto_retry) - return f - - return deco - - -class FunctionContext: - def __init__(self) -> None: - self._fncs = dict[str, FunctionInfo]() - - for _, member in inspect.getmembers(self, predicate=inspect.ismethod): - if hasattr(member, METADATA_ATTR): - self._register_ai_function(member) - - def ai_callable( - self, - *, - name: str | None = None, - description: str | _UseDocMarker = USE_DOCSTRING, - auto_retry: bool = True, - ) -> Callable: - def deco(f): - _set_metadata(f, name=name, desc=description, auto_retry=auto_retry) - self._register_ai_function(f) - - return deco - - def _register_ai_function(self, fnc: Callable) -> None: - if not hasattr(fnc, METADATA_ATTR): - logger.warning(f"function {fnc.__name__} does not have ai metadata") - return - - metadata: _AIFncMetadata = getattr(fnc, METADATA_ATTR) - fnc_name = metadata.name - if fnc_name in self._fncs: - raise ValueError(f"duplicate ai_callable name: {fnc_name}") - - sig = inspect.signature(fnc) - - # get_type_hints with include_extra=True is needed when using Annotated - # using typing.get_args with param.Annotated is returning an empty tuple for some reason - type_hints = typing.get_type_hints( - fnc, include_extras=True - ) # Annotated[T, ...] -> T - args = dict[str, FunctionArgInfo]() - - for name, param in sig.parameters.items(): - if param.kind not in ( - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ): - raise ValueError(f"{fnc_name}: unsupported parameter kind {param.kind}") - - inner_th, type_info = _extract_types(type_hints[name]) - - if not is_type_supported(inner_th): - raise ValueError( - f"{fnc_name}: unsupported type {inner_th} for parameter {name}" - ) - - desc = type_info.description if type_info else "" - choices = type_info.choices if type_info else () - - if ( - isinstance(inner_th, type) - and issubclass(inner_th, enum.Enum) - and not choices - ): - # the enum must be a str or int (and at least one value) - # this is verified by is_type_supported - choices = tuple([item.value for item in inner_th]) - inner_th = type(choices[0]) - - args[name] = FunctionArgInfo( - name=name, - description=desc, - type=inner_th, - default=param.default, - choices=choices, - ) - - self._fncs[metadata.name] = FunctionInfo( - name=metadata.name, - description=metadata.description, - auto_retry=metadata.auto_retry, - callable=fnc, - arguments=args, - ) - - @property - def ai_functions(self) -> dict[str, FunctionInfo]: - return self._fncs - - -@dataclass(frozen=True) -class _AIFncMetadata: - name: str - description: str - auto_retry: bool - - -def _extract_types(annotation: type) -> tuple[type, TypeInfo | None]: - """Return inner_type, TypeInfo""" - if typing.get_origin(annotation) is not typing.Annotated: - # email: Annotated[ - # Optional[str], TypeInfo(description="The user address email") - # ] = None, - # - # An argument like the above will return us: - # `typing.Optional[typing.Annotated[typing.Optional[str], TypeInfo(description='The user address email', choices=())]]` - # So we ignore the first typing.Optional - - is_optional, optional_inner = _is_optional_type(annotation) - if is_optional: - inner_type, info = _extract_types(optional_inner) - return Optional[inner_type], info # type: ignore - - return annotation, None - - # assume the first argument is always the inner type the LLM will use - args = typing.get_args(annotation) - if len(args) < 2: - return args[0], None - - for a in args: - if isinstance(a, TypeInfo): - return args[0], a - - return args[0], None - - -def _set_metadata( - f: Callable, - name: str | None = None, - desc: str | _UseDocMarker = USE_DOCSTRING, - auto_retry: bool = False, -) -> None: - if isinstance(desc, _UseDocMarker): - docstring = inspect.getdoc(f) - if docstring is None: - raise ValueError( - f"missing docstring for function {f.__name__}, " - "use explicit description or provide docstring" - ) - desc = docstring - - metadata = _AIFncMetadata( - name=name or f.__name__, description=desc, auto_retry=auto_retry - ) - - setattr(f, METADATA_ATTR, metadata) - - -def is_type_supported(t: type) -> bool: - if t in (str, int, float, bool): - return True - - if typing.get_origin(t) is list: - in_type = typing.get_args(t)[0] - return is_type_supported(in_type) - - is_optional, ty = _is_optional_type(t) - if is_optional: - return is_type_supported(ty) - - if issubclass(t, enum.Enum): - initial_type = None - for e in t: - if initial_type is None: - initial_type = type(e.value) - if type(e.value) is not initial_type: - return False - - return initial_type in (str, int) - - return False - - -def _is_optional_type(typ) -> Tuple[bool, Any]: - """return is_optional, inner_type""" - origin = typing.get_origin(typ) - if origin is None or origin is list: - return False, typ - - if origin in {typing.Union, getattr(types, "UnionType", typing.Union)}: - args = typing.get_args(typ) - is_optional = type(None) in args - non_none_args = [a for a in args if a is not type(None)] - if is_optional and len(non_none_args) == 1: - # Exactly one non-None type + None means optional - return True, non_none_args[0] - - return False, None - - -def _create_ai_function_info( - fnc_ctx: FunctionContext, - tool_call_id: str, - fnc_name: str, - raw_arguments: str, # JSON string -) -> FunctionCallInfo: - if fnc_name not in fnc_ctx.ai_functions: - raise ValueError(f"AI function {fnc_name} not found") - - parsed_arguments: dict[str, Any] = {} - try: - if raw_arguments: # ignore empty string - parsed_arguments = json.loads(raw_arguments) - except json.JSONDecodeError: - raise ValueError( - f"AI function {fnc_name} received invalid JSON arguments - {raw_arguments}" - ) - - fnc_info = fnc_ctx.ai_functions[fnc_name] - - # Ensure all necessary arguments are present and of the correct type. - sanitized_arguments: dict[str, Any] = {} - for arg_info in fnc_info.arguments.values(): - if arg_info.name not in parsed_arguments: - if arg_info.default is inspect.Parameter.empty: - raise ValueError( - f"AI function {fnc_name} missing required argument {arg_info.name}" - ) - continue - - arg_value = parsed_arguments[arg_info.name] - is_optional, inner_th = _is_optional_type(arg_info.type) - - if typing.get_origin(inner_th) is not None: - if not isinstance(arg_value, list): - raise ValueError( - f"AI function {fnc_name} argument {arg_info.name} should be a list" - ) - - inner_type = typing.get_args(inner_th)[0] - sanitized_value = [ - _sanitize_primitive( - value=v, - expected_type=inner_type, - choices=arg_info.choices, - ) - for v in arg_value - ] - else: - sanitized_value = _sanitize_primitive( - value=arg_value, - expected_type=inner_th, - choices=arg_info.choices, - ) - - sanitized_arguments[arg_info.name] = sanitized_value - - return FunctionCallInfo( - tool_call_id=tool_call_id, - raw_arguments=raw_arguments, - function_info=fnc_info, - arguments=sanitized_arguments, - ) - - -def _sanitize_primitive( - *, value: Any, expected_type: type, choices: tuple | None -) -> Any: - if expected_type is str: - if not isinstance(value, str): - raise ValueError(f"expected str, got {type(value)}") - elif expected_type in (int, float): - if not isinstance(value, (int, float)): - raise ValueError(f"expected number, got {type(value)}") - - if expected_type is int: - if value % 1 != 0: - raise ValueError("expected int, got float") - - value = int(value) - elif expected_type is float: - value = float(value) - - elif expected_type is bool: - if not isinstance(value, bool): - raise ValueError(f"expected bool, got {type(value)}") - - if choices and value not in choices: - raise ValueError(f"invalid value {value}, not in {choices}") - - return value + return create_model(model_name, **fields) diff --git a/livekit-agents/livekit/agents/llm/llm.py b/livekit-agents/livekit/agents/llm/llm.py index 099e3139c..265898b26 100644 --- a/livekit-agents/livekit/agents/llm/llm.py +++ b/livekit-agents/livekit/agents/llm/llm.py @@ -24,14 +24,14 @@ from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions from ..utils import aio from . import function_context -from .chat_context import ChatContext, ChatRole +from .chat_context import ChatContext @dataclass class ChoiceDelta: - role: ChatRole + # role: ChatRole content: str | None = None - tool_calls: list[function_context.FunctionCallInfo] | None = None + # tool_calls: list[function_context.FunctionCallInfo] | None = None @dataclass diff --git a/livekit-agents/livekit/agents/multimodal/realtime.py b/livekit-agents/livekit/agents/multimodal/realtime.py index 32a58671e..f3e08cb15 100644 --- a/livekit-agents/livekit/agents/multimodal/realtime.py +++ b/livekit-agents/livekit/agents/multimodal/realtime.py @@ -24,7 +24,7 @@ class GenerationCreatedEvent: message_id: str text_stream: AsyncIterable[str] audio_stream: AsyncIterable[rtc.AudioFrame] - tool_calls: AsyncIterable[llm.FunctionCallInfo] + function_stream: AsyncIterable[llm.FunctionCall] @dataclass @@ -85,10 +85,12 @@ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: ... @property @abstractmethod - def fnc_ctx(self) -> llm.FunctionContext | None: ... + def fnc_ctx(self) -> llm.FunctionContext: ... @abstractmethod - async def update_fnc_ctx(self, fnc_ctx: llm.FunctionContext | None) -> None: ... + async def update_fnc_ctx( + self, fnc_ctx: llm.FunctionContext | list[llm.AIFunction] + ) -> None: ... @abstractmethod def push_audio(self, frame: rtc.AudioFrame) -> None: ... diff --git a/livekit-agents/livekit/agents/pipeline/__init__.py b/livekit-agents/livekit/agents/pipeline/__init__.py index d6b34ab53..fed344c1f 100644 --- a/livekit-agents/livekit/agents/pipeline/__init__.py +++ b/livekit-agents/livekit/agents/pipeline/__init__.py @@ -1,4 +1,4 @@ from .chat_cli import ChatCLI -from .pipeline2 import PipelineAgent +from .pipeline_agent import PipelineAgent __all__ = ["ChatCLI", "PipelineAgent"] diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py new file mode 100644 index 000000000..8705b1138 --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -0,0 +1,472 @@ +from __future__ import annotations + +import asyncio +from typing import ( + AsyncIterable, + Optional, + Union, +) + +from livekit import rtc + +from .. import llm, multimodal, stt, tokenize, tts, utils, vad, debug +from ..llm import ChatContext, FunctionContext, find_ai_functions +from ..log import logger +from .agent_task import AgentTask +from .audio_recognition import AudioRecognition, _TurnDetector +from .pipeline_agent import PipelineAgent, SpeechHandle +from .generation import ( + do_llm_inference, + do_tts_inference, + _TTSGenerationData, + _LLMGenerationData, +) + + +class AgentTask: + def __init__( + self, + *, + instructions: str, + chat_ctx: llm.ChatContext | None = None, + fnc_ctx: llm.FunctionContext | None = None, + turn_detector: _TurnDetector | None = None, + stt: stt.STT | None = None, + vad: vad.VAD | None = None, + llm: llm.LLM | multimodal.RealtimeModel | None = None, + tts: tts.TTS | None = None, + ) -> None: + if tts and not tts.capabilities.streaming: + from .. import tts as text_to_speech + + tts = text_to_speech.StreamAdapter( + tts=tts, sentence_tokenizer=tokenize.basic.SentenceTokenizer() + ) + + if stt and not stt.capabilities.streaming: + from .. import stt as speech_to_text + + if vad is None: + raise ValueError( + "VAD is required when streaming is not supported by the STT" + ) + + stt = speech_to_text.StreamAdapter( + stt=stt, + vad=vad, + ) + + self._instructions = instructions + self._chat_ctx = chat_ctx or ChatContext.empty() + self._fnc_ctx = fnc_ctx or FunctionContext.empty() + self._fnc_ctx.update_ai_functions( + list(self._fnc_ctx.ai_functions.values()) + + find_ai_functions(self.__class__) + ) + self._turn_detector = turn_detector + self._stt, self._llm, self._tts, self._vad = stt, llm, tts, vad + + self._agent: PipelineAgent | None = None + self._rt_session: multimodal.RealtimeSession | None = None + self._audio_recognition: AudioRecognition | None = None + + @property + def instructions(self) -> str: + return self._instructions + + @instructions.setter + def instructions(self, instructions: str) -> None: + self._instructions = instructions + + @property + def chat_ctx(self) -> llm.ChatContext: + return self._chat_ctx + + async def stt_node( + self, audio: AsyncIterable[rtc.AudioFrame] + ) -> Optional[AsyncIterable[stt.SpeechEvent]]: + assert self._stt is not None, "stt_node called but no STT node is available" + + async with self._stt.stream() as stream: + + async def _forward_input(): + async for frame in audio: + stream.push_frame(frame) + + forward_task = asyncio.create_task(_forward_input()) + try: + async for event in stream: + yield event + finally: + await utils.aio.gracefully_cancel(forward_task) + + async def llm_node( + self, chat_ctx: llm.ChatContext, fnc_ctx: llm.FunctionContext | None + ) -> Union[ + Optional[AsyncIterable[llm.ChatChunk]], + Optional[AsyncIterable[str]], + Optional[str], + ]: + assert self._llm is not None, "llm_node called but no LLM node is available" + assert isinstance(self._llm, llm.LLM), ( + "llm_node should only be used with LLM (non-multimodal/realtime APIs) nodes" + ) + + async with self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) as stream: + async for chunk in stream: + yield chunk + + async def tts_node( + self, text: AsyncIterable[str] + ) -> Optional[AsyncIterable[rtc.AudioFrame]]: + assert self._tts is not None, "tts_node called but no TTS node is available" + + async with self._tts.stream() as stream: + + async def _forward_input(): + async for chunk in text: + stream.push_text(chunk) + + stream.end_input() + + forward_task = asyncio.create_task(_forward_input()) + try: + async for ev in stream: + yield ev.frame + finally: + await utils.aio.gracefully_cancel(forward_task) + + async def _on_start(self, agent: PipelineAgent) -> None: + if self._rt_session is not None: + logger.warning("starting a new task while rt_session is not None") + + self._audio_recognition = AudioRecognition( + task=self, + stt=self.stt_node, + vad=self._vad, + turn_detector=self._turn_detector, + min_endpointing_delay=agent.options.min_endpointing_delay, + ) + self._audio_recognition.start() + + if isinstance(self._llm, multimodal.RealtimeModel): + self._rt_session = self._llm.session() + self._rt_session.on("generation_created", self._on_generation_created) + await self._rt_session.update_chat_ctx(self._chat_ctx) + + async def _on_close(self) -> None: + if self._rt_session is not None: + await self._rt_session.aclose() + + if self._audio_recognition is not None: + await self._audio_recognition.aclose() + + def _on_generation_created(self, ev: multimodal.GenerationCreatedEvent) -> None: + pass + + def _on_input_audio_frame(self, frame: rtc.AudioFrame) -> None: + if self._rt_session is not None: + self._rt_session.push_audio(frame) + + if self._audio_recognition is not None: + self._audio_recognition.push_audio(frame) + + def _on_audio_end_of_turn(self, new_transcript: str) -> None: + # When the audio recognition detects the end of a user turn: + # - check if there is no current generation happening + # - cancel the current generation if it allows interruptions (otherwise skip this current + # turn) + # - generate a reply to the user input + + if self._current_speech is not None: + if self._current_speech.allow_interruptions: + logger.warning( + "skipping user input, current speech generation cannot be interrupted", + extra={"user_input": new_transcript}, + ) + return + + debug.Tracing.log_event( + "speech interrupted, new user turn detected", + {"speech_id": self._current_speech.id}, + ) + self._current_speech.interrupt() + + self.generate_reply(new_transcript) + + def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: + if ev.speech_duration > self._opts.min_interruption_duration: + if ( + self._current_speech is not None + and not self._current_speech.interrupted + and self._current_speech.allow_interruptions + ): + debug.Tracing.log_event( + "speech interrupted by vad", + {"speech_id": self._current_speech.id}, + ) + self._current_speech.interrupt() + + def _on_start_of_speech(self, _: vad.VADEvent) -> None: + self.emit("user_started_speaking", events.UserStartedSpeakingEvent()) + + def _on_end_of_speech(self, _: vad.VADEvent) -> None: + self.emit("user_stopped_speaking", events.UserStoppedSpeakingEvent()) + + + + @utils.log_exceptions(logger=logger) + async def _generate_pipeline_reply_task( + self, + *, + speech_handle: SpeechHandle, + ) -> None: + assert self._agent is not None, "agent is not set" + agent = self._agent + + @utils.log_exceptions(logger=logger) + async def _forward_llm_text(llm_output: AsyncIterable[str]) -> None: + """collect and forward the generated text to the current agent output""" + try: + async for delta in llm_output: + if agent.output.text is None: + break + + await agent.output.text.capture_text(delta) + finally: + if agent.output.text is not None: + agent.output.text.flush() + + @utils.log_exceptions(logger=logger) + async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: + """collect and forward the generated audio to the current agent output (generally playout)""" + try: + async for frame in tts_output: + if agent.output.audio is None: + break + await agent.output.audio.capture_frame(frame) + finally: + if agent.output.audio is not None: + agent.output.audio.flush() # always flush (even if the task is interrupted) + + @utils.log_exceptions(logger=logger) + async def _execute_tools( + tools_ch: utils.aio.Chan[llm.FunctionCallInfo], + called_functions: set[llm.CalledFunction], + ) -> None: + """execute tools, when cancelled, stop executing new tools but wait for the pending ones""" + try: + async for tool in tools_ch: + logger.debug( + "executing tool", + extra={ + "function": tool.function_info.name, + "speech_id": speech_handle.id, + }, + ) + debug.Tracing.log_event( + "executing tool", + { + "function": tool.function_info.name, + "speech_id": speech_handle.id, + }, + ) + cfnc = tool.execute() + called_functions.add(cfnc) + except asyncio.CancelledError: + # don't allow to cancel running function calla if they're still running + pending_tools = [cfn for cfn in called_functions if not cfn.task.done()] + + if pending_tools: + names = [cfn.call_info.function_info.name for cfn in pending_tools] + + logger.debug( + "waiting for function call to finish before cancelling", + extra={ + "functions": names, + "speech_id": speech_handle.id, + }, + ) + debug.Tracing.log_event( + "waiting for function call to finish before cancelling", + { + "functions": names, + "speech_id": speech_handle.id, + }, + ) + await asyncio.gather(*[cfn.task for cfn in pending_tools]) + finally: + if len(called_functions) > 0: + logger.debug( + "tools execution completed", + extra={"speech_id": speech_handle.id}, + ) + debug.Tracing.log_event( + "tools execution completed", + {"speech_id": speech_handle.id}, + ) + + debug.Tracing.log_event( + "generation started", + {"speech_id": speech_handle.id, "step_index": speech_handle.step_index}, + ) + + wg = utils.aio.WaitGroup() + tasks = [] + llm_task, llm_gen_data = do_llm_inference( + node=self.llm_node, + chat_ctx=self._chat_ctx, + fnc_ctx=( + self._fnc_ctx + if speech_handle.step_index < self._agent.options.max_fnc_steps - 1 + and speech_handle.step_index >= 2 + else None + ), + ) + tasks.append(llm_task) + wg.add(1) + llm_task.add_done_callback(lambda _: wg.done()) + tts_text_input, llm_output = utils.aio.itertools.tee(llm_gen_data.text_ch) + + tts_task: asyncio.Task | None = None + tts_gen_data: _TTSGenerationData | None = None + if self._agent.output.audio is not None: + tts_task, tts_gen_data = do_tts_inference( + node=self.tts_node, input=tts_text_input + ) + tasks.append(tts_task) + wg.add(1) + tts_task.add_done_callback(lambda _: wg.done()) + + # wait for the play() method to be called + await asyncio.wait( + [ + speech_handle._play_fut, + speech_handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, + ) + + if speech_handle.interrupted: + await utils.aio.gracefully_cancel(*tasks) + speech_handle._mark_done() + return # return directly (the generated output wasn't used) + + # forward tasks are started after the play() method is called + # they redirect the generated text/audio to the output channels + forward_llm_task = asyncio.create_task( + _forward_llm_text(llm_output), + name="_generate_reply_task.forward_llm_text", + ) + tasks.append(forward_llm_task) + wg.add(1) + forward_llm_task.add_done_callback(lambda _: wg.done()) + + forward_tts_task: asyncio.Task | None = None + if tts_gen_data is not None: + forward_tts_task = asyncio.create_task( + _forward_tts_audio(tts_gen_data.audio_ch), + name="_generate_reply_task.forward_tts_audio", + ) + tasks.append(forward_tts_task) + wg.add(1) + forward_tts_task.add_done_callback(lambda _: wg.done()) + + # start to execute tools (only after play()) + called_functions: set[llm.CalledFunction] = set() + tools_task = asyncio.create_task( + _execute_tools(llm_gen_data.tools_ch, called_functions), + name="_generate_reply_task.execute_tools", + ) + tasks.append(tools_task) + wg.add(1) + tools_task.add_done_callback(lambda _: wg.done()) + + # wait for the tasks to finish + await asyncio.wait( + [ + wg.wait(), + speech_handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, + ) + + # wait for the end of the playout if the audio is enabled + if forward_llm_task is not None and self._agent.output.audio is not None: + await asyncio.wait( + [ + self._agent.output.audio.wait_for_playout(), + speech_handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, + ) + + if speech_handle.interrupted: + await utils.aio.gracefully_cancel(*tasks) + + if len(called_functions) > 0: + functions = [ + cfnc.call_info.function_info.name for cfnc in called_functions + ] + logger.debug( + "speech interrupted, ignoring generation of the function calls results", + extra={"speech_id": speech_handle.id, "functions": functions}, + ) + debug.Tracing.log_event( + "speech interrupted, ignoring generation of the function calls results", + {"speech_id": speech_handle.id, "functions": functions}, + ) + + # if the audio playout was enabled, clear the buffer + if forward_tts_task is not None and self._agent.output.audio is not None: + self._agent.output.audio.clear_buffer() + playback_ev = await self._agent.output.audio.wait_for_playout() + + debug.Tracing.log_event( + "playout interrupted", + { + "playback_position": playback_ev.playback_position, + "speech_id": speech_handle.id, + }, + ) + + speech_handle._mark_playout_done() + # TODO(theomonnom): calculate the played text based on playback_ev.playback_position + + speech_handle._mark_done() + return + + speech_handle._mark_playout_done() + debug.Tracing.log_event("playout completed", {"speech_id": speech_handle.id}) + + if len(called_functions) > 0: + if speech_handle.step_index >= self._agent.options.max_fnc_steps: + logger.warning( + "maximum number of function calls steps reached", + extra={"speech_id": speech_handle.id}, + ) + debug.Tracing.log_event( + "maximum number of function calls steps reached", + {"speech_id": speech_handle.id}, + ) + speech_handle._mark_done() + return + + # create a new SpeechHandle to generate the result of the function calls + speech_handle = SpeechHandle.create( + allow_interruptions=speech_handle.allow_interruptions, + step_index=speech_handle.step_index + 1, + ) + task = asyncio.create_task( + self._generate_pipeline_reply_task( + handle=speech_handle, + chat_ctx=self._chat_ctx, + fnc_ctx=self._fnc_ctx, + ), + name="_generate_pipeline_reply", + ) + self._agent._schedule_speech( + speech_handle, task, PipelineAgent.SPEECH_PRIORITY_NORMAL + ) + + speech_handle._mark_done() diff --git a/livekit-agents/livekit/agents/pipeline/audio_recognition.py b/livekit-agents/livekit/agents/pipeline/audio_recognition.py index 9cb6306f5..7f12e6ef2 100644 --- a/livekit-agents/livekit/agents/pipeline/audio_recognition.py +++ b/livekit-agents/livekit/agents/pipeline/audio_recognition.py @@ -12,7 +12,7 @@ from . import io if TYPE_CHECKING: - from .pipeline2 import PipelineAgent + from .pipeline_agent import AgentTask class _TurnDetector(Protocol): @@ -46,24 +46,20 @@ class AudioRecognition(rtc.EventEmitter[EventTypes]): def __init__( self, *, - agent: PipelineAgent, - stt: io.STTNode, + task: AgentTask, + stt: io.STTNode | None, vad: vad.VAD | None, turn_detector: _TurnDetector | None, min_endpointing_delay: float, - chat_ctx: llm.ChatContext, - loop: asyncio.AbstractEventLoop, ) -> None: super().__init__() - self._agent = agent + self._agent_task = task self._audio_input_atask: asyncio.Task[None] | None = None self._stt_atask: asyncio.Task[None] | None = None self._vad_atask: asyncio.Task[None] | None = None self._end_of_turn_task: asyncio.Task[None] | None = None self._audio_input: io.AudioStream | None = None self._min_endpointing_delay = min_endpointing_delay - self._chat_ctx = chat_ctx - self._loop = loop self._stt = stt self._vad = vad self._turn_detector = turn_detector @@ -87,28 +83,36 @@ def start(self) -> None: self.update_stt(self._stt) self.update_vad(self._vad) - @property - def audio_input(self) -> io.AudioStream | None: - return self._audio_input + def stop(self) -> None: + self.update_stt(None) + self.update_vad(None) - @audio_input.setter - def audio_input(self, audio_input: io.AudioStream | None) -> None: - self._audio_input = audio_input - self.update_stt(self._stt) - self.update_vad(self._vad) + def push_audio(self, frame: rtc.AudioFrame) -> None: + if self._stt_ch is not None: + self._stt_ch.send_nowait(frame) - if self._audio_input and self._audio_input_atask is None: - self._audio_input_atask = asyncio.create_task( - self._audio_input_task(self._audio_input) - ) - elif self._audio_input_atask is not None: - self._audio_input_atask.cancel() - self._audio_input_atask = None + if self._vad_ch is not None: + self._vad_ch.send_nowait(frame) - async def aclose(self) -> None: - if self._audio_input_atask is not None: - await aio.gracefully_cancel(self._audio_input_atask) + # @property + # def audio_input(self) -> io.AudioStream | None: + # return self._audio_input + + # @audio_input.setter + # def audio_input(self, audio_input: io.AudioStream | None) -> None: + # self._audio_input = audio_input + # self.update_stt(self._stt) + # self.update_vad(self._vad) + # if self._audio_input and self._audio_input_atask is None: + # self._audio_input_atask = asyncio.create_task( + # self._audio_input_task(self._audio_input) + # ) + # elif self._audio_input_atask is not None: + # self._audio_input_atask.cancel() + # self._audio_input_atask = None + + async def aclose(self) -> None: if self._stt_atask is not None: await aio.gracefully_cancel(self._stt_atask) @@ -166,7 +170,7 @@ async def _on_stt_event(self, ev: stt.SpeechEvent) -> None: self._audio_transcript = self._audio_transcript.lstrip() if not self._speaking: - self._run_eou_detection(self._agent.chat_ctx) + self._run_eou_detection(self._agent_task.chat_ctx) elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: self.emit("interim_transcript", ev) @@ -187,14 +191,15 @@ async def _on_vad_event(self, ev: vad.VADEvent) -> None: self._speaking = False if not self._speaking: - self._run_eou_detection(self._agent.chat_ctx) + self._run_eou_detection(self._agent_task.chat_ctx) def _run_eou_detection(self, chat_ctx: llm.ChatContext) -> None: if not self._audio_transcript: return - chat_ctx = self._chat_ctx.copy() - chat_ctx.append(role="user", text=self._audio_transcript) + # TODO + # chat_ctx = self._agent._chat_ctx.copy() + # chat_ctx.append(role="user", text=self._audio_transcript) turn_detector = self._turn_detector @utils.log_exceptions(logger=logger) @@ -244,9 +249,9 @@ async def _stt_task( if isinstance(node, AsyncIterable): async for ev in node: - assert isinstance( - ev, stt.SpeechEvent - ), "STT node must yield SpeechEvent" + assert isinstance(ev, stt.SpeechEvent), ( + "STT node must yield SpeechEvent" + ) await self._on_stt_event(ev) async def _vad_task( @@ -269,11 +274,3 @@ async def _forward() -> None: finally: await stream.aclose() await aio.gracefully_cancel(forward_task) - - async def _audio_input_task(self, audio_input: io.AudioStream) -> None: - async for frame in audio_input: - if self._stt_ch is not None: - self._stt_ch.send_nowait(frame) - - if self._vad_ch is not None: - self._vad_ch.send_nowait(frame) diff --git a/livekit-agents/livekit/agents/pipeline/chat_cli.py b/livekit-agents/livekit/agents/pipeline/chat_cli.py index 7f197229c..d2c0cd9bb 100644 --- a/livekit-agents/livekit/agents/pipeline/chat_cli.py +++ b/livekit-agents/livekit/agents/pipeline/chat_cli.py @@ -15,7 +15,7 @@ from ..log import logger from ..utils import aio, log_exceptions from . import io -from .pipeline2 import PipelineAgent +from .pipeline_agent import PipelineAgent MAX_AUDIO_BAR = 30 INPUT_DB_MIN = -70.0 diff --git a/livekit-agents/livekit/agents/pipeline/generation.py b/livekit-agents/livekit/agents/pipeline/generation.py index 03b4e69ac..71224bee2 100644 --- a/livekit-agents/livekit/agents/pipeline/generation.py +++ b/livekit-agents/livekit/agents/pipeline/generation.py @@ -6,7 +6,7 @@ from livekit import rtc -from ..llm import ChatChunk, ChatContext, FunctionCallInfo, FunctionContext +from ..llm import ChatChunk, ChatContext, FunctionContext from ..utils import aio from . import io diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 1eb44a8a0..25d20597a 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -7,18 +7,16 @@ from typing import ( AsyncIterable, Literal, - Optional, Tuple, - Union, ) from livekit import rtc -from .. import debug, llm, stt, tokenize, tts, utils, vad +from .. import debug, llm, utils from ..llm import ChatContext, FunctionContext from ..log import logger -from . import events, io -from .audio_recognition import AudioRecognition, _TurnDetector +from . import io +from .agent_task import AgentTask from .generation import ( _TTSGenerationData, do_llm_inference, @@ -26,10 +24,6 @@ ) -class AgentContext: - pass - - class SpeechHandle: def __init__( self, *, speech_id: str, allow_interruptions: bool, step_index: int @@ -106,14 +100,18 @@ def _mark_done(self) -> None: @dataclass -class _PipelineOptions: - language: str | None +class PipelineOptions: allow_interruptions: bool min_interruption_duration: float min_endpointing_delay: float max_fnc_steps: int +class AgentContext: + def __init__(self) -> None: + pass + + class PipelineAgent(rtc.EventEmitter[EventTypes]): SPEECH_PRIORITY_LOW = 0 """Priority for messages that should be played after all other messages in the queue""" @@ -125,14 +123,7 @@ class PipelineAgent(rtc.EventEmitter[EventTypes]): def __init__( self, *, - llm: llm.LLM | None = None, - vad: vad.VAD | None = None, - stt: stt.STT | None = None, - tts: tts.TTS | None = None, - turn_detector: _TurnDetector | None = None, - language: str | None = None, - chat_ctx: ChatContext | None = None, - fnc_ctx: FunctionContext | None = None, + task: AgentTask, allow_interruptions: bool = True, min_interruption_duration: float = 0.5, min_endpointing_delay: float = 0.5, @@ -142,54 +133,19 @@ def __init__( super().__init__() self._loop = loop or asyncio.get_event_loop() - self._chat_ctx = chat_ctx or ChatContext() - self._fnc_ctx = fnc_ctx - - self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts - - if tts and not tts.capabilities.streaming: - from .. import tts as text_to_speech - - tts = text_to_speech.StreamAdapter( - tts=tts, sentence_tokenizer=tokenize.basic.SentenceTokenizer() - ) - - if stt and not stt.capabilities.streaming: - from .. import stt as speech_to_text - - if vad is None: - raise ValueError( - "VAD is required when streaming is not supported by the STT" - ) - - stt = speech_to_text.StreamAdapter( - stt=stt, - vad=vad, - ) - - self._turn_detector = turn_detector - self._audio_recognition = AudioRecognition( - agent=self, - stt=self.stt_node, - vad=vad, - turn_detector=turn_detector, - min_endpointing_delay=min_endpointing_delay, - chat_ctx=self._chat_ctx, - loop=self._loop, - ) - - self._opts = _PipelineOptions( - language=language, + # This is the "global" chat_context, it holds the entire conversation history + self._chat_ctx = task.chat_ctx.copy() + self._opts = PipelineOptions( allow_interruptions=allow_interruptions, min_interruption_duration=min_interruption_duration, min_endpointing_delay=min_endpointing_delay, max_fnc_steps=max_fnc_steps, ) - self._audio_recognition.on("end_of_turn", self._on_audio_end_of_turn) - self._audio_recognition.on("start_of_speech", self._on_start_of_speech) - self._audio_recognition.on("end_of_speech", self._on_end_of_speech) - self._audio_recognition.on("vad_inference_done", self._on_vad_inference_done) + # self._audio_recognition.on("end_of_turn", self._on_audio_end_of_turn) + # self._audio_recognition.on("start_of_speech", self._on_start_of_speech) + # self._audio_recognition.on("end_of_speech", self._on_end_of_speech) + # self._audio_recognition.on("vad_inference_done", self._on_vad_inference_done) # configurable IO self._input = io.AgentInput( @@ -201,76 +157,29 @@ def __init__( self._on_text_output_changed, ) + # speech state self._current_speech: SpeechHandle | None = None self._speech_q: list[Tuple[int, SpeechHandle]] = [] self._speech_q_changed = asyncio.Event() self._speech_tasks = [] - - self._speech_scheduler_task: asyncio.Task | None = None + self._speech_scheduler_atask: asyncio.Task | None = None # -- Pipeline nodes -- # They can all be overriden by subclasses, by default they use the STT/LLM/TTS specified in the # constructor of the PipelineAgent - async def stt_node( - self, audio: AsyncIterable[rtc.AudioFrame] - ) -> Optional[AsyncIterable[stt.SpeechEvent]]: - assert self._stt is not None, "stt_node called but no STT node is available" - - async with self._stt.stream() as stream: - - async def _forward_input(): - async for frame in audio: - stream.push_frame(frame) - - forward_task = asyncio.create_task(_forward_input()) - try: - async for event in stream: - yield event - finally: - await utils.aio.gracefully_cancel(forward_task) - - async def llm_node( - self, chat_ctx: llm.ChatContext, fnc_ctx: llm.FunctionContext | None - ) -> Union[ - Optional[AsyncIterable[llm.ChatChunk]], - Optional[AsyncIterable[str]], - Optional[str], - ]: - assert self._llm is not None, "llm_node called but no LLM node is available" - - async with self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) as stream: - async for chunk in stream: - yield chunk - - async def tts_node( - self, text: AsyncIterable[str] - ) -> Optional[AsyncIterable[rtc.AudioFrame]]: - assert self._tts is not None, "tts_node called but no TTS node is available" - - async with self._tts.stream() as stream: - - async def _forward_input(): - async for chunk in text: - stream.push_text(chunk) - - stream.end_input() - - forward_task = asyncio.create_task(_forward_input()) - try: - async for ev in stream: - yield ev.frame - finally: - await utils.aio.gracefully_cancel(forward_task) - def start(self) -> None: - self._audio_recognition.start() - self._speech_scheduler_task = asyncio.create_task( - self._playout_scheduler(), name="_playout_scheduler" + self._speech_scheduler_atask = asyncio.create_task( + self._speech_scheduler_task(), name="_playout_scheduler_task" ) async def aclose(self) -> None: - await self._audio_recognition.aclose() + if self._speech_scheduler_atask is not None: + await utils.aio.gracefully_cancel(self._speech_scheduler_atask) + + @property + def options(self) -> PipelineOptions: + return self._opts def emit(self, event: EventTypes, *args) -> None: debug.Tracing.log_event(f'agent.on("{event}")') @@ -284,7 +193,6 @@ def input(self) -> io.AgentInput: def output(self) -> io.AgentOutput: return self._output - # TODO(theomonnom): find a better name than `generation` @property def current_speech(self) -> SpeechHandle | None: return self._current_speech @@ -297,14 +205,20 @@ def update_options(self) -> None: pass def say(self, text: str | AsyncIterable[str]) -> SpeechHandle: - pass + raise NotImplementedError() def generate_reply(self, user_input: str) -> SpeechHandle: if self._current_speech is not None and not self._current_speech.interrupted: raise ValueError("another reply is already in progress") debug.Tracing.log_event("generate_reply", {"user_input": user_input}) - self._chat_ctx.append(role="user", text=user_input) # TODO(theomonnom) Remove + + # TODO(theomonnom): move to _generate_pipeline_reply_task + self._chat_ctx.items.append( + llm.ChatItem.create( + [llm.ChatMessage.create(role="user", content=user_input)] + ) + ) handle = SpeechHandle.create(allow_interruptions=self._opts.allow_interruptions) task = asyncio.create_task( @@ -330,7 +244,7 @@ def _schedule_speech( self._speech_q_changed.set() @utils.log_exceptions(logger=logger) - async def _playout_scheduler(self) -> None: + async def _speech_scheduler_task(self) -> None: while True: await self._speech_q_changed.wait() @@ -343,304 +257,49 @@ async def _playout_scheduler(self) -> None: self._speech_q_changed.clear() - @utils.log_exceptions(logger=logger) - async def _generate_pipeline_reply_task( - self, - *, - handle: SpeechHandle, - chat_ctx: ChatContext, - fnc_ctx: FunctionContext | None, - ) -> None: - @utils.log_exceptions(logger=logger) - async def _forward_llm_text(llm_output: AsyncIterable[str]) -> None: - """collect and forward the generated text to the current agent output""" - if self.output.text is None: - return - - try: - async for delta in llm_output: - await self.output.text.capture_text(delta) - finally: - self.output.text.flush() - - @utils.log_exceptions(logger=logger) - async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: - """collect and forward the generated audio to the current agent output (generally playout)""" - if self.output.audio is None: - return - - try: - async for frame in tts_output: - await self.output.audio.capture_frame(frame) - finally: - self.output.audio.flush() # always flush (even if the task is interrupted) - - @utils.log_exceptions(logger=logger) - async def _execute_tools( - tools_ch: utils.aio.Chan[llm.FunctionCallInfo], - called_functions: set[llm.CalledFunction], - ) -> None: - """execute tools, when cancelled, stop executing new tools but wait for the pending ones""" - try: - async for tool in tools_ch: - logger.debug( - "executing tool", - extra={ - "function": tool.function_info.name, - "speech_id": handle.id, - }, - ) - debug.Tracing.log_event( - "executing tool", - { - "function": tool.function_info.name, - "speech_id": handle.id, - }, - ) - cfnc = tool.execute() - called_functions.add(cfnc) - except asyncio.CancelledError: - # don't allow to cancel running function calla if they're still running - pending_tools = [cfn for cfn in called_functions if not cfn.task.done()] - - if pending_tools: - names = [cfn.call_info.function_info.name for cfn in pending_tools] - - logger.debug( - "waiting for function call to finish before cancelling", - extra={ - "functions": names, - "speech_id": handle.id, - }, - ) - debug.Tracing.log_event( - "waiting for function call to finish before cancelling", - { - "functions": names, - "speech_id": handle.id, - }, - ) - await asyncio.gather(*[cfn.task for cfn in pending_tools]) - finally: - if len(called_functions) > 0: - logger.debug( - "tools execution completed", - extra={"speech_id": handle.id}, - ) - debug.Tracing.log_event( - "tools execution completed", - {"speech_id": handle.id}, - ) - - debug.Tracing.log_event( - "generation started", - {"speech_id": handle.id, "step_index": handle.step_index}, - ) - - wg = utils.aio.WaitGroup() - tasks = [] - llm_task, llm_gen_data = do_llm_inference( - node=self.llm_node, - chat_ctx=chat_ctx, - fnc_ctx=( - fnc_ctx - if handle.step_index < self._opts.max_fnc_steps - 1 - and handle.step_index >= 2 - else None - ), - ) - tasks.append(llm_task) - wg.add(1) - llm_task.add_done_callback(lambda _: wg.done()) - tts_text_input, llm_output = utils.aio.itertools.tee(llm_gen_data.text_ch) - - tts_task: asyncio.Task | None = None - tts_gen_data: _TTSGenerationData | None = None - if self._output.audio is not None: - tts_task, tts_gen_data = do_tts_inference( - node=self.tts_node, input=tts_text_input - ) - tasks.append(tts_task) - wg.add(1) - tts_task.add_done_callback(lambda _: wg.done()) - - # wait for the play() method to be called - await asyncio.wait( - [ - handle._play_fut, - handle._interrupt_fut, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - - if handle.interrupted: - await utils.aio.gracefully_cancel(*tasks) - handle._mark_done() - return # return directly (the generated output wasn't used) - - # forward tasks are started after the play() method is called - # they redirect the generated text/audio to the output channels - forward_llm_task = asyncio.create_task( - _forward_llm_text(llm_output), - name="_generate_reply_task.forward_llm_text", - ) - tasks.append(forward_llm_task) - wg.add(1) - forward_llm_task.add_done_callback(lambda _: wg.done()) - - forward_tts_task: asyncio.Task | None = None - if tts_gen_data is not None: - forward_tts_task = asyncio.create_task( - _forward_tts_audio(tts_gen_data.audio_ch), - name="_generate_reply_task.forward_tts_audio", - ) - tasks.append(forward_tts_task) - wg.add(1) - forward_tts_task.add_done_callback(lambda _: wg.done()) - - # start to execute tools (only after play()) - called_functions: set[llm.CalledFunction] = set() - tools_task = asyncio.create_task( - _execute_tools(llm_gen_data.tools_ch, called_functions), - name="_generate_reply_task.execute_tools", - ) - tasks.append(tools_task) - wg.add(1) - tools_task.add_done_callback(lambda _: wg.done()) - - # wait for the tasks to finish - await asyncio.wait( - [ - wg.wait(), - handle._interrupt_fut, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - - # wait for the end of the playout if the audio is enabled - if forward_llm_task is not None: - assert self._output.audio is not None - await asyncio.wait( - [ - self._output.audio.wait_for_playout(), - handle._interrupt_fut, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - - if handle.interrupted: - await utils.aio.gracefully_cancel(*tasks) - - if len(called_functions) > 0: - functions = [ - cfnc.call_info.function_info.name for cfnc in called_functions - ] - logger.debug( - "speech interrupted, ignoring generation of the function calls results", - extra={"speech_id": handle.id, "functions": functions}, - ) - debug.Tracing.log_event( - "speech interrupted, ignoring generation of the function calls results", - {"speech_id": handle.id, "functions": functions}, - ) - - # if the audio playout was enabled, clear the buffer - if forward_tts_task is not None: - assert self._output.audio is not None - - self._output.audio.clear_buffer() - playback_ev = await self._output.audio.wait_for_playout() - - debug.Tracing.log_event( - "playout interrupted", - { - "playback_position": playback_ev.playback_position, - "speech_id": handle.id, - }, - ) - - handle._mark_playout_done() - # TODO(theomonnom): calculate the played text based on playback_ev.playback_position - - handle._mark_done() - return - - handle._mark_playout_done() - debug.Tracing.log_event("playout completed", {"speech_id": handle.id}) - - if len(called_functions) > 0: - if handle.step_index >= self._opts.max_fnc_steps: - logger.warning( - "maximum number of function calls steps reached", - extra={"speech_id": handle.id}, - ) - debug.Tracing.log_event( - "maximum number of function calls steps reached", - {"speech_id": handle.id}, - ) - handle._mark_done() - return - - # create a new SpeechHandle to generate the result of the function calls - handle = SpeechHandle.create( - allow_interruptions=self._opts.allow_interruptions, - step_index=handle.step_index + 1, - ) - task = asyncio.create_task( - self._generate_pipeline_reply_task( - handle=handle, - chat_ctx=chat_ctx, - fnc_ctx=fnc_ctx, - ), - name="_generate_pipeline_reply", - ) - self._schedule_speech(handle, task, self.SPEECH_PRIORITY_NORMAL) - - handle._mark_done() - # -- Audio recognition -- - def _on_audio_end_of_turn(self, new_transcript: str) -> None: - # When the audio recognition detects the end of a user turn: - # - check if there is no current generation happening - # - cancel the current generation if it allows interruptions (otherwise skip this current - # turn) - # - generate a reply to the user input - - if self._current_speech is not None: - if self._current_speech.allow_interruptions: - logger.warning( - "skipping user input, current speech generation cannot be interrupted", - extra={"user_input": new_transcript}, - ) - return - - debug.Tracing.log_event( - "speech interrupted, new user turn detected", - {"speech_id": self._current_speech.id}, - ) - self._current_speech.interrupt() - - self.generate_reply(new_transcript) - - def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: - if ev.speech_duration > self._opts.min_interruption_duration: - if ( - self._current_speech is not None - and not self._current_speech.interrupted - and self._current_speech.allow_interruptions - ): - debug.Tracing.log_event( - "speech interrupted by vad", - {"speech_id": self._current_speech.id}, - ) - self._current_speech.interrupt() - - def _on_start_of_speech(self, _: vad.VADEvent) -> None: - self.emit("user_started_speaking", events.UserStartedSpeakingEvent()) - - def _on_end_of_speech(self, _: vad.VADEvent) -> None: - self.emit("user_stopped_speaking", events.UserStoppedSpeakingEvent()) + # def _on_audio_end_of_turn(self, new_transcript: str) -> None: + # # When the audio recognition detects the end of a user turn: + # # - check if there is no current generation happening + # # - cancel the current generation if it allows interruptions (otherwise skip this current + # # turn) + # # - generate a reply to the user input + + # if self._current_speech is not None: + # if self._current_speech.allow_interruptions: + # logger.warning( + # "skipping user input, current speech generation cannot be interrupted", + # extra={"user_input": new_transcript}, + # ) + # return + + # debug.Tracing.log_event( + # "speech interrupted, new user turn detected", + # {"speech_id": self._current_speech.id}, + # ) + # self._current_speech.interrupt() + + # self.generate_reply(new_transcript) + + # def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: + # if ev.speech_duration > self._opts.min_interruption_duration: + # if ( + # self._current_speech is not None + # and not self._current_speech.interrupted + # and self._current_speech.allow_interruptions + # ): + # debug.Tracing.log_event( + # "speech interrupted by vad", + # {"speech_id": self._current_speech.id}, + # ) + # self._current_speech.interrupt() + + # def _on_start_of_speech(self, _: vad.VADEvent) -> None: + # self.emit("user_started_speaking", events.UserStartedSpeakingEvent()) + + # def _on_end_of_speech(self, _: vad.VADEvent) -> None: + # self.emit("user_stopped_speaking", events.UserStoppedSpeakingEvent()) # --- @@ -650,7 +309,7 @@ def _on_video_input_changed(self) -> None: pass def _on_audio_input_changed(self) -> None: - self._audio_recognition.audio_input = self._input.audio + pass def _on_video_output_changed(self) -> None: pass diff --git a/livekit-agents/livekit/agents/pipeline/speech_scheduler.py b/livekit-agents/livekit/agents/pipeline/speech_scheduler.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/speech_scheduler.py @@ -0,0 +1 @@ + diff --git a/livekit-agents/livekit/agents/utils/aio/channel.py b/livekit-agents/livekit/agents/utils/aio/channel.py index 6c0d4ff5a..5279d9715 100644 --- a/livekit-agents/livekit/agents/utils/aio/channel.py +++ b/livekit-agents/livekit/agents/utils/aio/channel.py @@ -47,7 +47,9 @@ async def __anext__(self) -> T_co: ... class Chan(Generic[T]): def __init__( - self, maxsize: int = 0, loop: asyncio.AbstractEventLoop | None = None + self, + maxsize: int = 0, + loop: asyncio.AbstractEventLoop | None = None, ) -> None: self._loop = loop or asyncio.get_event_loop() self._maxsize = max(maxsize, 0) diff --git a/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py b/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py index 9f674717d..6a5ec46db 100644 --- a/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py +++ b/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py @@ -74,9 +74,9 @@ async def _run(self) -> None: "The last message in the chat context must be from the user" ) - assert isinstance( - user_msg.content, str - ), "user message content must be a string" + assert isinstance(user_msg.content, str), ( + "user message content must be a string" + ) try: if not self._stream: diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index 28ed56247..716db3ed9 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -6,7 +6,8 @@ from livekit import rtc from livekit.agents import llm, multimodal, utils -from livekit.agents.llm.function_context import _create_ai_function_info +from livekit.agents.llm.function_context import build_legacy_openai_schema +from pydantic import ValidationError import openai from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection @@ -31,6 +32,8 @@ ResponseDoneEvent, ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, + SessionUpdateEvent, + session_update_event, ) from .log import logger @@ -46,7 +49,7 @@ # 8. response.output_item.done (contains item_status: "completed/incomplete") # 9. response.done (contains status_details for cancelled/failed/turn_detected/content_filter) # -# Ourcode assumes a response will generate only one item with type "message." +# Ourcode assumes a response will generate only one item with type "message" SAMPLE_RATE = 24000 @@ -64,7 +67,7 @@ class _ResponseGeneration: item_id: str audio_ch: utils.aio.Chan[rtc.AudioFrame] text_ch: utils.aio.Chan[str] - tool_calls_ch: utils.aio.Chan[llm.FunctionCallInfo] + function_ch: utils.aio.Chan[llm.FunctionCall] class RealtimeModel(multimodal.RealtimeModel): @@ -92,7 +95,7 @@ def __init__(self, realtime_model: RealtimeModel) -> None: super().__init__(realtime_model) self._realtime_model = realtime_model self._chat_ctx = llm.ChatContext() - self._fnc_ctx: llm.FunctionContext | None = None + self._fnc_ctx = llm.FunctionContext.empty() self._msg_ch = utils.aio.Chan[RealtimeClientEvent]() self._conn: AsyncRealtimeConnection | None = None @@ -167,7 +170,7 @@ def _handle_response_created(self, event: ResponseCreatedEvent) -> None: item_id="", audio_ch=utils.aio.Chan(), text_ch=utils.aio.Chan(), - tool_calls_ch=utils.aio.Chan(), + function_ch=utils.aio.Chan(), ) def _handle_response_output_item_added( @@ -187,6 +190,16 @@ def _handle_response_output_item_added( self._current_generation.item_id = item_id + self.emit( + "generation_created", + multimodal.GenerationCreatedEvent( + message_id=item_id, + text_stream=self._current_generation.text_ch, + audio_stream=self._current_generation.audio_ch, + function_stream=self._current_generation.function_ch, + ), + ) + def _handle_response_audio_transcript_delta( self, event: ResponseAudioTranscriptDeltaEvent ) -> None: @@ -221,9 +234,9 @@ def _handle_response_output_item_done( item = event.item if item.type == "function_call": - if self.fnc_ctx is None: + if len(self.fnc_ctx.ai_functions) == 0: logger.warning( - "received a function_call item without a function context", + "received a function_call item without ai functions", extra={"item": item}, ) return @@ -232,17 +245,17 @@ def _handle_response_output_item_done( assert item.name is not None, "name is None" assert item.arguments is not None, "arguments is None" - fnc_call_info = _create_ai_function_info( - self.fnc_ctx, - item.call_id, - item.name, - item.arguments, + self._current_generation.function_ch.send_nowait( + llm.FunctionCall( + call_id=item.call_id, + name=item.name, + arguments=item.arguments, + ) ) - self._current_generation.tool_calls_ch.send_nowait(fnc_call_info) def _handle_response_done(self, _: ResponseDoneEvent) -> None: assert self._current_generation is not None, "current_generation is None" - self._current_generation.tool_calls_ch.close() + # self._current_generation.tool_calls_ch.close() self._current_generation = None def _handle_error(self, event: ErrorEvent) -> None: @@ -286,11 +299,45 @@ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: # TODO(theomonnom): wait for the server confirmation @property - def fnc_ctx(self) -> llm.FunctionContext | None: + def fnc_ctx(self) -> llm.FunctionContext: return self._fnc_ctx - async def update_fnc_ctx(self, fnc_ctx: llm.FunctionContext | None) -> None: - pass + async def update_fnc_ctx( + self, fnc_ctx: llm.FunctionContext | list[llm.AIFunction] + ) -> None: + if isinstance(fnc_ctx, list): + fnc_ctx = llm.FunctionContext(fnc_ctx) + + tools: list[session_update_event.SessionTool] = [] + retained_functions: list[llm.AIFunction] = [] + + for ai_fnc in fnc_ctx.ai_functions.values(): + tool_desc = build_legacy_openai_schema(ai_fnc, internally_tagged=True) + try: + session_tool = session_update_event.SessionTool.model_validate( + tool_desc + ) + tools.append(session_tool) + retained_functions.append(ai_fnc) + except ValidationError: + logger.error( + "OpenAI Realtime API doesn't support this tool", + extra={"tool": tool_desc}, + ) + continue + + self._msg_ch.send_nowait( + SessionUpdateEvent( + type="session.update", + session=session_update_event.Session( + model=self._realtime_model._opts.model, # type: ignore (str -> Literal) + tools=tools, + ), + ) + ) + + # TODO(theomonnom): wait for the server confirmation before updating the local state + self._fnc_ctx = llm.FunctionContext(retained_functions) def push_audio(self, frame: rtc.AudioFrame) -> None: self._msg_ch.send_nowait( @@ -321,28 +368,24 @@ async def aclose(self) -> None: await self._conn.close() -def _chat_item_to_conversation_item(msg: llm.ChatItem) -> ConversationItem: - if not msg.content: - raise ValueError("ChatItem has no content") - - item = msg.content[0] +def _chat_item_to_conversation_item(item: llm.ChatItem) -> ConversationItem: conversation_item = ConversationItem( - id=msg.id, + id=item.id, object="realtime.item", ) - if isinstance(item, llm.FunctionCall): + if item.type == "function_call": conversation_item.type = "function_call" conversation_item.call_id = item.call_id conversation_item.name = item.name conversation_item.arguments = item.arguments - elif isinstance(item, llm.FunctionCallOutput): + elif item.type == "function_call_output": conversation_item.type = "function_call_output" conversation_item.call_id = item.call_id conversation_item.output = item.output - elif isinstance(item, llm.ChatMessage): + elif item.type == "message": role = "system" if item.role == "developer" else item.role conversation_item.type = "message" conversation_item.role = role @@ -375,7 +418,4 @@ def _chat_item_to_conversation_item(msg: llm.ChatItem) -> ConversationItem: conversation_item.content = content_list - else: - raise ValueError(f"Unsupported ChatItem content: {item}") - return conversation_item diff --git a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py index ee6d27599..1604a99a6 100644 --- a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py +++ b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py @@ -86,7 +86,9 @@ async def entrypoint(ctx: JobContext): if __name__ == "__main__": - cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm)) + cli.run_app( + WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm) + ) ``` Args: diff --git a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py index 041918be9..f7a808159 100644 --- a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py +++ b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py @@ -190,9 +190,9 @@ async def predict_end_of_turn( timeout=timeout, ) - assert ( - result is not None - ), "end_of_utterance prediction should always returns a result" + assert result is not None, ( + "end_of_utterance prediction should always returns a result" + ) result_json = json.loads(result.decode()) return result_json["eou_probability"] diff --git a/tests/test_create_func.py b/tests/test_create_func.py index a81d31d93..202a4e4f2 100644 --- a/tests/test_create_func.py +++ b/tests/test_create_func.py @@ -16,9 +16,9 @@ def test_fn( pass fnc_ctx = TestFunctionContext() - assert ( - "test_function" in fnc_ctx.ai_functions - ), "Function should be registered in ai_functions" + assert "test_function" in fnc_ctx.ai_functions, ( + "Function should be registered in ai_functions" + ) fnc_info = fnc_ctx.ai_functions["test_function"] build_info = _oai_api.build_oai_function_description(fnc_info) @@ -69,9 +69,9 @@ def test_fn(self): pass fnc_ctx = TestFunctionContext() - assert ( - "test_fn" in fnc_ctx.ai_functions - ), "Function should be registered in ai_functions" + assert "test_fn" in fnc_ctx.ai_functions, ( + "Function should be registered in ai_functions" + ) assert fnc_ctx.ai_functions["test_fn"].description == "A simple test function" @@ -92,9 +92,9 @@ def optional_fn( pass fnc_ctx = TestFunctionContext() - assert ( - "optional_function" in fnc_ctx.ai_functions - ), "Function should be registered in ai_functions" + assert "optional_function" in fnc_ctx.ai_functions, ( + "Function should be registered in ai_functions" + ) fnc_info = fnc_ctx.ai_functions["optional_function"] build_info = _oai_api.build_oai_function_description(fnc_info) @@ -159,9 +159,9 @@ def list_fn( pass fnc_ctx = TestFunctionContext() - assert ( - "list_function" in fnc_ctx.ai_functions - ), "Function should be registered in ai_functions" + assert "list_function" in fnc_ctx.ai_functions, ( + "Function should be registered in ai_functions" + ) fnc_info = fnc_ctx.ai_functions["list_function"] build_info = _oai_api.build_oai_function_description(fnc_info) @@ -202,9 +202,9 @@ def enum_fn( pass fnc_ctx = TestFunctionContext() - assert ( - "enum_function" in fnc_ctx.ai_functions - ), "Function should be registered in ai_functions" + assert "enum_function" in fnc_ctx.ai_functions, ( + "Function should be registered in ai_functions" + ) fnc_info = fnc_ctx.ai_functions["enum_function"] build_info = _oai_api.build_oai_function_description(fnc_info) diff --git a/tests/test_ipc.py b/tests/test_ipc.py index 4e1fd4fe7..645808827 100644 --- a/tests/test_ipc.py +++ b/tests/test_ipc.py @@ -354,9 +354,9 @@ async def test_shutdown_no_job(): assert proc.exitcode == 0 assert not proc.killed - assert ( - start_args.shutdown_counter.value == 0 - ), "shutdown_cb isn't called when there is no job" + assert start_args.shutdown_counter.value == 0, ( + "shutdown_cb isn't called when there is no job" + ) async def test_job_slow_shutdown(): diff --git a/tests/test_llm.py b/tests/test_llm.py index 4b71c0324..974f70f38 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -194,9 +194,9 @@ async def test_cancelled_calls(llm_factory: Callable[[], llm.LLM]): await stream.aclose() assert len(calls) == 1 - assert isinstance( - calls[0].exception, asyncio.CancelledError - ), "toggle_light should have been cancelled" + assert isinstance(calls[0].exception, asyncio.CancelledError), ( + "toggle_light should have been cancelled" + ) @pytest.mark.parametrize("llm_factory", LLMS) @@ -219,9 +219,9 @@ async def test_calls_arrays(llm_factory: Callable[[], llm.LLM]): call = calls[0] currencies = call.call_info.arguments["currencies"] assert len(currencies) == 3, "select_currencies should have 3 currencies" - assert ( - "eur" in currencies and "gbp" in currencies and "sek" in currencies - ), "select_currencies should have eur, gbp, sek" + assert "eur" in currencies and "gbp" in currencies and "sek" in currencies, ( + "select_currencies should have eur, gbp, sek" + ) @pytest.mark.parametrize("llm_factory", LLMS) @@ -341,9 +341,9 @@ async def test_tool_choice_options( if tool_choice == "none" and isinstance(input_llm, anthropic.LLM): assert True else: - assert ( - call_names == expected_calls - ), f"Test '{description}' failed: Expected calls {expected_calls}, but got {call_names}" + assert call_names == expected_calls, ( + f"Test '{description}' failed: Expected calls {expected_calls}, but got {call_names}" + ) async def _request_fnc_call( diff --git a/tests/test_stt.py b/tests/test_stt.py index d1f340b1e..9a80aa4be 100644 --- a/tests/test_stt.py +++ b/tests/test_stt.py @@ -99,9 +99,9 @@ async def _stream_output(): async for event in stream: if event.type == agents.stt.SpeechEventType.START_OF_SPEECH: - assert ( - recv_end - ), "START_OF_SPEECH recv but no END_OF_SPEECH has been sent before" + assert recv_end, ( + "START_OF_SPEECH recv but no END_OF_SPEECH has been sent before" + ) assert not recv_start recv_end = False recv_start = True diff --git a/tests/test_tts.py b/tests/test_tts.py index 91f8035b5..f9ac4f4c9 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -37,9 +37,9 @@ async def _assert_valid_synthesized_audio( merged_frame = merge_frames(frames) assert merged_frame.sample_rate == tts.sample_rate, "sample rate should be the same" - assert ( - merged_frame.num_channels == tts.num_channels - ), "num channels should be the same" + assert merged_frame.num_channels == tts.num_channels, ( + "num channels should be the same" + ) SYNTHESIZE_TTS: list[Callable[[], tts.TTS]] = [ From a715b64db4038bb04594677655b0bee1e20f3989 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Mon, 20 Jan 2025 15:04:24 +0100 Subject: [PATCH 12/19] multimodal support & ActiveTask --- .../livekit/agents/pipeline/agent_task.py | 310 +++++++++++++++--- .../agents/pipeline/audio_recognition.py | 72 ++-- .../livekit/agents/pipeline/pipeline_agent.py | 217 +++--------- .../livekit/agents/pipeline/speech_handle.py | 70 ++++ livekit-agents/livekit/agents/stt/stt.py | 6 +- 5 files changed, 396 insertions(+), 279 deletions(-) create mode 100644 livekit-agents/livekit/agents/pipeline/speech_handle.py diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py index 8705b1138..f50b5728f 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_task.py +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -9,18 +9,17 @@ from livekit import rtc -from .. import llm, multimodal, stt, tokenize, tts, utils, vad, debug +from .. import debug, llm, multimodal, stt, tokenize, tts, utils, vad from ..llm import ChatContext, FunctionContext, find_ai_functions from ..log import logger from .agent_task import AgentTask -from .audio_recognition import AudioRecognition, _TurnDetector -from .pipeline_agent import PipelineAgent, SpeechHandle +from .audio_recognition import AudioRecognition, RecognitionHooks, _TurnDetector from .generation import ( + _TTSGenerationData, do_llm_inference, do_tts_inference, - _TTSGenerationData, - _LLMGenerationData, ) +from .pipeline_agent import PipelineAgent, SpeechHandle class AgentTask: @@ -66,9 +65,7 @@ def __init__( self._turn_detector = turn_detector self._stt, self._llm, self._tts, self._vad = stt, llm, tts, vad - self._agent: PipelineAgent | None = None - self._rt_session: multimodal.RealtimeSession | None = None - self._audio_recognition: AudioRecognition | None = None + self._active_task: ActiveTask | None = None @property def instructions(self) -> str: @@ -82,6 +79,30 @@ def instructions(self, instructions: str) -> None: def chat_ctx(self) -> llm.ChatContext: return self._chat_ctx + @property + def fnc_ctx(self) -> llm.FunctionContext: + return self._fnc_ctx + + @property + def turn_detector(self) -> _TurnDetector | None: + return self._turn_detector + + @property + def stt(self) -> stt.STT | None: + return self._stt + + @property + def llm(self) -> llm.LLM | multimodal.RealtimeModel | None: + return self._llm + + @property + def tts(self) -> tts.TTS | None: + return self._tts + + @property + def vad(self) -> vad.VAD | None: + return self._vad + async def stt_node( self, audio: AsyncIterable[rtc.AudioFrame] ) -> Optional[AsyncIterable[stt.SpeechEvent]]: @@ -137,49 +158,138 @@ async def _forward_input(): await utils.aio.gracefully_cancel(forward_task) async def _on_start(self, agent: PipelineAgent) -> None: + if self._active_task is not None: + raise RuntimeError("task is already active") + + self._active_task = ActiveTask(task=self, agent=agent) + + async def _on_close(self) -> None: + if self._active_task is not None: + await self._active_task.aclose() + + +class ActiveTask(RecognitionHooks): + def __init__(self, task: AgentTask, agent: PipelineAgent) -> None: + self._task, self._agent = task, agent + self._rt_session: multimodal.RealtimeSession | None = None + self._audio_recognition: AudioRecognition | None = None + + async def start(self) -> None: if self._rt_session is not None: logger.warning("starting a new task while rt_session is not None") self._audio_recognition = AudioRecognition( - task=self, - stt=self.stt_node, - vad=self._vad, - turn_detector=self._turn_detector, - min_endpointing_delay=agent.options.min_endpointing_delay, + hooks=self, + stt=self._task.stt_node, + vad=self._task.vad, + turn_detector=self._task.turn_detector, + min_endpointing_delay=self._agent.options.min_endpointing_delay, ) self._audio_recognition.start() - if isinstance(self._llm, multimodal.RealtimeModel): - self._rt_session = self._llm.session() + if isinstance(self._task.llm, multimodal.RealtimeModel): + self._rt_session = self._task.llm.session() self._rt_session.on("generation_created", self._on_generation_created) - await self._rt_session.update_chat_ctx(self._chat_ctx) + await self._rt_session.update_chat_ctx(self._task.chat_ctx) - async def _on_close(self) -> None: + async def aclose(self) -> None: if self._rt_session is not None: await self._rt_session.aclose() if self._audio_recognition is not None: await self._audio_recognition.aclose() - def _on_generation_created(self, ev: multimodal.GenerationCreatedEvent) -> None: - pass - - def _on_input_audio_frame(self, frame: rtc.AudioFrame) -> None: + def push_audio(self, frame: rtc.AudioFrame) -> None: if self._rt_session is not None: self._rt_session.push_audio(frame) if self._audio_recognition is not None: self._audio_recognition.push_audio(frame) - def _on_audio_end_of_turn(self, new_transcript: str) -> None: + def generate_reply(self, user_input: str) -> SpeechHandle: + if ( + self._agent.current_speech is not None + and not self._agent.current_speech.interrupted + ): + raise ValueError("another reply is already in progress") + + debug.Tracing.log_event("generate_reply", {"user_input": user_input}) + + # TODO(theomonnom): move to _generate_pipeline_reply_task + # self._chat_ctx.items.append( + # llm.ChatItem.create( + # [llm.ChatMessage.create(role="user", content=user_input)] + # ) + # ) + + handle = SpeechHandle.create( + allow_interruptions=self._agent.options.allow_interruptions + ) + task = asyncio.create_task( + self._pipeline_reply_task( + handle=handle, + chat_ctx=self._task.chat_ctx, + fnc_ctx=self._task.fnc_ctx, + ), + name="_pipeline_reply_task", + ) + self._agent._schedule_speech(handle, PipelineAgent.SPEECH_PRIORITY_NORMAL) + return handle + + # -- Realtime Session events -- + + def _on_generation_created(self, ev: multimodal.GenerationCreatedEvent) -> None: + handle = SpeechHandle.create( + allow_interruptions=self._agent.options.allow_interruptions + ) + task = asyncio.create_task( + self._pipeline_reply_task( + handle=handle, + chat_ctx=self._task.chat_ctx, + fnc_ctx=self._task.fnc_ctx, + ), + name="_generate_pipeline_reply", + ) + self._agent._schedule_speech(handle, PipelineAgent.SPEECH_PRIORITY_NORMAL) + + # -- Recognition Hooks -- + + def on_start_of_speech(self, ev: vad.VADEvent) -> None: + pass + # self.emit("user_started_speaking", events.UserStartedSpeakingEvent()) + + def on_end_of_speech(self, ev: vad.VADEvent) -> None: + pass + # self.emit("user_stopped_speaking", events.UserStoppedSpeakingEvent()) + + def on_vad_inference_done(self, ev: vad.VADEvent) -> None: + if ev.speech_duration > self._agent.options.min_interruption_duration: + if ( + self._agent.current_speech is not None + and not self._agent.current_speech.interrupted + and self._agent.current_speech.allow_interruptions + ): + debug.Tracing.log_event( + "speech interrupted by vad", + {"speech_id": self._agent.current_speech.id}, + ) + self._agent.current_speech.interrupt() + + def on_interim_transcript(self, ev: stt.SpeechEvent) -> None: + pass + + def on_final_transcript(self, ev: stt.SpeechEvent) -> None: + pass + + def on_end_of_turn(self, new_transcript: str) -> None: # When the audio recognition detects the end of a user turn: # - check if there is no current generation happening # - cancel the current generation if it allows interruptions (otherwise skip this current # turn) # - generate a reply to the user input - if self._current_speech is not None: - if self._current_speech.allow_interruptions: + if self._agent.current_speech is not None: + if self._agent.current_speech.allow_interruptions: logger.warning( "skipping user input, current speech generation cannot be interrupted", extra={"user_input": new_transcript}, @@ -188,42 +298,23 @@ def _on_audio_end_of_turn(self, new_transcript: str) -> None: debug.Tracing.log_event( "speech interrupted, new user turn detected", - {"speech_id": self._current_speech.id}, + {"speech_id": self._agent.current_speech.id}, ) - self._current_speech.interrupt() + self._agent.current_speech.interrupt() self.generate_reply(new_transcript) - def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: - if ev.speech_duration > self._opts.min_interruption_duration: - if ( - self._current_speech is not None - and not self._current_speech.interrupted - and self._current_speech.allow_interruptions - ): - debug.Tracing.log_event( - "speech interrupted by vad", - {"speech_id": self._current_speech.id}, - ) - self._current_speech.interrupt() - - def _on_start_of_speech(self, _: vad.VADEvent) -> None: - self.emit("user_started_speaking", events.UserStartedSpeakingEvent()) - - def _on_end_of_speech(self, _: vad.VADEvent) -> None: - self.emit("user_stopped_speaking", events.UserStoppedSpeakingEvent()) - + def retrieve_chat_ctx(self) -> llm.ChatContext: + return self._task.chat_ctx + # --- @utils.log_exceptions(logger=logger) - async def _generate_pipeline_reply_task( + async def _pipeline_reply_task( self, *, speech_handle: SpeechHandle, ) -> None: - assert self._agent is not None, "agent is not set" - agent = self._agent - @utils.log_exceptions(logger=logger) async def _forward_llm_text(llm_output: AsyncIterable[str]) -> None: """collect and forward the generated text to the current agent output""" @@ -392,7 +483,7 @@ async def _execute_tools( ) # wait for the end of the playout if the audio is enabled - if forward_llm_task is not None and self._agent.output.audio is not None: + if forward_tts_task is not None and self._agent.output.audio is not None: await asyncio.wait( [ self._agent.output.audio.wait_for_playout(), @@ -430,9 +521,9 @@ async def _execute_tools( }, ) - speech_handle._mark_playout_done() # TODO(theomonnom): calculate the played text based on playback_ev.playback_position + speech_handle._mark_playout_done() speech_handle._mark_done() return @@ -458,15 +549,126 @@ async def _execute_tools( step_index=speech_handle.step_index + 1, ) task = asyncio.create_task( - self._generate_pipeline_reply_task( + self._pipeline_reply_task( handle=speech_handle, chat_ctx=self._chat_ctx, fnc_ctx=self._fnc_ctx, ), - name="_generate_pipeline_reply", + name="_pipeline_fnc_reply_task", ) self._agent._schedule_speech( - speech_handle, task, PipelineAgent.SPEECH_PRIORITY_NORMAL + speech_handle, PipelineAgent.SPEECH_PRIORITY_NORMAL + ) + + speech_handle._mark_done() + + @utils.log_exceptions(logger=logger) + async def _realtime_reply_task( + self, + speech_handle: SpeechHandle, + generation_ev: multimodal.GenerationCreatedEvent, + ) -> None: + assert self._rt_session is not None, "rt_session is not available" + + @utils.log_exceptions(logger=logger) + async def _forward_text(llm_output: AsyncIterable[str]) -> None: + try: + async for delta in llm_output: + if self._agent.output.text is None: + break + + await self._agent.output.text.capture_text(delta) + finally: + if self._agent.output.text is not None: + self._agent.output.text.flush() + + @utils.log_exceptions(logger=logger) + async def _forward_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: + try: + async for frame in tts_output: + if self._agent.output.audio is None: + break + await self._agent.output.audio.capture_frame(frame) + finally: + if self._agent.output.audio is not None: + self._agent.output.audio.flush() # always flush (even if the task is interrupted) + + # wait for the play() method to be called + await asyncio.wait( + [ + speech_handle._play_fut, + speech_handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, + ) + + if speech_handle.interrupted: + self._rt_session.interrupt() + speech_handle._mark_done() + + # TODO(theomonnom): Remove the message + return + + wg = utils.aio.WaitGroup() + tasks: list[asyncio.Task] = [] + + forward_text_task = asyncio.create_task( + _forward_text(generation_ev.text_stream), + name="_realtime_reply_task.forward_text", + ) + wg.add(1) + forward_text_task.add_done_callback(lambda _: wg.done()) + tasks.append(forward_text_task) + + forward_audio_task = asyncio.create_task( + _forward_audio(generation_ev.audio_stream), + name="_realtime_reply_task.forward_audio", + ) + wg.add(1) + forward_audio_task.add_done_callback(lambda _: wg.done()) + tasks.append(forward_text_task) + + await asyncio.wait( + [ + wg.wait(), + speech_handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, + ) + + if forward_audio_task is not None and self._agent.output.audio is not None: + await asyncio.wait( + [ + self._agent.output.audio.wait_for_playout(), + speech_handle._interrupt_fut, + ], + return_when=asyncio.FIRST_COMPLETED, ) + if speech_handle.interrupted: + self._rt_session.interrupt() + await utils.aio.gracefully_cancel(*tasks) + + if self._agent.output.audio is not None: + self._agent.output.audio.clear_buffer() + playback_ev = await self._agent.output.audio.wait_for_playout() + + debug.Tracing.log_event( + "playout interrupted", + { + "playback_position": playback_ev.playback_position, + "speech_id": speech_handle.id, + }, + ) + + # TODO(theomonnom): truncate serverside message + speech_handle._mark_playout_done() + speech_handle._mark_done() + return + + # TODO(theomonnom): tools + + speech_handle._mark_playout_done() + debug.Tracing.log_event("playout completed", {"speech_id": speech_handle.id}) + speech_handle._mark_done() diff --git a/livekit-agents/livekit/agents/pipeline/audio_recognition.py b/livekit-agents/livekit/agents/pipeline/audio_recognition.py index 7f12e6ef2..792427510 100644 --- a/livekit-agents/livekit/agents/pipeline/audio_recognition.py +++ b/livekit-agents/livekit/agents/pipeline/audio_recognition.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, AsyncIterable, Literal, Protocol +from typing import AsyncIterable, Protocol from livekit import rtc @@ -11,9 +11,6 @@ from ..utils import aio from . import io -if TYPE_CHECKING: - from .pipeline_agent import AgentTask - class _TurnDetector(Protocol): # TODO: Move those two functions to EOU ctor (capabilities dataclass) @@ -23,46 +20,39 @@ def supports_language(self, language: str | None) -> bool: ... async def predict_end_of_turn(self, chat_ctx: llm.ChatContext) -> float: ... -EventTypes = Literal[ - "start_of_speech", - "vad_inference_done", - "end_of_speech", - "interim_transcript", - "final_transcript", - "end_of_turn", -] - +class RecognitionHooks(Protocol): + def on_start_of_speech(self, ev: vad.VADEvent) -> None: ... + def on_vad_inference_done(self, ev: vad.VADEvent) -> None: ... + def on_end_of_speech(self, ev: vad.VADEvent) -> None: ... + def on_interim_transcript(self, ev: stt.SpeechEvent) -> None: ... + def on_final_transcript(self, ev: stt.SpeechEvent) -> None: ... + def on_end_of_turn(self, new_transcript: str) -> None: ... -class AudioRecognition(rtc.EventEmitter[EventTypes]): - """ - Audio recognition part of the PipelineAgent. - The class is always instantiated but no tasks may be running if STT/VAD is disabled + def retrieve_chat_ctx(self) -> llm.ChatContext: ... - This class is also responsible for the end of turn detection. - """ +class AudioRecognition: UNLIKELY_END_OF_TURN_EXTRA_DELAY = 6.0 def __init__( self, *, - task: AgentTask, + hooks: RecognitionHooks, stt: io.STTNode | None, vad: vad.VAD | None, turn_detector: _TurnDetector | None, min_endpointing_delay: float, ) -> None: - super().__init__() - self._agent_task = task + self._hooks = hooks self._audio_input_atask: asyncio.Task[None] | None = None self._stt_atask: asyncio.Task[None] | None = None self._vad_atask: asyncio.Task[None] | None = None self._end_of_turn_task: asyncio.Task[None] | None = None self._audio_input: io.AudioStream | None = None self._min_endpointing_delay = min_endpointing_delay + self._turn_detector = turn_detector self._stt = stt self._vad = vad - self._turn_detector = turn_detector self._speaking = False self._audio_transcript = "" @@ -94,24 +84,6 @@ def push_audio(self, frame: rtc.AudioFrame) -> None: if self._vad_ch is not None: self._vad_ch.send_nowait(frame) - # @property - # def audio_input(self) -> io.AudioStream | None: - # return self._audio_input - - # @audio_input.setter - # def audio_input(self, audio_input: io.AudioStream | None) -> None: - # self._audio_input = audio_input - # self.update_stt(self._stt) - # self.update_vad(self._vad) - - # if self._audio_input and self._audio_input_atask is None: - # self._audio_input_atask = asyncio.create_task( - # self._audio_input_task(self._audio_input) - # ) - # elif self._audio_input_atask is not None: - # self._audio_input_atask.cancel() - # self._audio_input_atask = None - async def aclose(self) -> None: if self._stt_atask is not None: await aio.gracefully_cancel(self._stt_atask) @@ -148,7 +120,7 @@ def update_vad(self, vad: vad.VAD | None) -> None: async def _on_stt_event(self, ev: stt.SpeechEvent) -> None: if ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: - self.emit("final_transcript", ev) + self._hooks.on_final_transcript(ev) transcript = ev.alternatives[0].text if not transcript: return @@ -170,13 +142,14 @@ async def _on_stt_event(self, ev: stt.SpeechEvent) -> None: self._audio_transcript = self._audio_transcript.lstrip() if not self._speaking: - self._run_eou_detection(self._agent_task.chat_ctx) + chat_ctx = self._hooks.retrieve_chat_ctx().copy() + self._run_eou_detection(chat_ctx) elif ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: - self.emit("interim_transcript", ev) + self._hooks.on_interim_transcript(ev) async def _on_vad_event(self, ev: vad.VADEvent) -> None: if ev.type == vad.VADEventType.START_OF_SPEECH: - self.emit("start_of_speech", ev) + self._hooks.on_start_of_speech(ev) self._speaking = True if self._end_of_turn_task is not None: @@ -184,14 +157,15 @@ async def _on_vad_event(self, ev: vad.VADEvent) -> None: elif ev.type == vad.VADEventType.INFERENCE_DONE: self._vad_graph.plot(ev.timestamp, ev.probability) - self.emit("vad_inference_done", ev) + self._hooks.on_vad_inference_done(ev) elif ev.type == vad.VADEventType.END_OF_SPEECH: - self.emit("end_of_speech", ev) + self._hooks.on_end_of_speech(ev) self._speaking = False if not self._speaking: - self._run_eou_detection(self._agent_task.chat_ctx) + chat_ctx = self._hooks.retrieve_chat_ctx().copy() + self._run_eou_detection(chat_ctx) def _run_eou_detection(self, chat_ctx: llm.ChatContext) -> None: if not self._audio_transcript: @@ -223,7 +197,7 @@ async def _bounce_eou_task() -> None: tracing.Tracing.log_event( "end of user turn", {"transcript": self._audio_transcript} ) - self.emit("end_of_turn", self._audio_transcript) + self._hooks.on_end_of_turn(self._audio_transcript) self._audio_transcript = "" if self._end_of_turn_task is not None: diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 25d20597a..7c43ab3d1 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -1,7 +1,6 @@ from __future__ import annotations, print_function import asyncio -import contextlib import heapq from dataclasses import dataclass from typing import ( @@ -13,80 +12,10 @@ from livekit import rtc from .. import debug, llm, utils -from ..llm import ChatContext, FunctionContext from ..log import logger from . import io -from .agent_task import AgentTask -from .generation import ( - _TTSGenerationData, - do_llm_inference, - do_tts_inference, -) - - -class SpeechHandle: - def __init__( - self, *, speech_id: str, allow_interruptions: bool, step_index: int - ) -> None: - self._id = speech_id - self._step_index = step_index - self._allow_interruptions = allow_interruptions - self._interrupt_fut = asyncio.Future() - self._done_fut = asyncio.Future() - self._play_fut = asyncio.Future() - self._playout_done_fut = asyncio.Future() - - @staticmethod - def create(allow_interruptions: bool = True, step_index: int = 0) -> SpeechHandle: - return SpeechHandle( - speech_id=utils.shortuuid("speech_"), - allow_interruptions=allow_interruptions, - step_index=step_index, - ) - - @property - def id(self) -> str: - return self._id - - @property - def step_index(self) -> int: - return self._step_index - - @property - def interrupted(self) -> bool: - return self._interrupt_fut.done() - - @property - def allow_interruptions(self) -> bool: - return self._allow_interruptions - - def play(self) -> None: - self._play_fut.set_result(None) - - def done(self) -> bool: - return self._done_fut.done() - - def interrupt(self) -> None: - if not self._allow_interruptions: - raise ValueError("This generation handle does not allow interruptions") - - if self.done(): - return - - self._done_fut.set_result(None) - self._interrupt_fut.set_result(None) - - async def wait_for_playout(self) -> None: - await asyncio.shield(self._playout_done_fut) - - def _mark_playout_done(self) -> None: - self._playout_done_fut.set_result(None) - - def _mark_done(self) -> None: - with contextlib.suppress(asyncio.InvalidStateError): - # will raise InvalidStateError if the future is already done (interrupted) - self._done_fut.set_result(None) - +from .agent_task import ActiveTask, AgentTask +from .speech_handle import SpeechHandle EventTypes = Literal[ "user_started_speaking", @@ -123,7 +52,7 @@ class PipelineAgent(rtc.EventEmitter[EventTypes]): def __init__( self, *, - task: AgentTask, + task: AgentTask, # TODO(theomonnom): move this, pretty sure there will be complaints about this lol allow_interruptions: bool = True, min_interruption_duration: float = 0.5, min_endpointing_delay: float = 0.5, @@ -142,11 +71,6 @@ def __init__( max_fnc_steps=max_fnc_steps, ) - # self._audio_recognition.on("end_of_turn", self._on_audio_end_of_turn) - # self._audio_recognition.on("start_of_speech", self._on_start_of_speech) - # self._audio_recognition.on("end_of_speech", self._on_end_of_speech) - # self._audio_recognition.on("vad_inference_done", self._on_vad_inference_done) - # configurable IO self._input = io.AgentInput( self._on_video_input_changed, self._on_audio_input_changed @@ -161,21 +85,23 @@ def __init__( self._current_speech: SpeechHandle | None = None self._speech_q: list[Tuple[int, SpeechHandle]] = [] self._speech_q_changed = asyncio.Event() - self._speech_tasks = [] - self._speech_scheduler_atask: asyncio.Task | None = None + + self._main_atask: asyncio.Task | None = None + + # agent tasks + self._current_task: AgentTask = task + self._active_task: ActiveTask # -- Pipeline nodes -- # They can all be overriden by subclasses, by default they use the STT/LLM/TTS specified in the # constructor of the PipelineAgent def start(self) -> None: - self._speech_scheduler_atask = asyncio.create_task( - self._speech_scheduler_task(), name="_playout_scheduler_task" - ) + self._main_atask = asyncio.create_task(self._main_task(), name="_main_task") async def aclose(self) -> None: - if self._speech_scheduler_atask is not None: - await utils.aio.gracefully_cancel(self._speech_scheduler_atask) + if self._main_atask is not None: + await utils.aio.gracefully_cancel(self._main_atask) @property def options(self) -> PipelineOptions: @@ -197,6 +123,10 @@ def output(self) -> io.AgentOutput: def current_speech(self) -> SpeechHandle | None: return self._current_speech + @property + def current_task(self) -> AgentTask: + return self._current_task + @property def chat_ctx(self) -> llm.ChatContext: return self._chat_ctx @@ -208,100 +138,41 @@ def say(self, text: str | AsyncIterable[str]) -> SpeechHandle: raise NotImplementedError() def generate_reply(self, user_input: str) -> SpeechHandle: - if self._current_speech is not None and not self._current_speech.interrupted: - raise ValueError("another reply is already in progress") - - debug.Tracing.log_event("generate_reply", {"user_input": user_input}) - - # TODO(theomonnom): move to _generate_pipeline_reply_task - self._chat_ctx.items.append( - llm.ChatItem.create( - [llm.ChatMessage.create(role="user", content=user_input)] - ) - ) - - handle = SpeechHandle.create(allow_interruptions=self._opts.allow_interruptions) - task = asyncio.create_task( - self._generate_pipeline_reply_task( - handle=handle, - chat_ctx=self._chat_ctx, - fnc_ctx=self._fnc_ctx, - ), - name="_generate_pipeline_reply", - ) - self._schedule_speech(handle, task, self.SPEECH_PRIORITY_NORMAL) - return handle - - # -- Main generation task -- + raise NotImplementedError() - def _schedule_speech( - self, speech: SpeechHandle, task: asyncio.Task, priority: int - ) -> None: - self._speech_tasks.append(task) - task.add_done_callback(lambda _: self._speech_tasks.remove(task)) + def _update_task(self, task: AgentTask) -> None: + pass + def _schedule_speech(self, speech: SpeechHandle, priority: int) -> None: heapq.heappush(self._speech_q, (priority, speech)) self._speech_q_changed.set() @utils.log_exceptions(logger=logger) - async def _speech_scheduler_task(self) -> None: - while True: - await self._speech_q_changed.wait() - - while self._speech_q: - _, speech = heapq.heappop(self._speech_q) - self._current_speech = speech - speech.play() - await speech.wait_for_playout() - self._current_speech = None - - self._speech_q_changed.clear() - - # -- Audio recognition -- - - # def _on_audio_end_of_turn(self, new_transcript: str) -> None: - # # When the audio recognition detects the end of a user turn: - # # - check if there is no current generation happening - # # - cancel the current generation if it allows interruptions (otherwise skip this current - # # turn) - # # - generate a reply to the user input - - # if self._current_speech is not None: - # if self._current_speech.allow_interruptions: - # logger.warning( - # "skipping user input, current speech generation cannot be interrupted", - # extra={"user_input": new_transcript}, - # ) - # return - - # debug.Tracing.log_event( - # "speech interrupted, new user turn detected", - # {"speech_id": self._current_speech.id}, - # ) - # self._current_speech.interrupt() - - # self.generate_reply(new_transcript) - - # def _on_vad_inference_done(self, ev: vad.VADEvent) -> None: - # if ev.speech_duration > self._opts.min_interruption_duration: - # if ( - # self._current_speech is not None - # and not self._current_speech.interrupted - # and self._current_speech.allow_interruptions - # ): - # debug.Tracing.log_event( - # "speech interrupted by vad", - # {"speech_id": self._current_speech.id}, - # ) - # self._current_speech.interrupt() - - # def _on_start_of_speech(self, _: vad.VADEvent) -> None: - # self.emit("user_started_speaking", events.UserStartedSpeakingEvent()) - - # def _on_end_of_speech(self, _: vad.VADEvent) -> None: - # self.emit("user_stopped_speaking", events.UserStoppedSpeakingEvent()) + async def _main_task(self) -> None: + @utils.log_exceptions(logger=logger) + async def _speech_scheduling_task() -> None: + while True: + await self._speech_q_changed.wait() + + while self._speech_q: + _, speech = heapq.heappop(self._speech_q) + self._current_speech = speech + speech._authorize_playout() + await speech.wait_for_playout() + self._current_speech = None + + self._speech_q_changed.clear() + + tasks = [ + asyncio.create_task( + _speech_scheduling_task(), name="_speech_scheduling_task" + ) + ] - # --- + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) # -- User changed input/output streams/sinks -- diff --git a/livekit-agents/livekit/agents/pipeline/speech_handle.py b/livekit-agents/livekit/agents/pipeline/speech_handle.py new file mode 100644 index 000000000..eece0141a --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/speech_handle.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import asyncio +import contextlib + +from .. import utils + + +class SpeechHandle: + def __init__( + self, *, speech_id: str, allow_interruptions: bool, step_index: int + ) -> None: + self._id = speech_id + self._step_index = step_index + self._allow_interruptions = allow_interruptions + self._interrupt_fut = asyncio.Future() + self._done_fut = asyncio.Future() + self._play_fut = asyncio.Future() + self._playout_done_fut = asyncio.Future() + + @staticmethod + def create(allow_interruptions: bool = True, step_index: int = 0) -> SpeechHandle: + return SpeechHandle( + speech_id=utils.shortuuid("speech_"), + allow_interruptions=allow_interruptions, + step_index=step_index, + ) + + @property + def id(self) -> str: + return self._id + + @property + def step_index(self) -> int: + return self._step_index + + @property + def interrupted(self) -> bool: + return self._interrupt_fut.done() + + @property + def allow_interruptions(self) -> bool: + return self._allow_interruptions + + def _authorize_playout(self) -> None: + self._play_fut.set_result(None) + + def done(self) -> bool: + return self._done_fut.done() + + def interrupt(self) -> None: + if not self._allow_interruptions: + raise ValueError("This generation handle does not allow interruptions") + + if self.done(): + return + + self._done_fut.set_result(None) + self._interrupt_fut.set_result(None) + + async def wait_for_playout(self) -> None: + await asyncio.shield(self._playout_done_fut) + + def _mark_playout_done(self) -> None: + self._playout_done_fut.set_result(None) + + def _mark_done(self) -> None: + with contextlib.suppress(asyncio.InvalidStateError): + # will raise InvalidStateError if the future is already done (interrupted) + self._done_fut.set_result(None) diff --git a/livekit-agents/livekit/agents/stt/stt.py b/livekit-agents/livekit/agents/stt/stt.py index e2f79f93c..4bed361c3 100644 --- a/livekit-agents/livekit/agents/stt/stt.py +++ b/livekit-agents/livekit/agents/stt/stt.py @@ -242,9 +242,9 @@ async def _metrics_monitor_task( async for ev in event_aiter: if ev.type == SpeechEventType.RECOGNITION_USAGE: - assert ( - ev.recognition_usage is not None - ), "recognition_usage must be provided for RECOGNITION_USAGE event" + assert ev.recognition_usage is not None, ( + "recognition_usage must be provided for RECOGNITION_USAGE event" + ) duration = time.perf_counter() - start_time stt_metrics = STTMetrics( From d4d430aabc54ff7ab3a1a4116641a44877570e51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 21 Jan 2025 15:31:51 +0100 Subject: [PATCH 13/19] wip (currently segfaulting) --- livekit-agents/livekit/agents/llm/utils.py | 8 +- .../livekit/agents/pipeline/__init__.py | 4 +- .../livekit/agents/pipeline/agent_task.py | 186 +++++++++++++----- .../livekit/agents/pipeline/chat_cli.py | 34 ++-- .../livekit/agents/pipeline/pipeline_agent.py | 105 +++++----- .../livekit/agents/pipeline/speech_handle.py | 31 +-- .../livekit/plugins/openai/_oai_api.py | 1 - .../livekit/plugins/openai/llm.py | 2 +- .../plugins/openai/realtime/__init__.py | 24 --- .../plugins/openai/realtime/realtime_model.py | 2 +- 10 files changed, 242 insertions(+), 155 deletions(-) diff --git a/livekit-agents/livekit/agents/llm/utils.py b/livekit-agents/livekit/agents/llm/utils.py index 3e53ce0d1..a6ff16172 100644 --- a/livekit-agents/livekit/agents/llm/utils.py +++ b/livekit-agents/livekit/agents/llm/utils.py @@ -49,15 +49,15 @@ def compute_chat_ctx_diff(old_ctx: ChatContext, new_ctx: ChatContext) -> DiffOps """Computes the minimal list of create/remove operations to transform old_ctx into new_ctx.""" # TODO(theomonnom): Make ChatMessage hashable and also add update ops - old_ids = [m.id for m in old_ctx.messages] - new_ids = [m.id for m in new_ctx.messages] + old_ids = [m.id for m in old_ctx.items] + new_ids = [m.id for m in new_ctx.items] lcs_ids = set(_compute_lcs(old_ids, new_ids)) - to_remove = [msg.id for msg in old_ctx.messages if msg.id not in lcs_ids] + to_remove = [msg.id for msg in old_ctx.items if msg.id not in lcs_ids] to_create: list[tuple[str | None, str]] = [] last_id_in_sequence: str | None = None - for new_msg in new_ctx.messages: + for new_msg in new_ctx.items: if new_msg.id in lcs_ids: last_id_in_sequence = new_msg.id else: diff --git a/livekit-agents/livekit/agents/pipeline/__init__.py b/livekit-agents/livekit/agents/pipeline/__init__.py index fed344c1f..1c92d54c2 100644 --- a/livekit-agents/livekit/agents/pipeline/__init__.py +++ b/livekit-agents/livekit/agents/pipeline/__init__.py @@ -1,4 +1,6 @@ from .chat_cli import ChatCLI from .pipeline_agent import PipelineAgent +from .agent_task import AgentTask +from .speech_handle import SpeechHandle -__all__ = ["ChatCLI", "PipelineAgent"] +__all__ = ["ChatCLI", "PipelineAgent", "AgentTask", "SpeechHandle"] diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py index f50b5728f..0b024c381 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_task.py +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -1,6 +1,8 @@ from __future__ import annotations +import time import asyncio +import heapq from typing import ( AsyncIterable, Optional, @@ -12,14 +14,18 @@ from .. import debug, llm, multimodal, stt, tokenize, tts, utils, vad from ..llm import ChatContext, FunctionContext, find_ai_functions from ..log import logger -from .agent_task import AgentTask from .audio_recognition import AudioRecognition, RecognitionHooks, _TurnDetector from .generation import ( _TTSGenerationData, do_llm_inference, do_tts_inference, ) -from .pipeline_agent import PipelineAgent, SpeechHandle +from typing import TYPE_CHECKING + +from .speech_handle import SpeechHandle + +if TYPE_CHECKING: + from .pipeline_agent import PipelineAgent class AgentTask: @@ -65,8 +71,6 @@ def __init__( self._turn_detector = turn_detector self._stt, self._llm, self._tts, self._vad = stt, llm, tts, vad - self._active_task: ActiveTask | None = None - @property def instructions(self) -> str: return self._instructions @@ -157,15 +161,8 @@ async def _forward_input(): finally: await utils.aio.gracefully_cancel(forward_task) - async def _on_start(self, agent: PipelineAgent) -> None: - if self._active_task is not None: - raise RuntimeError("task is already active") - - self._active_task = ActiveTask(task=self, agent=agent) - - async def _on_close(self) -> None: - if self._active_task is not None: - await self._active_task.aclose() + def _create_activity(self, agent: PipelineAgent) -> ActiveTask: + return ActiveTask(task=self, agent=agent) class ActiveTask(RecognitionHooks): @@ -173,33 +170,68 @@ def __init__(self, task: AgentTask, agent: PipelineAgent) -> None: self._task, self._agent = task, agent self._rt_session: multimodal.RealtimeSession | None = None self._audio_recognition: AudioRecognition | None = None + self._lock = asyncio.Lock() + + self._done_fut = asyncio.Future() + self._draining = False + + self._current_speech: SpeechHandle | None = None + self._speech_q: list[tuple[int, float, SpeechHandle]] = [] + self._speech_q_changed = asyncio.Event() + + self._main_atask: asyncio.Task | None = None + self._tasks: list[asyncio.Task] = [] + self._started = False + + @property + def draining(self) -> bool: + return self._draining + + async def drain(self) -> None: + self._speech_q_changed.set() # TODO(theomonnom): refactor so we don't need this here + self._draining = True + + if self._main_atask is not None: + await asyncio.shield(self._main_atask) async def start(self) -> None: - if self._rt_session is not None: - logger.warning("starting a new task while rt_session is not None") - - self._audio_recognition = AudioRecognition( - hooks=self, - stt=self._task.stt_node, - vad=self._task.vad, - turn_detector=self._task.turn_detector, - min_endpointing_delay=self._agent.options.min_endpointing_delay, - ) - self._audio_recognition.start() + async with self._lock: + self._main_atask = asyncio.create_task(self._main_task(), name="_main_task") + + self._audio_recognition = AudioRecognition( + hooks=self, + stt=self._task.stt_node, + vad=self._task.vad, + turn_detector=self._task.turn_detector, + min_endpointing_delay=self._agent.options.min_endpointing_delay, + ) + self._audio_recognition.start() + + if isinstance(self._task.llm, multimodal.RealtimeModel): + self._rt_session = self._task.llm.session() + self._rt_session.on("generation_created", self._on_generation_created) + self._rt_session.on( + "input_speech_started", self._on_input_speech_started + ) + await self._rt_session.update_chat_ctx(self._task.chat_ctx) - if isinstance(self._task.llm, multimodal.RealtimeModel): - self._rt_session = self._task.llm.session() - self._rt_session.on("generation_created", self._on_generation_created) - await self._rt_session.update_chat_ctx(self._task.chat_ctx) + self._started = True async def aclose(self) -> None: - if self._rt_session is not None: - await self._rt_session.aclose() + async with self._lock: + if self._rt_session is not None: + await self._rt_session.aclose() - if self._audio_recognition is not None: - await self._audio_recognition.aclose() + if self._audio_recognition is not None: + await self._audio_recognition.aclose() + + if self._main_atask is not None: + await utils.aio.gracefully_cancel(self._main_atask) def push_audio(self, frame: rtc.AudioFrame) -> None: + if not self._started: + return + if self._rt_session is not None: self._rt_session.push_audio(frame) @@ -233,24 +265,71 @@ def generate_reply(self, user_input: str) -> SpeechHandle: ), name="_pipeline_reply_task", ) - self._agent._schedule_speech(handle, PipelineAgent.SPEECH_PRIORITY_NORMAL) + self._tasks.append(task) + task.add_done_callback(lambda _: self._tasks.remove(task)) + self._schedule_speech(handle, SpeechHandle.SPEECH_PRIORITY_NORMAL) return handle + def interrupt(self) -> None: + if self._current_speech is not None: + self._current_speech.interrupt() + + for speech in self._speech_q: + _, _, speech = speech + speech.interrupt() + + def _schedule_speech(self, speech: SpeechHandle, priority: int) -> None: + if self._draining: + raise RuntimeError("cannot schedule new speech, task is draining") + + heapq.heappush(self._speech_q, (priority, time.time(), speech)) + self._speech_q_changed.set() + + @utils.log_exceptions(logger=logger) + async def _main_task(self) -> None: + try: + while True: + await self._speech_q_changed.wait() + while self._speech_q: + _, _, speech = heapq.heappop(self._speech_q) + self._current_speech = speech + speech._authorize_playout() + await speech.wait_for_playout() + self._current_speech = None + + if self._draining: # no more speech can be scheduled + break + + self._speech_q_changed.clear() + finally: + await asyncio.gather(*self._tasks) + debug.Tracing.log_event(f"task done, waiting for {len(self._tasks)} tasks") + debug.Tracing.log_event("marking agent task as done") + self._done_fut.set_result(None) + # -- Realtime Session events -- + def _on_input_speech_started(self, ev: multimodal.InputSpeechStartedEvent) -> None: + self.interrupt() + def _on_generation_created(self, ev: multimodal.GenerationCreatedEvent) -> None: + if self.draining: + logger.warning("skipping new generation, task is draining") + debug.Tracing.log_event("skipping new generation, task is draining") + return + handle = SpeechHandle.create( allow_interruptions=self._agent.options.allow_interruptions ) task = asyncio.create_task( - self._pipeline_reply_task( - handle=handle, - chat_ctx=self._task.chat_ctx, - fnc_ctx=self._task.fnc_ctx, + self._realtime_reply_task( + speech_handle=handle, + generation_ev=ev, ), - name="_generate_pipeline_reply", ) - self._agent._schedule_speech(handle, PipelineAgent.SPEECH_PRIORITY_NORMAL) + self._tasks.append(task) + task.add_done_callback(lambda _: self._tasks.remove(task)) + self._schedule_speech(handle, SpeechHandle.SPEECH_PRIORITY_NORMAL) # -- Recognition Hooks -- @@ -302,6 +381,17 @@ def on_end_of_turn(self, new_transcript: str) -> None: ) self._agent.current_speech.interrupt() + if self.draining: + logger.warning( + "skipping user input, task is draining", + extra={"user_input": new_transcript}, + ) + debug.Tracing.log_event( + "skipping user input, task is draining", + {"user_input": new_transcript}, + ) + return + self.generate_reply(new_transcript) def retrieve_chat_ctx(self) -> llm.ChatContext: @@ -557,7 +647,7 @@ async def _execute_tools( name="_pipeline_fnc_reply_task", ) self._agent._schedule_speech( - speech_handle, PipelineAgent.SPEECH_PRIORITY_NORMAL + speech_handle, SpeechHandle.SPEECH_PRIORITY_NORMAL ) speech_handle._mark_done() @@ -565,6 +655,7 @@ async def _execute_tools( @utils.log_exceptions(logger=logger) async def _realtime_reply_task( self, + *, speech_handle: SpeechHandle, generation_ev: multimodal.GenerationCreatedEvent, ) -> None: @@ -596,17 +687,15 @@ async def _forward_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: # wait for the play() method to be called await asyncio.wait( [ - speech_handle._play_fut, + speech_handle._wait_for_authorization(), speech_handle._interrupt_fut, ], return_when=asyncio.FIRST_COMPLETED, ) if speech_handle.interrupted: - self._rt_session.interrupt() - speech_handle._mark_done() - - # TODO(theomonnom): Remove the message + speech_handle._mark_playout_done() + # TODO(theomonnom): remove the message from the serverside history return wg = utils.aio.WaitGroup() @@ -646,11 +735,11 @@ async def _forward_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: ) if speech_handle.interrupted: - self._rt_session.interrupt() await utils.aio.gracefully_cancel(*tasks) if self._agent.output.audio is not None: self._agent.output.audio.clear_buffer() + print("wtf?") playback_ev = await self._agent.output.audio.wait_for_playout() debug.Tracing.log_event( @@ -663,12 +752,9 @@ async def _forward_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: # TODO(theomonnom): truncate serverside message speech_handle._mark_playout_done() - speech_handle._mark_done() return # TODO(theomonnom): tools - speech_handle._mark_playout_done() debug.Tracing.log_event("playout completed", {"speech_id": speech_handle.id}) - - speech_handle._mark_done() + speech_handle._mark_playout_done() diff --git a/livekit-agents/livekit/agents/pipeline/chat_cli.py b/livekit-agents/livekit/agents/pipeline/chat_cli.py index d2c0cd9bb..26692cbc3 100644 --- a/livekit-agents/livekit/agents/pipeline/chat_cli.py +++ b/livekit-agents/livekit/agents/pipeline/chat_cli.py @@ -3,6 +3,7 @@ import asyncio import sys import termios +import contextlib import time import tty from typing import Literal @@ -79,17 +80,30 @@ async def capture_frame(self, frame: rtc.AudioFrame) -> None: self._pushed_duration += frame.duration if self._cli._output_stream is not None: + + def _write(output_stream: sd.OutputStream, data: memoryview) -> None: + if ( + output_stream.active + and self._cli._output_stream == output_stream + ): + print("Writing data", output_stream) + with contextlib.suppress(sd.PortAudioError): + output_stream.write(data) + else: + print("Output stream closed") + await self._cli._loop.run_in_executor( - None, self._cli._output_stream.write, frame.data + None, _write, self._cli._output_stream, frame.data ) def clear_buffer(self) -> None: self._capturing = False - if self._cli._output_stream is not None and self._cli._output_stream.active: - # restarting the stream will clear the buffer - self._cli._output_stream.stop() - self._cli._output_stream.start() + if self._cli._output_stream is not None: + # hacky + print("Clearing buffer") + self._cli._update_speaker(enable=False) + self._cli._update_speaker(enable=True) if self._pushed_duration > 0.0: if self._dispatch_handle is not None: @@ -224,10 +238,9 @@ def _update_speaker(self, *, enable: bool) -> None: self._output_stream.start() self._agent.output.audio = self._audio_sink elif self._output_stream is not None: - self._output_stream.stop() self._output_stream.close() self._output_stream = None - self._agent.output.audio + self._agent.output.audio = None def _update_text_output(self, *, enable: bool) -> None: if enable: @@ -248,13 +261,6 @@ def _input_sd_callback(self, indata: np.ndarray, frame_count: int, *_) -> None: num_channels=1, ) self._saved_frames.append(frame) - - if len(self._saved_frames) > 20 * 5: - frmae = rtc.combine_audio_frames(self._saved_frames) - wav = frmae.to_wav_bytes() - with open("audio.wav", "wb") as f: - f.write(wav) - self._loop.call_soon_threadsafe(self._audio_input_ch.send_nowait, frame) @log_exceptions(logger=logger) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 7c43ab3d1..53172c0c8 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -42,13 +42,6 @@ def __init__(self) -> None: class PipelineAgent(rtc.EventEmitter[EventTypes]): - SPEECH_PRIORITY_LOW = 0 - """Priority for messages that should be played after all other messages in the queue""" - SPEECH_PRIORITY_NORMAL = 5 - """Every speech generates by the PipelineAgent defaults to this priority.""" - SPEECH_PRIORITY_HIGH = 10 - """Priority for important messages that should be played before others.""" - def __init__( self, *, @@ -70,6 +63,7 @@ def __init__( min_endpointing_delay=min_endpointing_delay, max_fnc_steps=max_fnc_steps, ) + self._started = False # configurable IO self._input = io.AgentInput( @@ -81,27 +75,39 @@ def __init__( self._on_text_output_changed, ) - # speech state - self._current_speech: SpeechHandle | None = None - self._speech_q: list[Tuple[int, SpeechHandle]] = [] - self._speech_q_changed = asyncio.Event() - - self._main_atask: asyncio.Task | None = None + self._forward_audio_atask: asyncio.Task | None = None + self._update_activity_atask: asyncio.Task | None = None + self._lock = asyncio.Lock() # agent tasks self._current_task: AgentTask = task - self._active_task: ActiveTask + self._active_task: ActiveTask | None = None # -- Pipeline nodes -- # They can all be overriden by subclasses, by default they use the STT/LLM/TTS specified in the # constructor of the PipelineAgent def start(self) -> None: - self._main_atask = asyncio.create_task(self._main_task(), name="_main_task") + if self._started: + return + + if self.input.audio is not None: + self._forward_audio_atask = asyncio.create_task( + self._forward_audio_task(), name="_forward_audio_task" + ) + + self._update_activity_atask = asyncio.create_task( + self._update_activity_task(self._current_task), name="_update_activity_task" + ) + + self._started = True async def aclose(self) -> None: - if self._main_atask is not None: - await utils.aio.gracefully_cancel(self._main_atask) + if not self._started: + return + + if self._forward_audio_atask is not None: + await utils.aio.gracefully_cancel(self._forward_audio_atask) @property def options(self) -> PipelineOptions: @@ -121,7 +127,7 @@ def output(self) -> io.AgentOutput: @property def current_speech(self) -> SpeechHandle | None: - return self._current_speech + raise NotImplementedError() @property def current_task(self) -> AgentTask: @@ -140,39 +146,34 @@ def say(self, text: str | AsyncIterable[str]) -> SpeechHandle: def generate_reply(self, user_input: str) -> SpeechHandle: raise NotImplementedError() - def _update_task(self, task: AgentTask) -> None: - pass + def update_task(self, task: AgentTask) -> None: + self._current_task = task - def _schedule_speech(self, speech: SpeechHandle, priority: int) -> None: - heapq.heappush(self._speech_q, (priority, speech)) - self._speech_q_changed.set() + if self._started: + self._update_activity_task = asyncio.create_task( + self._update_activity_task(self._current_task), + name="_update_activity_task", + ) @utils.log_exceptions(logger=logger) - async def _main_task(self) -> None: - @utils.log_exceptions(logger=logger) - async def _speech_scheduling_task() -> None: - while True: - await self._speech_q_changed.wait() - - while self._speech_q: - _, speech = heapq.heappop(self._speech_q) - self._current_speech = speech - speech._authorize_playout() - await speech.wait_for_playout() - self._current_speech = None - - self._speech_q_changed.clear() - - tasks = [ - asyncio.create_task( - _speech_scheduling_task(), name="_speech_scheduling_task" - ) - ] + async def _update_activity_task(self, task: AgentTask) -> None: + async with self._lock: + if self._active_task is not None: + await self._active_task.drain() + await self._active_task.aclose() + + self._active_task = task._create_activity(self) + await self._active_task.start() + + @utils.log_exceptions(logger=logger) + async def _forward_audio_task(self) -> None: + audio_input = self.input.audio + if audio_input is None: + return - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) + async for frame in audio_input: + if self._active_task is not None: + self._active_task.push_audio(frame) # -- User changed input/output streams/sinks -- @@ -180,7 +181,15 @@ def _on_video_input_changed(self) -> None: pass def _on_audio_input_changed(self) -> None: - pass + if not self._started: + return + + if self._forward_audio_atask is not None: + self._forward_audio_atask.cancel() + + self._forward_audio_atask = asyncio.create_task( + self._forward_audio_task(), name="_forward_audio_task" + ) def _on_video_output_changed(self) -> None: pass diff --git a/livekit-agents/livekit/agents/pipeline/speech_handle.py b/livekit-agents/livekit/agents/pipeline/speech_handle.py index eece0141a..0ee1da512 100644 --- a/livekit-agents/livekit/agents/pipeline/speech_handle.py +++ b/livekit-agents/livekit/agents/pipeline/speech_handle.py @@ -2,11 +2,19 @@ import asyncio import contextlib +from typing import Callable from .. import utils class SpeechHandle: + SPEECH_PRIORITY_LOW = 0 + """Priority for messages that should be played after all other messages in the queue""" + SPEECH_PRIORITY_NORMAL = 5 + """Every speech generates by the PipelineAgent defaults to this priority.""" + SPEECH_PRIORITY_HIGH = 10 + """Priority for important messages that should be played before others.""" + def __init__( self, *, speech_id: str, allow_interruptions: bool, step_index: int ) -> None: @@ -14,8 +22,7 @@ def __init__( self._step_index = step_index self._allow_interruptions = allow_interruptions self._interrupt_fut = asyncio.Future() - self._done_fut = asyncio.Future() - self._play_fut = asyncio.Future() + self._authorize_fut = asyncio.Future() self._playout_done_fut = asyncio.Future() @staticmethod @@ -42,11 +49,8 @@ def interrupted(self) -> bool: def allow_interruptions(self) -> bool: return self._allow_interruptions - def _authorize_playout(self) -> None: - self._play_fut.set_result(None) - def done(self) -> bool: - return self._done_fut.done() + return self._playout_done_fut.done() def interrupt(self) -> None: if not self._allow_interruptions: @@ -55,16 +59,21 @@ def interrupt(self) -> None: if self.done(): return - self._done_fut.set_result(None) self._interrupt_fut.set_result(None) async def wait_for_playout(self) -> None: await asyncio.shield(self._playout_done_fut) - def _mark_playout_done(self) -> None: - self._playout_done_fut.set_result(None) + def add_done_callback(self, callback: Callable[[SpeechHandle], None]) -> None: + self._playout_done_fut.add_done_callback(lambda _: callback(self)) - def _mark_done(self) -> None: + def _authorize_playout(self) -> None: + self._authorize_fut.set_result(None) + + async def _wait_for_authorization(self) -> None: + await asyncio.shield(self._authorize_fut) + + def _mark_playout_done(self) -> None: with contextlib.suppress(asyncio.InvalidStateError): # will raise InvalidStateError if the future is already done (interrupted) - self._done_fut.set_result(None) + self._playout_done_fut.set_result(None) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py index 8dbc3a33e..76e8a1513 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py @@ -19,7 +19,6 @@ from typing import Any from livekit.agents.llm import function_context, llm -from livekit.agents.llm.function_context import _is_optional_type __all__ = ["build_oai_function_description"] diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 37526dd4b..02b747b67 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -29,7 +29,7 @@ APITimeoutError, llm, ) -from livekit.agents.llm import ToolChoice, _create_ai_function_info +from livekit.agents.llm import ToolChoice from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions import openai diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py index fbb453609..bf2be937f 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py @@ -1,33 +1,9 @@ -from . import api_proto from .realtime_model import ( - DEFAULT_INPUT_AUDIO_TRANSCRIPTION, - DEFAULT_SERVER_VAD_OPTIONS, - InputTranscriptionOptions, - RealtimeContent, - RealtimeError, RealtimeModel, - RealtimeOutput, - RealtimeResponse, RealtimeSession, - RealtimeSessionOptions, - RealtimeToolCall, - ServerVadOptions, ) __all__ = [ - "RealtimeContent", - "RealtimeOutput", - "RealtimeResponse", - "RealtimeToolCall", "RealtimeSession", "RealtimeModel", - "RealtimeError", - "RealtimeSessionOptions", - "ServerVadOptions", - "InputTranscriptionOptions", - "ConversationItemCreated", - "ConversationItemDeleted", - "api_proto", - "DEFAULT_INPUT_AUDIO_TRANSCRIPTION", - "DEFAULT_SERVER_VAD_OPTIONS", ] diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index 716db3ed9..a947b0f66 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -94,7 +94,7 @@ class RealtimeSession(multimodal.RealtimeSession): def __init__(self, realtime_model: RealtimeModel) -> None: super().__init__(realtime_model) self._realtime_model = realtime_model - self._chat_ctx = llm.ChatContext() + self._chat_ctx = llm.ChatContext.empty() self._fnc_ctx = llm.FunctionContext.empty() self._msg_ch = utils.aio.Chan[RealtimeClientEvent]() From c63d862bb56715f0d908ad9f35d7dc5d99d13962 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 21 Jan 2025 17:41:13 +0100 Subject: [PATCH 14/19] fix segfault & interruptions --- .../livekit/agents/pipeline/agent_task.py | 10 +- .../livekit/agents/pipeline/chat_cli.py | 93 +++++++++++-------- livekit-agents/livekit/agents/pipeline/io.py | 26 +++--- .../livekit/plugins/openai/stt.py | 1 - 4 files changed, 74 insertions(+), 56 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py index 0b024c381..e4f23d4c3 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_task.py +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -213,6 +213,9 @@ async def start(self) -> None: self._rt_session.on( "input_speech_started", self._on_input_speech_started ) + self._rt_session.on( + "input_speech_stopped", self._on_input_speech_stopped + ) await self._rt_session.update_chat_ctx(self._task.chat_ctx) self._started = True @@ -309,9 +312,13 @@ async def _main_task(self) -> None: # -- Realtime Session events -- - def _on_input_speech_started(self, ev: multimodal.InputSpeechStartedEvent) -> None: + def _on_input_speech_started(self, _: multimodal.InputSpeechStartedEvent) -> None: + debug.Tracing.log_event("input_speech_started") self.interrupt() + def _on_input_speech_stopped(self, _: multimodal.InputSpeechStoppedEvent) -> None: + debug.Tracing.log_event("input_speech_stopped") + def _on_generation_created(self, ev: multimodal.GenerationCreatedEvent) -> None: if self.draining: logger.warning("skipping new generation, task is draining") @@ -739,7 +746,6 @@ async def _forward_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: if self._agent.output.audio is not None: self._agent.output.audio.clear_buffer() - print("wtf?") playback_ev = await self._agent.output.audio.wait_for_playout() debug.Tracing.log_event( diff --git a/livekit-agents/livekit/agents/pipeline/chat_cli.py b/livekit-agents/livekit/agents/pipeline/chat_cli.py index 26692cbc3..9da68a933 100644 --- a/livekit-agents/livekit/agents/pipeline/chat_cli.py +++ b/livekit-agents/livekit/agents/pipeline/chat_cli.py @@ -1,9 +1,9 @@ from __future__ import annotations import asyncio +import threading import sys import termios -import contextlib import time import tty from typing import Literal @@ -65,45 +65,49 @@ def __init__(self, cli: "ChatCLI") -> None: self._flush_complete = asyncio.Event() self._flush_complete.set() + self._output_buf = bytearray() + self._output_lock = threading.Lock() + + @property + def lock(self) -> threading.Lock: + return self._output_lock + + @property + def audio_buffer(self) -> bytearray: + return self._output_buf + async def capture_frame(self, frame: rtc.AudioFrame) -> None: await super().capture_frame(frame) await self._flush_complete.wait() - if not frame.duration: - return - if not self._capturing: self._capturing = True - self._buffer_duration = 0.0 + self._pushed_duration = 0.0 self._capture_start = time.monotonic() self._pushed_duration += frame.duration + print(f"Pushed audio frame, duration: {self._pushed_duration:.2f}s") + with self._output_lock: + self._output_buf += frame.data - if self._cli._output_stream is not None: - - def _write(output_stream: sd.OutputStream, data: memoryview) -> None: - if ( - output_stream.active - and self._cli._output_stream == output_stream - ): - print("Writing data", output_stream) - with contextlib.suppress(sd.PortAudioError): - output_stream.write(data) - else: - print("Output stream closed") - - await self._cli._loop.run_in_executor( - None, _write, self._cli._output_stream, frame.data + def flush(self) -> None: + super().flush() + if self._capturing: + self._flush_complete.clear() + self._capturing = False + to_wait = max( + 0.0, self._pushed_duration - (time.monotonic() - self._capture_start) + ) + print(f"Flushing audio buffer, waiting for {to_wait:.2f}s") + self._dispatch_handle = self._cli._loop.call_later( + to_wait, self._dispatch_playback_finished ) def clear_buffer(self) -> None: self._capturing = False - if self._cli._output_stream is not None: - # hacky - print("Clearing buffer") - self._cli._update_speaker(enable=False) - self._cli._update_speaker(enable=True) + with self._output_lock: + self._output_buf.clear() if self._pushed_duration > 0.0: if self._dispatch_handle is not None: @@ -119,18 +123,8 @@ def clear_buffer(self) -> None: interrupted=played_duration + 1.0 < self._pushed_duration, ) - def flush(self) -> None: - if self._capturing: - self._flush_complete.clear() - self._capturing = False - to_wait = min( - 0.0, self._pushed_duration - (time.monotonic() - self._capture_start) - ) - self._dispatch_handle = self._cli._loop.call_later( - to_wait, self._dispatch_playback_finished - ) - def _dispatch_playback_finished(self) -> None: + print("sending playback finished event") self.on_playback_finished( playback_position=self._pushed_duration, interrupted=False ) @@ -161,8 +155,6 @@ def __init__( self._text_sink = _TextSink(self) self._audio_sink = _AudioSink(self) - self._saved_frames = [] - def _print_welcome(self): print(_esc(34) + "=" * 50 + _esc(0)) print(_esc(34) + " Livekit Agents - ChatCLI" + _esc(0)) @@ -212,7 +204,7 @@ def _update_microphone(self, *, enable: bool) -> None: self._input_device_name = device_info.get("name", "Microphone") self._input_stream = sd.InputStream( - callback=self._input_sd_callback, + callback=self._sd_input_callback, dtype="int16", channels=1, device=input_device, @@ -230,10 +222,12 @@ def _update_speaker(self, *, enable: bool) -> None: _, output_device = sd.default.device if output_device is not None and enable: self._output_stream = sd.OutputStream( + callback=self._sd_output_callback, dtype="int16", channels=1, device=output_device, samplerate=24000, + blocksize=4800, ) self._output_stream.start() self._agent.output.audio = self._audio_sink @@ -249,7 +243,25 @@ def _update_text_output(self, *, enable: bool) -> None: self._agent.output.text = None self._text_input_buf = [] - def _input_sd_callback(self, indata: np.ndarray, frame_count: int, *_) -> None: + def _sd_output_callback(self, outdata: np.ndarray, frames: int, *_) -> None: + with self._audio_sink.lock: + bytes_needed = frames * 2 + if len(self._audio_sink.audio_buffer) < bytes_needed: + available_bytes = len(self._audio_sink.audio_buffer) + outdata[: available_bytes // 2, 0] = np.frombuffer( + self._audio_sink.audio_buffer, + dtype=np.int16, + count=available_bytes // 2, + ) + + outdata[available_bytes // 2 :, 0] = 0 + del self._audio_sink.audio_buffer[:available_bytes] + else: + chunk = self._audio_sink.audio_buffer[:bytes_needed] + outdata[:, 0] = np.frombuffer(chunk, dtype=np.int16, count=frames) + del self._audio_sink.audio_buffer[:bytes_needed] + + def _sd_input_callback(self, indata: np.ndarray, frame_count: int, *_) -> None: rms = np.sqrt(np.mean(indata.astype(np.float32) ** 2)) max_int16 = np.iinfo(np.int16).max self._micro_db = 20.0 * np.log10(rms / max_int16 + 1e-6) @@ -260,7 +272,6 @@ def _input_sd_callback(self, indata: np.ndarray, frame_count: int, *_) -> None: sample_rate=24000, num_channels=1, ) - self._saved_frames.append(frame) self._loop.call_soon_threadsafe(self._audio_input_ch.send_nowait, frame) @log_exceptions(logger=logger) diff --git a/livekit-agents/livekit/agents/pipeline/io.py b/livekit-agents/livekit/agents/pipeline/io.py index a93d5008b..512a24f71 100644 --- a/livekit-agents/livekit/agents/pipeline/io.py +++ b/livekit-agents/livekit/agents/pipeline/io.py @@ -16,6 +16,8 @@ from .. import llm, stt +from ..log import logger + STTNode = Callable[ [AsyncIterable[rtc.AudioFrame]], Union[Awaitable[Optional[AsyncIterable[stt.SpeechEvent]]]], @@ -61,7 +63,7 @@ def __init__(self, *, sample_rate: int | None = None) -> None: self.__capturing = False self.__playback_finished_event = asyncio.Event() - self.__nb_playback_finished_needed = 0 + self.__playback_segments_count = 0 self.__playback_finished_count = 0 def on_playback_finished( @@ -71,9 +73,13 @@ def on_playback_finished( Developers building audio sinks must call this method when a playback/segment is finished. Segments are segmented by calls to flush() or clear_buffer() """ - self.__nb_playback_finished_needed = max( - 0, self.__nb_playback_finished_needed - 1 - ) + + if self.__playback_finished_count >= self.__playback_segments_count: + logger.warning( + "playback_finished called more times than playback segments were captured" + ) + return + self.__playback_finished_count += 1 self.__playback_finished_event.set() @@ -91,11 +97,9 @@ async def wait_for_playout(self) -> PlaybackFinishedEvent: PlaybackFinishedEvent: The event that was emitted when the audio finished playing out (only the last segment information) """ - needed = self.__nb_playback_finished_needed - initial_count = self.__playback_finished_count - target_count = initial_count + needed + target = self.__playback_segments_count - while self.__playback_finished_count < target_count: + while self.__playback_finished_count < target: await self.__playback_finished_event.wait() self.__playback_finished_event.clear() @@ -111,18 +115,16 @@ async def capture_frame(self, frame: rtc.AudioFrame) -> None: """Capture an audio frame for playback, frames can be pushed faster than real-time""" if not self.__capturing: self.__capturing = True - self.__nb_playback_finished_needed += 1 + self.__playback_segments_count += 1 @abstractmethod def flush(self) -> None: """Flush any buffered audio, marking the current playback/segment as complete""" - if self.__capturing: - self.__capturing = False + self.__capturing = False @abstractmethod def clear_buffer(self) -> None: """Clear the buffer, stopping playback immediately""" - ... class TextSink(ABC): diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py index 1b59e1f98..e3f19972a 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py @@ -138,7 +138,6 @@ async def _recognize_impl( conn_options: APIConnectOptions, ) -> stt.SpeechEvent: try: - print("buffer", buffer) config = self._sanitize_options(language=language) data = rtc.combine_audio_frames(buffer).to_wav_bytes() resp = await self._client.audio.transcriptions.create( From 1bbbdfd048ea922d0202837cd41047e5f614b015 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 21 Jan 2025 23:31:43 +0100 Subject: [PATCH 15/19] multimodal wip --- examples/minimal_worker.py | 9 +- .../livekit/agents/multimodal/realtime.py | 15 +-- .../livekit/agents/pipeline/agent_task.py | 91 +++++++------------ .../livekit/agents/pipeline/speech_handle.py | 7 +- .../livekit/agents/utils/aio/__init__.py | 31 +------ .../livekit/agents/utils/aio/task_set.py | 17 ++-- .../livekit/agents/utils/aio/utils.py | 27 ++++++ .../livekit/plugins/openai/_oai_api.py | 91 ------------------- .../livekit/plugins/openai/llm.py | 1 - .../plugins/openai/realtime/realtime_model.py | 22 ++++- 10 files changed, 105 insertions(+), 206 deletions(-) create mode 100644 livekit-agents/livekit/agents/utils/aio/utils.py delete mode 100644 livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py diff --git a/examples/minimal_worker.py b/examples/minimal_worker.py index d3d100513..fe5adb8f2 100644 --- a/examples/minimal_worker.py +++ b/examples/minimal_worker.py @@ -2,7 +2,7 @@ from dotenv import load_dotenv from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, WorkerType, cli -from livekit.agents.pipeline import ChatCLI, PipelineAgent +from livekit.agents.pipeline import ChatCLI, PipelineAgent, AgentTask from livekit.plugins import cartesia, deepgram, openai, silero logger = logging.getLogger("my-worker") @@ -12,10 +12,11 @@ async def entrypoint(ctx: JobContext): - await ctx.connect(auto_subscribe=AutoSubscribe.SUBSCRIBE_ALL) - agent = PipelineAgent( - llm=openai.LLM(), stt=deepgram.STT(), tts=cartesia.TTS(), vad=silero.VAD.load() + task=AgentTask( + instructions="Talk to me!", + llm=openai.realtime.RealtimeModel(), + ) ) agent.start() diff --git a/livekit-agents/livekit/agents/multimodal/realtime.py b/livekit-agents/livekit/agents/multimodal/realtime.py index f3e08cb15..7705476ee 100644 --- a/livekit-agents/livekit/agents/multimodal/realtime.py +++ b/livekit-agents/livekit/agents/multimodal/realtime.py @@ -54,7 +54,7 @@ async def aclose(self) -> None: ... EventTypes = Literal[ - "input_speech_started", # serverside VAD + "input_speech_started", # serverside VAD (also used for interruptions) "input_speech_stopped", # serverside VAD "generation_created", "error", @@ -80,13 +80,16 @@ def realtime_model(self) -> RealtimeModel: @abstractmethod def chat_ctx(self) -> llm.ChatContext: ... - @abstractmethod - async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: ... - @property @abstractmethod def fnc_ctx(self) -> llm.FunctionContext: ... + @abstractmethod + async def update_instructions(self, instructions: str) -> None: ... + + @abstractmethod + async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: ... + @abstractmethod async def update_fnc_ctx( self, fnc_ctx: llm.FunctionContext | list[llm.AIFunction] @@ -100,9 +103,7 @@ def generate_reply(self) -> None: ... # when VAD is disabled # cancel the current generation (do nothing if no generation is in progress) @abstractmethod - def interrupt( - self, - ) -> None: ... + def interrupt(self) -> None: ... # message_id is the ID of the message to truncate (inside the ChatCtx) @abstractmethod diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py index e4f23d4c3..77b60d2c9 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_task.py +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -75,10 +75,6 @@ def __init__( def instructions(self) -> str: return self._instructions - @instructions.setter - def instructions(self, instructions: str) -> None: - self._instructions = instructions - @property def chat_ctx(self) -> llm.ChatContext: return self._chat_ctx @@ -217,6 +213,7 @@ async def start(self) -> None: "input_speech_stopped", self._on_input_speech_stopped ) await self._rt_session.update_chat_ctx(self._task.chat_ctx) + await self._rt_session.update_instructions(self._task.instructions) self._started = True @@ -668,85 +665,61 @@ async def _realtime_reply_task( ) -> None: assert self._rt_session is not None, "rt_session is not available" + audio_output = self._agent.output.audio + text_output = self._agent.output.text + @utils.log_exceptions(logger=logger) async def _forward_text(llm_output: AsyncIterable[str]) -> None: + assert text_output is not None, "text_output is not available" try: async for delta in llm_output: - if self._agent.output.text is None: - break - - await self._agent.output.text.capture_text(delta) + await text_output.capture_text(delta) finally: - if self._agent.output.text is not None: - self._agent.output.text.flush() + text_output.flush() @utils.log_exceptions(logger=logger) async def _forward_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: + assert audio_output is not None, "audio_output is not available" try: async for frame in tts_output: - if self._agent.output.audio is None: - break - await self._agent.output.audio.capture_frame(frame) + await audio_output.capture_frame(frame) finally: - if self._agent.output.audio is not None: - self._agent.output.audio.flush() # always flush (even if the task is interrupted) + audio_output.flush() - # wait for the play() method to be called - await asyncio.wait( - [ - speech_handle._wait_for_authorization(), - speech_handle._interrupt_fut, - ], - return_when=asyncio.FIRST_COMPLETED, + await speech_handle.wait_until_interrupted( + [speech_handle._wait_for_authorization()] ) if speech_handle.interrupted: - speech_handle._mark_playout_done() - # TODO(theomonnom): remove the message from the serverside history + speech_handle._mark_playout_done() # TODO(theomonnom): remove the message from the serverside history return - wg = utils.aio.WaitGroup() - tasks: list[asyncio.Task] = [] - - forward_text_task = asyncio.create_task( - _forward_text(generation_ev.text_stream), - name="_realtime_reply_task.forward_text", - ) - wg.add(1) - forward_text_task.add_done_callback(lambda _: wg.done()) - tasks.append(forward_text_task) + ts = utils.aio.TaskSet() + if text_output is not None: + ts.create_task( + _forward_text(generation_ev.text_stream), + name="_realtime_reply_task.forward_text", + ) - forward_audio_task = asyncio.create_task( - _forward_audio(generation_ev.audio_stream), - name="_realtime_reply_task.forward_audio", - ) - wg.add(1) - forward_audio_task.add_done_callback(lambda _: wg.done()) - tasks.append(forward_text_task) + if audio_output is not None: + ts.create_task( + _forward_audio(generation_ev.audio_stream), + name="_realtime_reply_task.forward_audio", + ) - await asyncio.wait( - [ - wg.wait(), - speech_handle._interrupt_fut, - ], - return_when=asyncio.FIRST_COMPLETED, - ) + await speech_handle.wait_until_interrupted([*ts.tasks]) - if forward_audio_task is not None and self._agent.output.audio is not None: - await asyncio.wait( - [ - self._agent.output.audio.wait_for_playout(), - speech_handle._interrupt_fut, - ], - return_when=asyncio.FIRST_COMPLETED, + if audio_output is not None: + await speech_handle.wait_until_interrupted( + [audio_output.wait_for_playout()] ) if speech_handle.interrupted: - await utils.aio.gracefully_cancel(*tasks) + await utils.aio.gracefully_cancel(*ts.tasks) - if self._agent.output.audio is not None: - self._agent.output.audio.clear_buffer() - playback_ev = await self._agent.output.audio.wait_for_playout() + if audio_output is not None: + audio_output.clear_buffer() + playback_ev = await audio_output.wait_for_playout() debug.Tracing.log_event( "playout interrupted", diff --git a/livekit-agents/livekit/agents/pipeline/speech_handle.py b/livekit-agents/livekit/agents/pipeline/speech_handle.py index 0ee1da512..bd3971291 100644 --- a/livekit-agents/livekit/agents/pipeline/speech_handle.py +++ b/livekit-agents/livekit/agents/pipeline/speech_handle.py @@ -2,7 +2,7 @@ import asyncio import contextlib -from typing import Callable +from typing import Awaitable, Callable from .. import utils @@ -77,3 +77,8 @@ def _mark_playout_done(self) -> None: with contextlib.suppress(asyncio.InvalidStateError): # will raise InvalidStateError if the future is already done (interrupted) self._playout_done_fut.set_result(None) + + async def wait_until_interrupted(self, aw: list[Awaitable]) -> None: + await asyncio.wait( + [*aw, self._interrupt_fut], return_when=asyncio.FIRST_COMPLETED + ) diff --git a/livekit-agents/livekit/agents/utils/aio/__init__.py b/livekit-agents/livekit/agents/utils/aio/__init__.py index bdef55dd7..d30bc9292 100644 --- a/livekit-agents/livekit/agents/utils/aio/__init__.py +++ b/livekit-agents/livekit/agents/utils/aio/__init__.py @@ -1,38 +1,10 @@ -import asyncio -import functools - from . import debug, duplex_unix, itertools from .channel import Chan, ChanClosed, ChanReceiver, ChanSender from .interval import Interval, interval from .sleep import Sleep, SleepFinished, sleep from .task_set import TaskSet from .wait_group import WaitGroup - - -async def gracefully_cancel(*futures: asyncio.Future): - loop = asyncio.get_running_loop() - waiters = [] - - for fut in futures: - waiter = loop.create_future() - cb = functools.partial(_release_waiter, waiter) - waiters.append((waiter, cb)) - fut.add_done_callback(cb) - fut.cancel() - - try: - for waiter, _ in waiters: - await waiter - finally: - for i, fut in enumerate(futures): - _, cb = waiters[i] - fut.remove_done_callback(cb) - - -def _release_waiter(waiter, *args): - if not waiter.done(): - waiter.set_result(None) - +from .utils import gracefully_cancel __all__ = [ "ChanClosed", @@ -51,4 +23,5 @@ def _release_waiter(waiter, *args): "gracefully_cancel", "duplex_unix", "itertools", + "gracefully_cancel", ] diff --git a/livekit-agents/livekit/agents/utils/aio/task_set.py b/livekit-agents/livekit/agents/utils/aio/task_set.py index 848d87545..ff6437372 100644 --- a/livekit-agents/livekit/agents/utils/aio/task_set.py +++ b/livekit-agents/livekit/agents/utils/aio/task_set.py @@ -7,25 +7,24 @@ class TaskSet: - """ - Small utility to create task in a fire-and-forget fashion. - """ + """Small utility to create tasks in a fire-and-forget fashion.""" def __init__(self, loop: asyncio.AbstractEventLoop | None = None) -> None: self._loop = loop or asyncio.get_event_loop() self._set = set[asyncio.Task[Any]]() self._closed = False - def create_task(self, coro: Coroutine[Any, Any, _T]) -> asyncio.Task[_T]: + def create_task( + self, coro: Coroutine[Any, Any, _T], name: str | None = None + ) -> asyncio.Task[_T]: if self._closed: raise RuntimeError("TaskSet is closed") - task = self._loop.create_task(coro) + task = self._loop.create_task(coro, name=name) self._set.add(task) task.add_done_callback(self._set.remove) return task - async def aclose(self) -> None: - self._closed = True - await asyncio.gather(*self._set, return_exceptions=True) - self._set.clear() + @property + def tasks(self) -> set[asyncio.Task[Any]]: + return self._set.copy() diff --git a/livekit-agents/livekit/agents/utils/aio/utils.py b/livekit-agents/livekit/agents/utils/aio/utils.py new file mode 100644 index 000000000..19251c961 --- /dev/null +++ b/livekit-agents/livekit/agents/utils/aio/utils.py @@ -0,0 +1,27 @@ +import asyncio +import functools + + +async def gracefully_cancel(*futures: asyncio.Future): + loop = asyncio.get_running_loop() + waiters = [] + + for fut in futures: + waiter = loop.create_future() + cb = functools.partial(_release_waiter, waiter) + waiters.append((waiter, cb)) + fut.add_done_callback(cb) + fut.cancel() + + try: + for waiter, _ in waiters: + await waiter + finally: + for i, fut in enumerate(futures): + _, cb = waiters[i] + fut.remove_done_callback(cb) + + +def _release_waiter(waiter, *_): + if not waiter.done(): + waiter.set_result(None) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py deleted file mode 100644 index 76e8a1513..000000000 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/_oai_api.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2023 LiveKit, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import inspect -import typing -from typing import Any - -from livekit.agents.llm import function_context, llm - -__all__ = ["build_oai_function_description"] - - -def build_oai_function_description( - fnc_info: function_context.FunctionInfo, - capabilities: llm.LLMCapabilities | None = None, -) -> dict[str, Any]: - def build_oai_property(arg_info: function_context.FunctionArgInfo): - def type2str(t: type) -> str: - if t is str: - return "string" - elif t in (int, float): - return "number" - elif t is bool: - return "boolean" - - raise ValueError(f"unsupported type {t} for ai_property") - - p: dict[str, Any] = {} - - if arg_info.description: - p["description"] = arg_info.description - - is_optional, inner_th = _is_optional_type(arg_info.type) - - if typing.get_origin(inner_th) is list: - inner_type = typing.get_args(inner_th)[0] - p["type"] = "array" - p["items"] = {} - p["items"]["type"] = type2str(inner_type) - - if arg_info.choices: - p["items"]["enum"] = arg_info.choices - else: - p["type"] = type2str(inner_th) - if arg_info.choices: - p["enum"] = arg_info.choices - if ( - inner_th is int - and capabilities - and not capabilities.supports_choices_on_int - ): - raise ValueError( - f"Parameter '{arg_info.name}' uses 'choices' with 'int', which is not supported by this model." - ) - - return p - - properties_info: dict[str, dict[str, Any]] = {} - required_properties: list[str] = [] - - for arg_info in fnc_info.arguments.values(): - if arg_info.default is inspect.Parameter.empty: - required_properties.append(arg_info.name) - - properties_info[arg_info.name] = build_oai_property(arg_info) - - return { - "type": "function", - "function": { - "name": fnc_info.name, - "description": fnc_info.description, - "parameters": { - "type": "object", - "properties": properties_info, - "required": required_properties, - }, - }, - } diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 02b747b67..b97868a47 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -36,7 +36,6 @@ from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam from openai.types.chat.chat_completion_chunk import Choice -from ._oai_api import build_oai_function_description from .log import logger from .models import ( CerebrasChatModels, diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index a947b0f66..687eb51a1 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -189,7 +189,6 @@ def _handle_response_output_item_added( return self._current_generation.item_id = item_id - self.emit( "generation_created", multimodal.GenerationCreatedEvent( @@ -272,6 +271,10 @@ def _handle_error(self, event: ErrorEvent) -> None: def chat_ctx(self) -> llm.ChatContext: return self._chat_ctx.copy() + @property + def fnc_ctx(self) -> llm.FunctionContext: + return self._fnc_ctx.copy() + async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: diff_ops = llm.compute_chat_ctx_diff(self._chat_ctx, chat_ctx) @@ -298,10 +301,6 @@ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: # TODO(theomonnom): wait for the server confirmation - @property - def fnc_ctx(self) -> llm.FunctionContext: - return self._fnc_ctx - async def update_fnc_ctx( self, fnc_ctx: llm.FunctionContext | list[llm.AIFunction] ) -> None: @@ -339,6 +338,19 @@ async def update_fnc_ctx( # TODO(theomonnom): wait for the server confirmation before updating the local state self._fnc_ctx = llm.FunctionContext(retained_functions) + async def update_instructions(self, instructions: str) -> None: + self._msg_ch.send_nowait( + SessionUpdateEvent( + type="session.update", + session=session_update_event.Session( + model=self._realtime_model._opts.model, # type: ignore + instructions=instructions, + ), + ) + ) + + # TODO(theomonnom): wait for the server confirmation + def push_audio(self, frame: rtc.AudioFrame) -> None: self._msg_ch.send_nowait( InputAudioBufferAppendEvent( From 54ef4cec8372b9ef28ec1b99498a9379433de896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Mon, 27 Jan 2025 16:56:04 +0100 Subject: [PATCH 16/19] multimodal sync & simple task transition example --- examples/minimal_worker.py | 36 +- livekit-agents/livekit/agents/cli/watcher.py | 2 +- .../livekit/agents/ipc/job_proc_executor.py | 2 +- .../livekit/agents/ipc/job_proc_lazy_main.py | 2 +- .../livekit/agents/ipc/job_thread_executor.py | 4 +- .../livekit/agents/ipc/proc_client.py | 4 +- .../livekit/agents/ipc/proc_pool.py | 2 +- .../livekit/agents/ipc/supervised_proc.py | 6 +- livekit-agents/livekit/agents/llm/__init__.py | 10 +- .../livekit/agents/llm/chat_context.py | 5 +- .../livekit/agents/llm/function_context.py | 146 ++--- livekit-agents/livekit/agents/llm/llm.py | 4 +- .../livekit/agents/llm/remote_chat_context.py | 96 +++ livekit-agents/livekit/agents/llm/utils.py | 137 +++++ .../livekit/agents/multimodal/__init__.py | 4 + .../livekit/agents/multimodal/realtime.py | 18 +- .../livekit/agents/pipeline/__init__.py | 5 +- .../livekit/agents/pipeline/agent_task.py | 553 +++++++++++++----- .../agents/pipeline/audio_recognition.py | 12 +- .../livekit/agents/pipeline/chat_cli.py | 7 +- .../livekit/agents/pipeline/context.py | 17 + livekit-agents/livekit/agents/pipeline/io.py | 6 +- .../livekit/agents/pipeline/pipeline_agent.py | 15 +- .../agents/pipeline/speech_scheduler.py | 1 - .../livekit/agents/stt/fallback_adapter.py | 6 +- .../livekit/agents/stt/stream_adapter.py | 2 +- livekit-agents/livekit/agents/stt/stt.py | 2 +- .../livekit/agents/tts/fallback_adapter.py | 8 +- .../livekit/agents/tts/stream_adapter.py | 2 +- livekit-agents/livekit/agents/tts/tts.py | 4 +- .../livekit/agents/utils/aio/__init__.py | 6 +- .../livekit/agents/utils/aio/utils.py | 2 +- livekit-agents/livekit/agents/vad.py | 2 +- livekit-agents/livekit/agents/worker.py | 6 +- .../plugins/openai/realtime/realtime_model.py | 469 +++++++++------ 35 files changed, 1128 insertions(+), 475 deletions(-) create mode 100644 livekit-agents/livekit/agents/llm/remote_chat_context.py create mode 100644 livekit-agents/livekit/agents/pipeline/context.py delete mode 100644 livekit-agents/livekit/agents/pipeline/speech_scheduler.py diff --git a/examples/minimal_worker.py b/examples/minimal_worker.py index fe5adb8f2..ea332f9da 100644 --- a/examples/minimal_worker.py +++ b/examples/minimal_worker.py @@ -1,9 +1,10 @@ import logging from dotenv import load_dotenv -from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, WorkerType, cli -from livekit.agents.pipeline import ChatCLI, PipelineAgent, AgentTask -from livekit.plugins import cartesia, deepgram, openai, silero +from livekit.agents import JobContext, WorkerOptions, WorkerType, cli +from livekit.agents.llm import ai_function +from livekit.agents.pipeline import AgentTask, ChatCLI, PipelineAgent, AgentContext +from livekit.plugins import openai logger = logging.getLogger("my-worker") logger.setLevel(logging.INFO) @@ -11,12 +12,33 @@ load_dotenv() +class EchoTask(AgentTask): + def __init__(self) -> None: + super().__init__( + instructions="Always speak in English even if the user speaks in another language or wants to use another language.", + llm=openai.realtime.RealtimeModel(voice="echo"), + ) + + @ai_function + async def talk_to_alloy(self, context: AgentContext): + return AlloyTask(), "Transfering you to Alloy." + + +class AlloyTask(AgentTask): + def __init__(self) -> None: + super().__init__( + instructions="Always speak in English even if the user speaks in another language or wants to use another language.", + llm=openai.realtime.RealtimeModel(voice="alloy"), + ) + + @ai_function + async def talk_to_echo(self, context: AgentContext): + return EchoTask(), "Transfering you to Echo." + + async def entrypoint(ctx: JobContext): agent = PipelineAgent( - task=AgentTask( - instructions="Talk to me!", - llm=openai.realtime.RealtimeModel(), - ) + task=AlloyTask(), ) agent.start() diff --git a/livekit-agents/livekit/agents/cli/watcher.py b/livekit-agents/livekit/agents/cli/watcher.py index 5f4a60751..8b70c29fc 100644 --- a/livekit-agents/livekit/agents/cli/watcher.py +++ b/livekit-agents/livekit/agents/cli/watcher.py @@ -95,7 +95,7 @@ async def run(self) -> None: callback=self._on_reload, ) finally: - await utils.aio.gracefully_cancel(read_ipc_task) + await utils.aio.cancel_and_wait(read_ipc_task) await self._pch.aclose() async def _on_reload(self, _: Set[watchfiles.main.FileChange]) -> None: diff --git a/livekit-agents/livekit/agents/ipc/job_proc_executor.py b/livekit-agents/livekit/agents/ipc/job_proc_executor.py index 5c1ecb38b..f16836bc9 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_executor.py @@ -114,7 +114,7 @@ async def _main_task(self, ipc_ch: aio.ChanReceiver[channel.Message]) -> None: asyncio.create_task(self._do_inference_task(msg)) ) finally: - await aio.gracefully_cancel(*self._inference_tasks) + await aio.cancel_and_wait(*self._inference_tasks) @log_exceptions(logger=logger) async def _supervise_task(self) -> None: diff --git a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py index ed0092012..03dc1aae0 100644 --- a/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py +++ b/livekit-agents/livekit/agents/ipc/job_proc_lazy_main.py @@ -189,7 +189,7 @@ async def _read_ipc_task(): read_task = asyncio.create_task(_read_ipc_task(), name="job_ipc_read") await self._exit_proc_flag.wait() - await aio.gracefully_cancel(read_task) + await aio.cancel_and_wait(read_task) def _start_job(self, msg: StartJobRequest) -> None: self._room = rtc.Room() diff --git a/livekit-agents/livekit/agents/ipc/job_thread_executor.py b/livekit-agents/livekit/agents/ipc/job_thread_executor.py index 3eefac28f..7d9a70a77 100644 --- a/livekit-agents/livekit/agents/ipc/job_thread_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_thread_executor.py @@ -255,8 +255,8 @@ async def _main_task(self) -> None: monitor_task = asyncio.create_task(self._monitor_task()) await self._join_fut - await utils.aio.gracefully_cancel(ping_task, monitor_task) - await utils.aio.gracefully_cancel(*self._inference_tasks) + await utils.aio.cancel_and_wait(ping_task, monitor_task) + await utils.aio.cancel_and_wait(*self._inference_tasks) with contextlib.suppress(duplex_unix.DuplexClosed): await self._pch.aclose() diff --git a/livekit-agents/livekit/agents/ipc/proc_client.py b/livekit-agents/livekit/agents/ipc/proc_client.py index 97080e8cc..869cb4ef4 100644 --- a/livekit-agents/livekit/agents/ipc/proc_client.py +++ b/livekit-agents/livekit/agents/ipc/proc_client.py @@ -155,8 +155,8 @@ def _done_cb(_: asyncio.Task) -> None: main_task.add_done_callback(_done_cb) await exit_flag.wait() - await aio.gracefully_cancel(read_task, main_task) + await aio.cancel_and_wait(read_task, main_task) if health_check_task is not None: - await aio.gracefully_cancel(health_check_task) + await aio.cancel_and_wait(health_check_task) finally: await self._acch.aclose() diff --git a/livekit-agents/livekit/agents/ipc/proc_pool.py b/livekit-agents/livekit/agents/ipc/proc_pool.py index 25a395f53..c05e6af69 100644 --- a/livekit-agents/livekit/agents/ipc/proc_pool.py +++ b/livekit-agents/livekit/agents/ipc/proc_pool.py @@ -83,7 +83,7 @@ async def aclose(self) -> None: return self._closed = True - await aio.gracefully_cancel(self._main_atask) + await aio.cancel_and_wait(self._main_atask) async def launch_job(self, info: RunningJobInfo) -> None: if self._num_idle_processes == 0: diff --git a/livekit-agents/livekit/agents/ipc/supervised_proc.py b/livekit-agents/livekit/agents/ipc/supervised_proc.py index dfd18172d..4226d2b8b 100644 --- a/livekit-agents/livekit/agents/ipc/supervised_proc.py +++ b/livekit-agents/livekit/agents/ipc/supervised_proc.py @@ -263,10 +263,10 @@ async def _supervise_task(self) -> None: await self._join_fut self._exitcode = self._proc.exitcode self._proc.close() - await aio.gracefully_cancel(ping_task, read_ipc_task, main_task) + await aio.cancel_and_wait(ping_task, read_ipc_task, main_task) if memory_monitor_task is not None: - await aio.gracefully_cancel(memory_monitor_task) + await aio.cancel_and_wait(memory_monitor_task) with contextlib.suppress(duplex_unix.DuplexClosed): await self._pch.aclose() @@ -334,7 +334,7 @@ async def _pong_timeout_co(): try: await asyncio.gather(*tasks) finally: - await aio.gracefully_cancel(*tasks) + await aio.cancel_and_wait(*tasks) @log_exceptions(logger=logger) async def _memory_monitor_task(self) -> None: diff --git a/livekit-agents/livekit/agents/llm/__init__.py b/livekit-agents/livekit/agents/llm/__init__.py index c1adf8710..0d7dc7410 100644 --- a/livekit-agents/livekit/agents/llm/__init__.py +++ b/livekit-agents/livekit/agents/llm/__init__.py @@ -1,5 +1,7 @@ +from . import utils from .chat_context import ( AudioContent, + ChatContent, ChatContext, ChatItem, ChatMessage, @@ -11,6 +13,7 @@ from .function_context import ( AIFunction, FunctionContext, + AIError, ai_function, find_ai_functions, is_ai_function, @@ -25,13 +28,14 @@ LLMStream, ToolChoice, ) -from .utils import compute_chat_ctx_diff +from . import remote_chat_context __all__ = [ "LLM", "LLMStream", "ChatContext", "ChatMessage", + "ChatContent", "FunctionCall", "FunctionCallOutput", "AudioContent", @@ -46,10 +50,12 @@ "FallbackAdapter", "AvailabilityChangedEvent", "ToolChoice", - "compute_chat_ctx_diff", "is_ai_function", "ai_function", "find_ai_functions", "AIFunction", "FunctionContext", + "AIError", + "utils", + "remote_chat_context", ] diff --git a/livekit-agents/livekit/agents/llm/chat_context.py b/livekit-agents/livekit/agents/llm/chat_context.py index bf9fa2c7b..f2ce4a31d 100644 --- a/livekit-agents/livekit/agents/llm/chat_context.py +++ b/livekit-agents/livekit/agents/llm/chat_context.py @@ -92,10 +92,13 @@ class ChatMessage(BaseModel): id: str = Field(default_factory=lambda: utils.shortuuid("item_")) type: Literal["message"] = "message" role: Literal["developer", "system", "user", "assistant"] - content: list[Union[str, ImageContent, AudioContent]] + content: list[ChatContent] hash: Optional[bytes] = None +ChatContent: TypeAlias = Union[str, ImageContent, AudioContent] + + class FunctionCall(BaseModel): id: str = Field(default_factory=lambda: utils.shortuuid("item_")) type: Literal["function_call"] = "function_call" diff --git a/livekit-agents/livekit/agents/llm/function_context.py b/livekit-agents/livekit/agents/llm/function_context.py index 9507d7e79..7ecd074ba 100644 --- a/livekit-agents/livekit/agents/llm/function_context.py +++ b/livekit-agents/livekit/agents/llm/function_context.py @@ -15,37 +15,60 @@ from __future__ import annotations import inspect +from dataclasses import dataclass from typing import ( - Annotated, Any, Callable, - List, Protocol, - Type, - get_args, - get_origin, - get_type_hints, runtime_checkable, ) -from pydantic import BaseModel, create_model -from pydantic.fields import FieldInfo from typing_extensions import TypeGuard +class AIError(Exception): + def __init__(self, message: str) -> None: + """ + Exception raised within AI functions. + + This exception should be raised by users when an error occurs + in the context of AI operations. The provided message will be + visible to the LLM, allowing it to understand the context of + the error during FunctionOutput generation. + """ + super().__init__(message) + self._message = message + + @property + def message(self) -> str: + return self._message + + +@dataclass +class _AIFunctionInfo: + name: str + description: str | None + + @runtime_checkable class AIFunction(Protocol): - __livekit_agents_ai_callable: bool - __name__: str - __doc__: str | None + __livekit_agents_ai_callable: _AIFunctionInfo def __call__(self, *args: Any, **kwargs: Any) -> Any: ... -def ai_function(f: Callable | None = None) -> Callable[[Callable], AIFunction]: - def deco(f) -> AIFunction: - setattr(f, "__livekit_agents_ai_callable", True) - return f +def ai_function( + f: Callable | None = None, + *, + name: str | None = None, + description: str | None = None, +) -> Callable[[Callable], AIFunction]: + def deco(func) -> AIFunction: + info = _AIFunctionInfo( + name=name or func.__name__, description=description or func.__doc__ + ) + setattr(func, "__livekit_agents_ai_callable", info) + return func if callable(f): return deco(f) @@ -54,14 +77,18 @@ def deco(f) -> AIFunction: def is_ai_function(f: Callable) -> TypeGuard[AIFunction]: - return getattr(f, "__livekit_agents_ai_callable", False) + return hasattr(f, "__livekit_agents_ai_callable") + + +def get_function_info(f: AIFunction) -> _AIFunctionInfo: + return getattr(f, "__livekit_agents_ai_callable") -def find_ai_functions(cls: Type) -> List[AIFunction]: +def find_ai_functions(cls_or_obj: Any) -> list[AIFunction]: methods: list[AIFunction] = [] - for _, method in inspect.getmembers(cls, predicate=inspect.isfunction): - if is_ai_function(method): - methods.append(method) + for _, member in inspect.getmembers(cls_or_obj): + if is_ai_function(member): + methods.append(member) return methods @@ -82,85 +109,16 @@ def ai_functions(self) -> dict[str, AIFunction]: def update_ai_functions(self, ai_functions: list[AIFunction]) -> None: self._ai_functions = ai_functions - for method in find_ai_functions(self.__class__): + for method in find_ai_functions(self): ai_functions.append(method) self._ai_functions_map = {} for fnc in ai_functions: - if fnc.__name__ in self._ai_functions_map: - raise ValueError(f"duplicate function name: {fnc.__name__}") + info = get_function_info(fnc) + if info.name in self._ai_functions_map: + raise ValueError(f"duplicate function name: {info.name}") - self._ai_functions_map[fnc.__name__] = fnc + self._ai_functions_map[info.name] = fnc def copy(self) -> FunctionContext: return FunctionContext(self._ai_functions.copy()) - - -def build_legacy_openai_schema( - ai_function: AIFunction, *, internally_tagged: bool = False -) -> dict[str, Any]: - """non-strict mode tool description - see https://serde.rs/enum-representations.html for the internally tagged representation""" - model = build_pydantic_model_from_function(ai_function) - schema = model.model_json_schema() - - fnc_name = ai_function.__name__ - fnc_description = ai_function.__doc__ - - if internally_tagged: - return { - "name": fnc_name, - "description": fnc_description or "", - "parameters": schema, - "type": "function", - } - else: - return { - "type": "function", - "function": { - "name": fnc_name, - "description": fnc_description or "", - "parameters": schema, - }, - } - - -def build_pydantic_model_from_function( - func: Callable, -) -> type[BaseModel]: - fnc_name = func.__name__.split("_") - fnc_name = "".join(x.capitalize() for x in fnc_name) - model_name = fnc_name + "Args" - - signature = inspect.signature(func) - type_hints = get_type_hints(func, include_extras=True) - - # field_name -> (type, FieldInfo or default) - fields: dict[str, Any] = {} - - for param_name, param in signature.parameters.items(): - annotation = type_hints[param_name] - default_value = param.default if param.default is not param.empty else ... - - # Annotated[str, Field(description="...")] - if get_origin(annotation) is Annotated: - annotated_args = get_args(annotation) - actual_type = annotated_args[0] - field_info = None - - for extra in annotated_args[1:]: - if isinstance(extra, FieldInfo): - field_info = extra # get the first FieldInfo - break - - if field_info: - if default_value is not ... and field_info.default is None: - field_info.default = default_value - fields[param_name] = (actual_type, field_info) - else: - fields[param_name] = (actual_type, default_value) - - else: - fields[param_name] = (annotation, default_value) - - return create_model(model_name, **fields) diff --git a/livekit-agents/livekit/agents/llm/llm.py b/livekit-agents/livekit/agents/llm/llm.py index 265898b26..2e25e4b1a 100644 --- a/livekit-agents/livekit/agents/llm/llm.py +++ b/livekit-agents/livekit/agents/llm/llm.py @@ -226,8 +226,8 @@ def execute_functions(self) -> list[function_context.CalledFunction]: return called_functions async def aclose(self) -> None: - await aio.gracefully_cancel(self._task) - await utils.aio.gracefully_cancel(*self._function_tasks) + await aio.cancel_and_wait(self._task) + await utils.aio.cancel_and_wait(*self._function_tasks) await self._metrics_task async def __anext__(self) -> ChatChunk: diff --git a/livekit-agents/livekit/agents/llm/remote_chat_context.py b/livekit-agents/livekit/agents/llm/remote_chat_context.py new file mode 100644 index 000000000..10f404202 --- /dev/null +++ b/livekit-agents/livekit/agents/llm/remote_chat_context.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from .chat_context import ChatItem, ChatContext + +__all__ = ["RemoteChatContext"] + + +@dataclass +class _RemoteChatItem: + item: ChatItem + _prev: _RemoteChatItem | None = field(default=None, repr=False) + _next: _RemoteChatItem | None = field(default=None, repr=False) + + +class RemoteChatContext: + def __init__(self) -> None: + self._head: _RemoteChatItem | None = None + self._tail: _RemoteChatItem | None = None + self._id_to_item: dict[str, _RemoteChatItem] = {} + + def to_chat_ctx(self) -> ChatContext: + items: list[ChatItem] = [] + current_node = self._head + while current_node is not None: + items.append(current_node.item) + current_node = current_node._next + + return ChatContext(items=items) + + def insert(self, previous_item_id: str | None, message: ChatItem) -> None: + """ + Insert `message` after the node with ID `previous_item_id`. + If `previous_item_id` is None, insert at the head. + """ + item_id = message.id + + if item_id in self._id_to_item: + raise ValueError(f"Item with ID {item_id} already exists.") + + new_node = _RemoteChatItem(item=message) + + if previous_item_id is None: + if self._head is not None: + new_node._next = self._head + self._head._prev = new_node + else: + self._tail = new_node + + self._head = new_node + self._id_to_item[item_id] = new_node + return + + prev_node = self._id_to_item.get(previous_item_id) + if prev_node is None: + raise ValueError( + f"No item found with ID {previous_item_id} to insert after." + ) + + new_node._prev = prev_node + new_node._next = prev_node._next + + prev_node._next = new_node + + if new_node._next is not None: + new_node._next._prev = new_node + else: + self._tail = new_node + + self._id_to_item[item_id] = new_node + + def delete(self, item_id: str) -> None: + node = self._id_to_item.get(item_id) + if node is None: + raise ValueError(f"No item found with ID {item_id} to delete.") + + prev_node = node._prev + next_node = node._next + + if self._head == node: + self._head = next_node + if self._head is not None: + self._head._prev = None + else: + if prev_node is not None: + prev_node._next = next_node + + if self._tail == node: + self._tail = prev_node + if self._tail is not None: + self._tail._next = None + else: + if next_node is not None: + next_node._prev = prev_node + + del self._id_to_item[item_id] diff --git a/livekit-agents/livekit/agents/llm/utils.py b/livekit-agents/livekit/agents/llm/utils.py index a6ff16172..68a1fa039 100644 --- a/livekit-agents/livekit/agents/llm/utils.py +++ b/livekit-agents/livekit/agents/llm/utils.py @@ -1,8 +1,28 @@ from __future__ import annotations +import inspect from dataclasses import dataclass +from typing import ( + Annotated, + Any, + Callable, + get_args, + get_origin, + get_type_hints, + TYPE_CHECKING, +) + +from pydantic import BaseModel, create_model +from pydantic.fields import FieldInfo + from .chat_context import ChatContext +from .function_context import AIFunction, get_function_info + + +if TYPE_CHECKING: + from ..pipeline.context import AgentContext + from ..pipeline.speech_handle import SpeechHandle def _compute_lcs(old_ids: list[str], new_ids: list[str]) -> list[str]: @@ -70,3 +90,120 @@ def compute_chat_ctx_diff(old_ctx: ChatContext, new_ctx: ChatContext) -> DiffOps last_id_in_sequence = new_msg.id return DiffOps(to_remove=to_remove, to_create=to_create) + + +# Convert FunctionContext to LLM API format + + +def is_context_type(ty: type) -> bool: + from ..pipeline.context import AgentContext + from ..pipeline.speech_handle import SpeechHandle + + return ty is AgentContext or ty is SpeechHandle + + +def build_legacy_openai_schema( + ai_function: AIFunction, *, internally_tagged: bool = False +) -> dict[str, Any]: + """non-strict mode tool description + see https://serde.rs/enum-representations.html for the internally tagged representation""" + model = function_arguments_to_pydantic_model(ai_function) + schema = model.model_json_schema() + + info = get_function_info(ai_function) + + if internally_tagged: + return { + "name": info.name, + "description": info.description or "", + "parameters": schema, + "type": "function", + } + else: + return { + "type": "function", + "function": { + "name": info.name, + "description": info.description or "", + "parameters": schema, + }, + } + + +def function_arguments_to_pydantic_model( + func: Callable, +) -> type[BaseModel]: + """ + Create a Pydantic model from a function’s signature. (excluding context types) + """ + fnc_name = func.__name__.split("_") + fnc_name = "".join(x.capitalize() for x in fnc_name) + model_name = fnc_name + "Args" + + signature = inspect.signature(func) + type_hints = get_type_hints(func, include_extras=True) + + # field_name -> (type, FieldInfo or default) + fields: dict[str, Any] = {} + + for param_name, param in signature.parameters.items(): + type_hint = type_hints[param_name] + + if is_context_type(type_hint): + continue + + default_value = param.default if param.default is not param.empty else ... + + # Annotated[str, Field(description="...")] + if get_origin(type_hint) is Annotated: + annotated_args = get_args(type_hint) + actual_type = annotated_args[0] + field_info = None + + for extra in annotated_args[1:]: + if isinstance(extra, FieldInfo): + field_info = extra # get the first FieldInfo + break + + if field_info: + if default_value is not ... and field_info.default is None: + field_info.default = default_value + fields[param_name] = (actual_type, field_info) + else: + fields[param_name] = (actual_type, default_value) + + else: + fields[param_name] = (type_hint, default_value) + + return create_model(model_name, **fields) + + +def pydantic_model_to_function_arguments( + *, + ai_function: Callable, + model: BaseModel, + agent_ctx: AgentContext | None = None, + speech_handle: SpeechHandle | None = None, +) -> tuple[tuple[Any, ...], dict[str, Any]]: + """ + Convert a model’s fields into function args/kwargs. + Raises TypeError if required params are missing + """ + + from ..pipeline.context import AgentContext + from ..pipeline.speech_handle import SpeechHandle + + signature = inspect.signature(ai_function) + type_hints = get_type_hints(ai_function, include_extras=True) + + context_dict = {} + for param_name, _ in signature.parameters.items(): + type_hint = type_hints[param_name] + if type_hint is AgentContext and agent_ctx is not None: + context_dict[param_name] = agent_ctx + elif type_hint is SpeechHandle and speech_handle is not None: + context_dict[param_name] = speech_handle + + bound = signature.bind(**{**model.model_dump(), **context_dict}) + bound.apply_defaults() + return bound.args, bound.kwargs diff --git a/livekit-agents/livekit/agents/multimodal/__init__.py b/livekit-agents/livekit/agents/multimodal/__init__.py index c38056a68..bebf8faf2 100644 --- a/livekit-agents/livekit/agents/multimodal/__init__.py +++ b/livekit-agents/livekit/agents/multimodal/__init__.py @@ -1,6 +1,8 @@ from .realtime import ( ErrorEvent, + RealtimeError, GenerationCreatedEvent, + MessageGeneration, InputSpeechStartedEvent, InputSpeechStoppedEvent, RealtimeCapabilities, @@ -10,10 +12,12 @@ __all__ = [ "RealtimeModel", + "RealtimeError", "RealtimeCapabilities", "RealtimeSession", "InputSpeechStartedEvent", "InputSpeechStoppedEvent", "GenerationCreatedEvent", "ErrorEvent", + "MessageGeneration", ] diff --git a/livekit-agents/livekit/agents/multimodal/realtime.py b/livekit-agents/livekit/agents/multimodal/realtime.py index 7705476ee..bc39509bc 100644 --- a/livekit-agents/livekit/agents/multimodal/realtime.py +++ b/livekit-agents/livekit/agents/multimodal/realtime.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio + from abc import ABC, abstractmethod from dataclasses import dataclass from typing import AsyncIterable, Generic, Literal, TypeVar, Union @@ -20,10 +22,15 @@ class InputSpeechStoppedEvent: @dataclass -class GenerationCreatedEvent: +class MessageGeneration: message_id: str text_stream: AsyncIterable[str] audio_stream: AsyncIterable[rtc.AudioFrame] + + +@dataclass +class GenerationCreatedEvent: + message_stream: AsyncIterable[MessageGeneration] function_stream: AsyncIterable[llm.FunctionCall] @@ -38,6 +45,11 @@ class RealtimeCapabilities: message_truncation: bool +class RealtimeError(Exception): + def __init__(self, message: str) -> None: + super().__init__(message) + + class RealtimeModel: def __init__(self, *, capabilities: RealtimeCapabilities) -> None: self._capabilities = capabilities @@ -99,7 +111,9 @@ async def update_fnc_ctx( def push_audio(self, frame: rtc.AudioFrame) -> None: ... @abstractmethod - def generate_reply(self) -> None: ... # when VAD is disabled + def generate_reply( + self, + ) -> asyncio.Future[GenerationCreatedEvent]: ... # can raise RealtimeError # cancel the current generation (do nothing if no generation is in progress) @abstractmethod diff --git a/livekit-agents/livekit/agents/pipeline/__init__.py b/livekit-agents/livekit/agents/pipeline/__init__.py index 1c92d54c2..375e329ab 100644 --- a/livekit-agents/livekit/agents/pipeline/__init__.py +++ b/livekit-agents/livekit/agents/pipeline/__init__.py @@ -1,6 +1,7 @@ +from .agent_task import AgentTask from .chat_cli import ChatCLI from .pipeline_agent import PipelineAgent -from .agent_task import AgentTask from .speech_handle import SpeechHandle +from .context import AgentContext -__all__ = ["ChatCLI", "PipelineAgent", "AgentTask", "SpeechHandle"] +__all__ = ["ChatCLI", "PipelineAgent", "AgentTask", "SpeechHandle", "AgentContext"] diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py index 77b60d2c9..7d928151f 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_task.py +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -1,28 +1,44 @@ from __future__ import annotations -import time import asyncio +import contextlib import heapq +import inspect +import time +from dataclasses import dataclass from typing import ( + Any, AsyncIterable, Optional, Union, + TYPE_CHECKING, ) from livekit import rtc +from pydantic import ValidationError +from torch import mul from .. import debug, llm, multimodal, stt, tokenize, tts, utils, vad -from ..llm import ChatContext, FunctionContext, find_ai_functions +from ..llm import ( + ChatContext, + FunctionContext, + AIError, + find_ai_functions, + utils as llm_utils, +) from ..log import logger +from ..types import NOT_GIVEN, NotGivenOr +from ..utils import is_given +from . import io from .audio_recognition import AudioRecognition, RecognitionHooks, _TurnDetector from .generation import ( _TTSGenerationData, do_llm_inference, do_tts_inference, ) -from typing import TYPE_CHECKING - from .speech_handle import SpeechHandle +from .context import AgentContext + if TYPE_CHECKING: from .pipeline_agent import PipelineAgent @@ -33,13 +49,13 @@ def __init__( self, *, instructions: str, - chat_ctx: llm.ChatContext | None = None, - fnc_ctx: llm.FunctionContext | None = None, - turn_detector: _TurnDetector | None = None, - stt: stt.STT | None = None, - vad: vad.VAD | None = None, - llm: llm.LLM | multimodal.RealtimeModel | None = None, - tts: tts.TTS | None = None, + chat_ctx: NotGivenOr[llm.ChatContext] = NOT_GIVEN, + fnc_ctx: NotGivenOr[llm.FunctionContext] = NOT_GIVEN, + turn_detector: NotGivenOr[_TurnDetector] = NOT_GIVEN, + stt: NotGivenOr[stt.STT] = NOT_GIVEN, + vad: NotGivenOr[vad.VAD] = NOT_GIVEN, + llm: NotGivenOr[llm.LLM | multimodal.RealtimeModel] = NOT_GIVEN, + tts: NotGivenOr[tts.TTS] = NOT_GIVEN, ) -> None: if tts and not tts.capabilities.streaming: from .. import tts as text_to_speech @@ -51,7 +67,7 @@ def __init__( if stt and not stt.capabilities.streaming: from .. import stt as speech_to_text - if vad is None: + if not is_given(vad): raise ValueError( "VAD is required when streaming is not supported by the STT" ) @@ -65,11 +81,13 @@ def __init__( self._chat_ctx = chat_ctx or ChatContext.empty() self._fnc_ctx = fnc_ctx or FunctionContext.empty() self._fnc_ctx.update_ai_functions( - list(self._fnc_ctx.ai_functions.values()) - + find_ai_functions(self.__class__) + list(self._fnc_ctx.ai_functions.values()) + find_ai_functions(self) ) - self._turn_detector = turn_detector - self._stt, self._llm, self._tts, self._vad = stt, llm, tts, vad + self._turn_detector = turn_detector or None + self._stt = stt or None + self._llm = llm or None + self._tts = tts or None + self._vad = vad or None @property def instructions(self) -> str: @@ -119,7 +137,7 @@ async def _forward_input(): async for event in stream: yield event finally: - await utils.aio.gracefully_cancel(forward_task) + await utils.aio.cancel_and_wait(forward_task) async def llm_node( self, chat_ctx: llm.ChatContext, fnc_ctx: llm.FunctionContext | None @@ -155,13 +173,13 @@ async def _forward_input(): async for ev in stream: yield ev.frame finally: - await utils.aio.gracefully_cancel(forward_task) + await utils.aio.cancel_and_wait(forward_task) - def _create_activity(self, agent: PipelineAgent) -> ActiveTask: - return ActiveTask(task=self, agent=agent) + def _create_activity(self, agent: PipelineAgent) -> TaskActivity: + return TaskActivity(task=self, agent=agent) -class ActiveTask(RecognitionHooks): +class TaskActivity(RecognitionHooks): def __init__(self, task: AgentTask, agent: PipelineAgent) -> None: self._task, self._agent = task, agent self._rt_session: multimodal.RealtimeSession | None = None @@ -184,7 +202,7 @@ def draining(self) -> bool: return self._draining async def drain(self) -> None: - self._speech_q_changed.set() # TODO(theomonnom): refactor so we don't need this here + self._speech_q_changed.set() # TODO(theomonnom): we shouldn't need this here self._draining = True if self._main_atask is not None: @@ -212,8 +230,9 @@ async def start(self) -> None: self._rt_session.on( "input_speech_stopped", self._on_input_speech_stopped ) - await self._rt_session.update_chat_ctx(self._task.chat_ctx) await self._rt_session.update_instructions(self._task.instructions) + await self._rt_session.update_chat_ctx(self._task.chat_ctx) + await self._rt_session.update_fnc_ctx(self._task.fnc_ctx) self._started = True @@ -226,7 +245,7 @@ async def aclose(self) -> None: await self._audio_recognition.aclose() if self._main_atask is not None: - await utils.aio.gracefully_cancel(self._main_atask) + await utils.aio.cancel_and_wait(self._main_atask) def push_audio(self, frame: rtc.AudioFrame) -> None: if not self._started: @@ -409,88 +428,6 @@ async def _pipeline_reply_task( *, speech_handle: SpeechHandle, ) -> None: - @utils.log_exceptions(logger=logger) - async def _forward_llm_text(llm_output: AsyncIterable[str]) -> None: - """collect and forward the generated text to the current agent output""" - try: - async for delta in llm_output: - if agent.output.text is None: - break - - await agent.output.text.capture_text(delta) - finally: - if agent.output.text is not None: - agent.output.text.flush() - - @utils.log_exceptions(logger=logger) - async def _forward_tts_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: - """collect and forward the generated audio to the current agent output (generally playout)""" - try: - async for frame in tts_output: - if agent.output.audio is None: - break - await agent.output.audio.capture_frame(frame) - finally: - if agent.output.audio is not None: - agent.output.audio.flush() # always flush (even if the task is interrupted) - - @utils.log_exceptions(logger=logger) - async def _execute_tools( - tools_ch: utils.aio.Chan[llm.FunctionCallInfo], - called_functions: set[llm.CalledFunction], - ) -> None: - """execute tools, when cancelled, stop executing new tools but wait for the pending ones""" - try: - async for tool in tools_ch: - logger.debug( - "executing tool", - extra={ - "function": tool.function_info.name, - "speech_id": speech_handle.id, - }, - ) - debug.Tracing.log_event( - "executing tool", - { - "function": tool.function_info.name, - "speech_id": speech_handle.id, - }, - ) - cfnc = tool.execute() - called_functions.add(cfnc) - except asyncio.CancelledError: - # don't allow to cancel running function calla if they're still running - pending_tools = [cfn for cfn in called_functions if not cfn.task.done()] - - if pending_tools: - names = [cfn.call_info.function_info.name for cfn in pending_tools] - - logger.debug( - "waiting for function call to finish before cancelling", - extra={ - "functions": names, - "speech_id": speech_handle.id, - }, - ) - debug.Tracing.log_event( - "waiting for function call to finish before cancelling", - { - "functions": names, - "speech_id": speech_handle.id, - }, - ) - await asyncio.gather(*[cfn.task for cfn in pending_tools]) - finally: - if len(called_functions) > 0: - logger.debug( - "tools execution completed", - extra={"speech_id": speech_handle.id}, - ) - debug.Tracing.log_event( - "tools execution completed", - {"speech_id": speech_handle.id}, - ) - debug.Tracing.log_event( "generation started", {"speech_id": speech_handle.id, "step_index": speech_handle.step_index}, @@ -533,7 +470,7 @@ async def _execute_tools( ) if speech_handle.interrupted: - await utils.aio.gracefully_cancel(*tasks) + await utils.aio.cancel_and_wait(*tasks) speech_handle._mark_done() return # return directly (the generated output wasn't used) @@ -587,7 +524,7 @@ async def _execute_tools( ) if speech_handle.interrupted: - await utils.aio.gracefully_cancel(*tasks) + await utils.aio.cancel_and_wait(*tasks) if len(called_functions) > 0: functions = [ @@ -665,47 +602,69 @@ async def _realtime_reply_task( ) -> None: assert self._rt_session is not None, "rt_session is not available" + debug.Tracing.log_event( + "realtime generation started", + {"speech_id": speech_handle.id, "step_index": speech_handle.step_index}, + ) + audio_output = self._agent.output.audio text_output = self._agent.output.text - @utils.log_exceptions(logger=logger) - async def _forward_text(llm_output: AsyncIterable[str]) -> None: - assert text_output is not None, "text_output is not available" - try: - async for delta in llm_output: - await text_output.capture_text(delta) - finally: - text_output.flush() - - @utils.log_exceptions(logger=logger) - async def _forward_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: - assert audio_output is not None, "audio_output is not available" - try: - async for frame in tts_output: - await audio_output.capture_frame(frame) - finally: - audio_output.flush() - await speech_handle.wait_until_interrupted( [speech_handle._wait_for_authorization()] ) if speech_handle.interrupted: - speech_handle._mark_playout_done() # TODO(theomonnom): remove the message from the serverside history + speech_handle._mark_playout_done() + # TODO(theomonnom): remove the message from the serverside history return - ts = utils.aio.TaskSet() - if text_output is not None: - ts.create_task( - _forward_text(generation_ev.text_stream), - name="_realtime_reply_task.forward_text", - ) + @utils.log_exceptions(logger=logger) + async def _read_messages(message_outputs: list[_MessageOutput]) -> None: + ts2 = utils.aio.TaskSet() + async for msg in generation_ev.message_stream: + if ts2.tasks: + logger.warning( + "expected to receive only one message generation from the realtime API" + ) + break - if audio_output is not None: - ts.create_task( - _forward_audio(generation_ev.audio_stream), - name="_realtime_reply_task.forward_audio", - ) + out = _MessageOutput(text="", audio=[]) + message_outputs.append(out) + + if text_output is not None: + ts2.create_task( + _forward_text(text_output, msg.text_stream, out), + name="_realtime_reply_task.forward_text", + ) + + if audio_output is not None: + ts2.create_task( + _forward_audio(audio_output, msg.audio_stream, out), + name="_realtime_reply_task.forward_audio", + ) + + try: + await asyncio.gather(*ts2.tasks) + finally: + await utils.aio.cancel_and_wait(*ts2.tasks) + + function_outputs: list[_FunctionCallOutput] = [] + message_outputs: list[_MessageOutput] = [] + ts = utils.aio.TaskSet() + ts.create_task( + _read_messages(message_outputs), name="_realtime_reply_task.read_messages" + ) + ts.create_task( + _execute_tools( + agent_ctx=AgentContext(self._agent), + fnc_ctx=self._task.fnc_ctx, + speech_handle=speech_handle, + function_stream=generation_ev.function_stream, + out=function_outputs, + ), + name="_realtime_reply_task.execute_tools", + ) await speech_handle.wait_until_interrupted([*ts.tasks]) @@ -715,7 +674,7 @@ async def _forward_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: ) if speech_handle.interrupted: - await utils.aio.gracefully_cancel(*ts.tasks) + await utils.aio.cancel_and_wait(*ts.tasks) if audio_output is not None: audio_output.clear_buffer() @@ -729,11 +688,323 @@ async def _forward_audio(tts_output: AsyncIterable[rtc.AudioFrame]) -> None: }, ) - # TODO(theomonnom): truncate serverside message speech_handle._mark_playout_done() + # TODO(theomonnom): truncate message (+ OAI serverside mesage) return - # TODO(theomonnom): tools + + new_agent_task: AgentTask | None = None + new_items: list[llm.ChatItem] = [] + if len(function_outputs) > 0: + for out in function_outputs: + if isinstance(out.exception, AIError): + new_items.append( + llm.FunctionCallOutput( + call_id=out.call_id, + output=out.exception.message, + is_error=True, + ) + ) + continue + elif out.exception is not None: + logger.error( + "exception occurred while executing tool", + extra={ + "function": out.name, + "speech_id": speech_handle.id, + }, + exc_info=out.exception, + ) + continue + + if out.output is not None: + if isinstance(out.output, tuple): + agent_tasks = [ + item for item in out.output if isinstance(item, AgentTask) + ] + if len(agent_tasks) > 1: + logger.error( + "Multiple AgentTask instances found in tuple output", + extra={ + "call_id": out.call_id, + "function": out.name, + "output": out.output, + }, + ) + continue + + new_agent_task = agent_tasks[0] if agent_tasks else None + out.output = tuple( + item + for item in out.output + if not isinstance(item, AgentTask) + ) + if len(out.output) == 1: + out.output = out.output[0] + elif isinstance(out.output, AgentTask): + new_agent_task = out.output + out.output = None + + if not _is_valid_function_output_type(out.output): + logger.error( + "invalid function output type", + extra={ + "call_id": out.call_id, + "function": out.name, + "output": out.output, + }, + ) + continue + + new_items.append( + llm.FunctionCallOutput( + call_id=out.call_id, + output=str(out.output), + is_error=False, + ) + ) + + if new_items: + chat_ctx = self._rt_session.chat_ctx.copy() + chat_ctx.items.extend(new_items) + await self._rt_session.update_chat_ctx(chat_ctx) + self._rt_session.interrupt() + self._rt_session.generate_reply() + + + if new_agent_task is not None: + if new_items: + await asyncio.sleep(1.0) + self._agent.update_task(new_agent_task) debug.Tracing.log_event("playout completed", {"speech_id": speech_handle.id}) speech_handle._mark_playout_done() + + +def _is_valid_function_output_type(value: Any) -> bool: + VALID_TYPES = (str, int, float, bool, complex, type(None)) + + if isinstance(value, VALID_TYPES): + return True + elif ( + isinstance(value, list) + or isinstance(value, set) + or isinstance(value, frozenset) + or isinstance(value, tuple) + ): + return all(_is_valid_function_output_type(item) for item in value) + elif isinstance(value, dict): + return all( + isinstance(key, VALID_TYPES) and _is_valid_function_output_type(val) + for key, val in value.items() + ) + return False + + +@dataclass +class _MessageOutput: + text: str + audio: list[rtc.AudioFrame] + + +@dataclass +class _FunctionCallOutput: + call_id: str + name: str + arguments: str + output: Any + exception: BaseException | None + + +@utils.log_exceptions(logger=logger) +async def _forward_text( + text_output: io.TextSink, llm_output: AsyncIterable[str], out: _MessageOutput +) -> None: + try: + async for delta in llm_output: + out.text += delta + await text_output.capture_text(delta) + finally: + text_output.flush() + + +@utils.log_exceptions(logger=logger) +async def _forward_audio( + audio_output: io.AudioSink, + tts_output: AsyncIterable[rtc.AudioFrame], + out: _MessageOutput, +) -> None: + try: + async for frame in tts_output: + out.audio.append(frame) + await audio_output.capture_frame(frame) + finally: + audio_output.flush() + + +@utils.log_exceptions(logger=logger) +async def _execute_tools( + *, + agent_ctx: AgentContext, + fnc_ctx: FunctionContext, + speech_handle: SpeechHandle, + function_stream: utils.aio.Chan[llm.FunctionCall], + out: list[_FunctionCallOutput] = [], +) -> None: + """execute tools, when cancelled, stop executing new tools but wait for the pending ones""" + tasks: list[tuple[str, asyncio.Task]] = [] + try: + async for fnc_call in function_stream: + ai_function = fnc_ctx.ai_functions.get(fnc_call.name, None) + if ai_function is None: + logger.warning( + f"LLM called function `{fnc_call.name}` but it was not found in the current task", + extra={ + "function": fnc_call.name, + "speech_id": speech_handle.id, + }, + ) + continue + + try: + function_model = llm_utils.function_arguments_to_pydantic_model( + ai_function + ) + parsed_args = function_model.model_validate_json(fnc_call.arguments) + except ValidationError: + logger.exception( + "LLM called function `{fnc.name}` with invalid arguments", + extra={ + "function": fnc_call.name, + "arguments": fnc_call.arguments, + "speech_id": speech_handle.id, + }, + ) + continue + + logger.debug( + "executing tool", + extra={ + "function": fnc_call.name, + "speech_id": speech_handle.id, + }, + ) + debug.Tracing.log_event( + "executing tool", + { + "function": fnc_call.name, + "speech_id": speech_handle.id, + }, + ) + + fnc_args, fnc_kwargs = llm_utils.pydantic_model_to_function_arguments( + ai_function=ai_function, + model=parsed_args, + agent_ctx=agent_ctx, + speech_handle=speech_handle, + ) + + if inspect.iscoroutinefunction(ai_function): + task = asyncio.create_task( + ai_function(*fnc_args, **fnc_kwargs), + name=f"ai_function_{fnc_call.name}", + ) + tasks.append((fnc_call.name, task)) + + def _log_exceptions(task: asyncio.Task) -> None: + if task.exception() is not None: + logger.error( + "exception occurred while executing tool", + extra={ + "function": fnc_call.name, + "speech_id": speech_handle.id, + }, + exc_info=task.exception(), + ) + out.append( + _FunctionCallOutput( + name=fnc_call.name, + arguments=fnc_call.arguments, + call_id=fnc_call.call_id, + output=None, + exception=task.exception(), + ) + ) + return + + out.append( + _FunctionCallOutput( + name=fnc_call.name, + arguments=fnc_call.arguments, + call_id=fnc_call.call_id, + output=task.result(), + exception=None, + ) + ) + + tasks.remove((fnc_call.name, task)) + + task.add_done_callback(_log_exceptions) + else: + start_time = time.monotonic() + try: + output = ai_function(*fnc_args, **fnc_kwargs) + out.append( + _FunctionCallOutput( + name=fnc_call.name, + arguments=fnc_call.arguments, + call_id=fnc_call.call_id, + output=output, + exception=None, + ) + ) + except Exception as e: + out.append( + _FunctionCallOutput( + name=fnc_call.name, + arguments=fnc_call.arguments, + call_id=fnc_call.call_id, + output=None, + exception=e, + ) + ) + + elapsed = time.monotonic() - start_time + if elapsed >= 1.5: + logger.warning( + f"function execution took too long ({elapsed:.2f}s), is `{fnc_call.name}` blocking?", + extra={ + "function": fnc_call.name, + "speech_id": speech_handle.id, + "elapsed": elapsed, + }, + ) + + except asyncio.CancelledError: + if len(tasks) > 0: + names = [name for name, _ in tasks] + logger.debug( + "waiting for function call to finish before fully cancelling", + extra={ + "functions": names, + "speech_id": speech_handle.id, + }, + ) + debug.Tracing.log_event( + "waiting for function call to finish before fully cancelling", + { + "functions": names, + "speech_id": speech_handle.id, + }, + ) + await asyncio.gather(*[task for _, task in tasks]) + finally: + if len(out) > 0: + logger.debug( + "tools execution completed", + extra={"speech_id": speech_handle.id}, + ) + debug.Tracing.log_event( + "tools execution completed", + {"speech_id": speech_handle.id}, + ) diff --git a/livekit-agents/livekit/agents/pipeline/audio_recognition.py b/livekit-agents/livekit/agents/pipeline/audio_recognition.py index 792427510..9ee31a244 100644 --- a/livekit-agents/livekit/agents/pipeline/audio_recognition.py +++ b/livekit-agents/livekit/agents/pipeline/audio_recognition.py @@ -86,13 +86,13 @@ def push_audio(self, frame: rtc.AudioFrame) -> None: async def aclose(self) -> None: if self._stt_atask is not None: - await aio.gracefully_cancel(self._stt_atask) + await aio.cancel_and_wait(self._stt_atask) if self._vad_atask is not None: - await aio.gracefully_cancel(self._vad_atask) + await aio.cancel_and_wait(self._vad_atask) if self._end_of_turn_task is not None: - await aio.gracefully_cancel(self._end_of_turn_task) + await aio.cancel_and_wait(self._end_of_turn_task) def update_stt(self, stt: io.STTNode | None) -> None: self._stt = stt @@ -212,7 +212,7 @@ async def _stt_task( task: asyncio.Task[None] | None, ) -> None: if task is not None: - await aio.gracefully_cancel(task) + await aio.cancel_and_wait(task) node = stt_node(audio_input) if asyncio.iscoroutine(node): @@ -232,7 +232,7 @@ async def _vad_task( self, vad: vad.VAD, audio_input: io.AudioStream, task: asyncio.Task[None] | None ) -> None: if task is not None: - await aio.gracefully_cancel(task) + await aio.cancel_and_wait(task) stream = vad.stream() @@ -247,4 +247,4 @@ async def _forward() -> None: await self._on_vad_event(ev) finally: await stream.aclose() - await aio.gracefully_cancel(forward_task) + await aio.cancel_and_wait(forward_task) diff --git a/livekit-agents/livekit/agents/pipeline/chat_cli.py b/livekit-agents/livekit/agents/pipeline/chat_cli.py index 9da68a933..1930bb92a 100644 --- a/livekit-agents/livekit/agents/pipeline/chat_cli.py +++ b/livekit-agents/livekit/agents/pipeline/chat_cli.py @@ -1,9 +1,9 @@ from __future__ import annotations import asyncio -import threading import sys import termios +import threading import time import tty from typing import Literal @@ -86,7 +86,6 @@ async def capture_frame(self, frame: rtc.AudioFrame) -> None: self._capture_start = time.monotonic() self._pushed_duration += frame.duration - print(f"Pushed audio frame, duration: {self._pushed_duration:.2f}s") with self._output_lock: self._output_buf += frame.data @@ -98,7 +97,6 @@ def flush(self) -> None: to_wait = max( 0.0, self._pushed_duration - (time.monotonic() - self._capture_start) ) - print(f"Flushing audio buffer, waiting for {to_wait:.2f}s") self._dispatch_handle = self._cli._loop.call_later( to_wait, self._dispatch_playback_finished ) @@ -124,7 +122,6 @@ def clear_buffer(self) -> None: ) def _dispatch_playback_finished(self) -> None: - print("sending playback finished event") self.on_playback_finished( playback_position=self._pushed_duration, interrupted=False ) @@ -188,7 +185,7 @@ def _on_input(): render_cli_task = asyncio.create_task(self._render_cli_task()) await self._done_fut - await aio.gracefully_cancel(render_cli_task) + await aio.cancel_and_wait(render_cli_task) self._update_microphone(enable=False) self._update_speaker(enable=False) diff --git a/livekit-agents/livekit/agents/pipeline/context.py b/livekit-agents/livekit/agents/pipeline/context.py new file mode 100644 index 000000000..7dc1a53c5 --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/context.py @@ -0,0 +1,17 @@ +from __future__ import annotations + + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from .pipeline_agent import PipelineAgent + + +class AgentContext: + def __init__(self, agent: PipelineAgent) -> None: + self._agent = agent + + @property + def agent(self) -> PipelineAgent: + return self._agent diff --git a/livekit-agents/livekit/agents/pipeline/io.py b/livekit-agents/livekit/agents/pipeline/io.py index 512a24f71..c37939a78 100644 --- a/livekit-agents/livekit/agents/pipeline/io.py +++ b/livekit-agents/livekit/agents/pipeline/io.py @@ -15,12 +15,11 @@ from livekit import rtc from .. import llm, stt - from ..log import logger STTNode = Callable[ [AsyncIterable[rtc.AudioFrame]], - Union[Awaitable[Optional[AsyncIterable[stt.SpeechEvent]]]], + Union[Awaitable[Optional[AsyncIterable[stt.SpeechEvent]]]], # TODO: support str ] LLMNode = Callable[ [llm.ChatContext, Optional[llm.FunctionContext]], @@ -65,6 +64,9 @@ def __init__(self, *, sample_rate: int | None = None) -> None: self.__playback_segments_count = 0 self.__playback_finished_count = 0 + self.__last_playback_ev: PlaybackFinishedEvent = PlaybackFinishedEvent( + playback_position=0, interrupted=False + ) def on_playback_finished( self, *, playback_position: float, interrupted: bool diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 53172c0c8..c7e70bdd7 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -1,12 +1,10 @@ from __future__ import annotations, print_function import asyncio -import heapq from dataclasses import dataclass from typing import ( AsyncIterable, Literal, - Tuple, ) from livekit import rtc @@ -14,7 +12,7 @@ from .. import debug, llm, utils from ..log import logger from . import io -from .agent_task import ActiveTask, AgentTask +from .agent_task import TaskActivity, AgentTask from .speech_handle import SpeechHandle EventTypes = Literal[ @@ -36,11 +34,6 @@ class PipelineOptions: max_fnc_steps: int -class AgentContext: - def __init__(self) -> None: - pass - - class PipelineAgent(rtc.EventEmitter[EventTypes]): def __init__( self, @@ -81,7 +74,7 @@ def __init__( # agent tasks self._current_task: AgentTask = task - self._active_task: ActiveTask | None = None + self._active_task: TaskActivity | None = None # -- Pipeline nodes -- # They can all be overriden by subclasses, by default they use the STT/LLM/TTS specified in the @@ -107,7 +100,7 @@ async def aclose(self) -> None: return if self._forward_audio_atask is not None: - await utils.aio.gracefully_cancel(self._forward_audio_atask) + await utils.aio.cancel_and_wait(self._forward_audio_atask) @property def options(self) -> PipelineOptions: @@ -150,7 +143,7 @@ def update_task(self, task: AgentTask) -> None: self._current_task = task if self._started: - self._update_activity_task = asyncio.create_task( + self._update_activity_atask = asyncio.create_task( self._update_activity_task(self._current_task), name="_update_activity_task", ) diff --git a/livekit-agents/livekit/agents/pipeline/speech_scheduler.py b/livekit-agents/livekit/agents/pipeline/speech_scheduler.py deleted file mode 100644 index 8b1378917..000000000 --- a/livekit-agents/livekit/agents/pipeline/speech_scheduler.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/livekit-agents/livekit/agents/stt/fallback_adapter.py b/livekit-agents/livekit/agents/stt/fallback_adapter.py index ac11a76db..9d5df7e13 100644 --- a/livekit-agents/livekit/agents/stt/fallback_adapter.py +++ b/livekit-agents/livekit/agents/stt/fallback_adapter.py @@ -237,10 +237,10 @@ def stream( async def aclose(self) -> None: for stt_status in self._status: if stt_status.recovering_synthesize_task is not None: - await aio.gracefully_cancel(stt_status.recovering_synthesize_task) + await aio.cancel_and_wait(stt_status.recovering_synthesize_task) if stt_status.recovering_stream_task is not None: - await aio.gracefully_cancel(stt_status.recovering_stream_task) + await aio.cancel_and_wait(stt_status.recovering_stream_task) class FallbackRecognizeStream(RecognizeStream): @@ -340,7 +340,7 @@ async def _forward_input_task() -> None: self._try_recovery(stt) if forward_input_task is not None: - await aio.gracefully_cancel(forward_input_task) + await aio.cancel_and_wait(forward_input_task) await asyncio.gather(*[stream.aclose() for stream in self._recovering_streams]) diff --git a/livekit-agents/livekit/agents/stt/stream_adapter.py b/livekit-agents/livekit/agents/stt/stream_adapter.py index 0e69d65c5..b5812c10b 100644 --- a/livekit-agents/livekit/agents/stt/stream_adapter.py +++ b/livekit-agents/livekit/agents/stt/stream_adapter.py @@ -123,4 +123,4 @@ async def _recognize(): try: await asyncio.gather(*tasks) finally: - await utils.aio.gracefully_cancel(*tasks) + await utils.aio.cancel_and_wait(*tasks) diff --git a/livekit-agents/livekit/agents/stt/stt.py b/livekit-agents/livekit/agents/stt/stt.py index 4bed361c3..b4b063e99 100644 --- a/livekit-agents/livekit/agents/stt/stt.py +++ b/livekit-agents/livekit/agents/stt/stt.py @@ -302,7 +302,7 @@ def end_input(self) -> None: async def aclose(self) -> None: """Close ths stream immediately""" self._input_ch.close() - await aio.gracefully_cancel(self._task) + await aio.cancel_and_wait(self._task) if self._metrics_task is not None: await self._metrics_task diff --git a/livekit-agents/livekit/agents/tts/fallback_adapter.py b/livekit-agents/livekit/agents/tts/fallback_adapter.py index d990d5934..ade6ce672 100644 --- a/livekit-agents/livekit/agents/tts/fallback_adapter.py +++ b/livekit-agents/livekit/agents/tts/fallback_adapter.py @@ -141,7 +141,7 @@ def stream( async def aclose(self) -> None: for tts_status in self._status: if tts_status.recovering_task is not None: - await aio.gracefully_cancel(tts_status.recovering_task) + await aio.cancel_and_wait(tts_status.recovering_task) class FallbackChunkedStream(ChunkedStream): @@ -449,9 +449,9 @@ async def _input_task() -> None: raise finally: if next_audio_task is not None: - await utils.aio.gracefully_cancel(next_audio_task) + await utils.aio.cancel_and_wait(next_audio_task) - await utils.aio.gracefully_cancel(input_task) + await utils.aio.cancel_and_wait(input_task) async def _run(self) -> None: start_time = time.time() @@ -589,7 +589,7 @@ async def _forward_input_task(): ) ) finally: - await utils.aio.gracefully_cancel(input_task) + await utils.aio.cancel_and_wait(input_task) def _try_recovery(self, tts: TTS) -> None: assert isinstance(self._tts, FallbackAdapter) diff --git a/livekit-agents/livekit/agents/tts/stream_adapter.py b/livekit-agents/livekit/agents/tts/stream_adapter.py index fbb25df5d..318e9e1e1 100644 --- a/livekit-agents/livekit/agents/tts/stream_adapter.py +++ b/livekit-agents/livekit/agents/tts/stream_adapter.py @@ -105,4 +105,4 @@ async def _synthesize(): try: await asyncio.gather(*tasks) finally: - await utils.aio.gracefully_cancel(*tasks) + await utils.aio.cancel_and_wait(*tasks) diff --git a/livekit-agents/livekit/agents/tts/tts.py b/livekit-agents/livekit/agents/tts/tts.py index e641bf39d..d79b62ddc 100644 --- a/livekit-agents/livekit/agents/tts/tts.py +++ b/livekit-agents/livekit/agents/tts/tts.py @@ -199,7 +199,7 @@ async def _main_task(self) -> None: async def aclose(self) -> None: """Close is automatically called if the stream is completely collected""" - await aio.gracefully_cancel(self._synthesize_task) + await aio.cancel_and_wait(self._synthesize_task) self._event_ch.close() await self._metrics_task @@ -360,7 +360,7 @@ def end_input(self) -> None: async def aclose(self) -> None: """Close ths stream immediately""" self._input_ch.close() - await aio.gracefully_cancel(self._task) + await aio.cancel_and_wait(self._task) if self._metrics_task is not None: await self._metrics_task diff --git a/livekit-agents/livekit/agents/utils/aio/__init__.py b/livekit-agents/livekit/agents/utils/aio/__init__.py index d30bc9292..7bca086a6 100644 --- a/livekit-agents/livekit/agents/utils/aio/__init__.py +++ b/livekit-agents/livekit/agents/utils/aio/__init__.py @@ -3,8 +3,8 @@ from .interval import Interval, interval from .sleep import Sleep, SleepFinished, sleep from .task_set import TaskSet +from .utils import cancel_and_wait from .wait_group import WaitGroup -from .utils import gracefully_cancel __all__ = [ "ChanClosed", @@ -20,8 +20,8 @@ "TaskSet", "WaitGroup", "debug", - "gracefully_cancel", + "cancel_and_wait", "duplex_unix", "itertools", - "gracefully_cancel", + "cancel_and_wait", ] diff --git a/livekit-agents/livekit/agents/utils/aio/utils.py b/livekit-agents/livekit/agents/utils/aio/utils.py index 19251c961..8553fe246 100644 --- a/livekit-agents/livekit/agents/utils/aio/utils.py +++ b/livekit-agents/livekit/agents/utils/aio/utils.py @@ -2,7 +2,7 @@ import functools -async def gracefully_cancel(*futures: asyncio.Future): +async def cancel_and_wait(*futures: asyncio.Future): loop = asyncio.get_running_loop() waiters = [] diff --git a/livekit-agents/livekit/agents/vad.py b/livekit-agents/livekit/agents/vad.py index a65f8f9e8..13d02f747 100644 --- a/livekit-agents/livekit/agents/vad.py +++ b/livekit-agents/livekit/agents/vad.py @@ -152,7 +152,7 @@ def end_input(self) -> None: async def aclose(self) -> None: """Close ths stream immediately""" self._input_ch.close() - await aio.gracefully_cancel(self._task) + await aio.cancel_and_wait(self._task) self._event_ch.close() await self._metrics_task diff --git a/livekit-agents/livekit/agents/worker.py b/livekit-agents/livekit/agents/worker.py index 494ff84db..0fdfd0fa7 100644 --- a/livekit-agents/livekit/agents/worker.py +++ b/livekit-agents/livekit/agents/worker.py @@ -405,7 +405,7 @@ def load_fnc(): try: await asyncio.gather(*tasks) finally: - await utils.aio.gracefully_cancel(*tasks) + await utils.aio.cancel_and_wait(*tasks) if not self._close_future.done(): self._close_future.set_result(None) @@ -495,7 +495,7 @@ async def aclose(self) -> None: self._closed = True if self._conn_task is not None: - await utils.aio.gracefully_cancel(self._conn_task) + await utils.aio.cancel_and_wait(self._conn_task) await self._proc_pool.aclose() @@ -669,7 +669,7 @@ async def _recv_task(): try: await asyncio.gather(*tasks) finally: - await utils.aio.gracefully_cancel(*tasks) + await utils.aio.cancel_and_wait(*tasks) async def _reload_jobs(self, jobs: list[RunningJobInfo]) -> None: if not self._opts.api_secret: diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index 687eb51a1..d8838f698 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -6,7 +6,6 @@ from livekit import rtc from livekit.agents import llm, multimodal, utils -from livekit.agents.llm.function_context import build_legacy_openai_schema from pydantic import ValidationError import openai @@ -15,7 +14,9 @@ ConversationItem, ConversationItemContent, ConversationItemCreateEvent, + ConversationItemCreatedEvent, ConversationItemDeleteEvent, + ConversationItemDeletedEvent, ConversationItemTruncateEvent, ErrorEvent, InputAudioBufferAppendEvent, @@ -59,29 +60,37 @@ @dataclass class _RealtimeOptions: model: str + voice: str @dataclass -class _ResponseGeneration: - response_id: str - item_id: str - audio_ch: utils.aio.Chan[rtc.AudioFrame] +class _MessageGeneration: + message_id: str text_ch: utils.aio.Chan[str] + audio_ch: utils.aio.Chan[rtc.AudioFrame] + + +@dataclass +class _ResponseGeneration: + message_ch: utils.aio.Chan[multimodal.MessageGeneration] function_ch: utils.aio.Chan[llm.FunctionCall] + messages: dict[str, _MessageGeneration] + class RealtimeModel(multimodal.RealtimeModel): def __init__( self, *, - model: str = "gpt-4o-realtime-preview-2024-12-17", + model: str = "gpt-4o-realtime-preview", + voice: str = "alloy", client: openai.AsyncClient | None = None, ) -> None: super().__init__( capabilities=multimodal.RealtimeCapabilities(message_truncation=True) ) - self._opts = _RealtimeOptions(model=model) + self._opts = _RealtimeOptions(model=model, voice=voice) self._client = client or openai.AsyncClient() def session(self) -> "RealtimeSession": @@ -94,7 +103,6 @@ class RealtimeSession(multimodal.RealtimeSession): def __init__(self, realtime_model: RealtimeModel) -> None: super().__init__(realtime_model) self._realtime_model = realtime_model - self._chat_ctx = llm.ChatContext.empty() self._fnc_ctx = llm.FunctionContext.empty() self._msg_ch = utils.aio.Chan[RealtimeClientEvent]() @@ -104,9 +112,14 @@ def __init__(self, realtime_model: RealtimeModel) -> None: ) self._current_generation: _ResponseGeneration | None = None + self._remote_chat_ctx = llm.remote_chat_context.RemoteChatContext() + + self._update_chat_ctx_lock = asyncio.Lock() + self._update_fnc_ctx_lock = asyncio.Lock() @utils.log_exceptions(logger=logger) async def _main_task(self) -> None: + # TODO(theomonnom): handle reconnections self._conn = conn = await self._realtime_model._client.beta.realtime.connect( model=self._realtime_model._opts.model ).enter() @@ -122,6 +135,10 @@ async def _listen_for_events() -> None: self._handle_response_created(event) elif event.type == "response.output_item.added": self._handle_response_output_item_added(event) + elif event.type == "conversation.item.created": + self._handle_conversion_item_created(event) + elif event.type == "conversation.item.deleted": + self._handle_conversion_item_deleted(event) elif event.type == "response.audio_transcript.delta": self._handle_response_audio_transcript_delta(event) elif event.type == "response.audio.delta": @@ -137,21 +154,178 @@ async def _listen_for_events() -> None: elif event.type == "error": self._handle_error(event) + if event.type != "response.audio.delta": + print(event) + @utils.log_exceptions(logger=logger) - async def _forward_input_audio() -> None: + async def _forward_input() -> None: async for msg in self._msg_ch: - await conn.send(msg) + try: + await conn.send(msg) + except Exception: + break + + self._msg_ch.send_nowait( + SessionUpdateEvent( + type="session.update", + session=session_update_event.Session( + model=self._realtime_model._opts.model, # type: ignore + voice=self._realtime_model._opts.voice, # type: ignore + ), + event_id=utils.shortuuid("session_update_"), + ) + ) tasks = [ asyncio.create_task(_listen_for_events(), name="_listen_for_events"), - asyncio.create_task(_forward_input_audio(), name="_forward_input_audio"), + asyncio.create_task(_forward_input(), name="_forward_input"), ] try: await asyncio.gather(*tasks) finally: - await utils.aio.gracefully_cancel(*tasks) + await utils.aio.cancel_and_wait(*tasks) await conn.close() + @property + def chat_ctx(self) -> llm.ChatContext: + return self._remote_chat_ctx.to_chat_ctx() + + @property + def fnc_ctx(self) -> llm.FunctionContext: + return self._fnc_ctx.copy() + + async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: + async with self._update_chat_ctx_lock: + diff_ops = llm.utils.compute_chat_ctx_diff( + self._remote_chat_ctx.to_chat_ctx(), chat_ctx + ) + + # futs = [] + + for msg_id in diff_ops.to_remove: + event_id = utils.shortuuid("chat_ctx_delete_") + self._msg_ch.send_nowait( + ConversationItemDeleteEvent( + type="conversation.item.delete", + item_id=msg_id, + event_id=event_id, + ) + ) + # futs.append(f := asyncio.Future()) + # self._response_futures[event_id] = f + + for previous_msg_id, msg_id in diff_ops.to_create: + event_id = utils.shortuuid("chat_ctx_create_") + chat_item = chat_ctx.get_by_id(msg_id) + assert chat_item is not None + + self._msg_ch.send_nowait( + ConversationItemCreateEvent( + type="conversation.item.create", + item=_livekit_item_to_openai_item(chat_item), + previous_item_id=( + "root" if previous_msg_id is None else previous_msg_id + ), + event_id=event_id, + ) + ) + # futs.append(f := asyncio.Future()) + # self._response_futures[event_id] = f + + # await asyncio.gather(*futs, return_exceptions=True) + + async def update_fnc_ctx( + self, fnc_ctx: llm.FunctionContext | list[llm.AIFunction] + ) -> None: + async with self._update_fnc_ctx_lock: + if isinstance(fnc_ctx, list): + fnc_ctx = llm.FunctionContext(fnc_ctx) + + tools: list[session_update_event.SessionTool] = [] + retained_functions: list[llm.AIFunction] = [] + + for ai_fnc in fnc_ctx.ai_functions.values(): + tool_desc = llm.utils.build_legacy_openai_schema( + ai_fnc, internally_tagged=True + ) + try: + session_tool = session_update_event.SessionTool.model_validate( + tool_desc + ) + tools.append(session_tool) + retained_functions.append(ai_fnc) + except ValidationError: + logger.error( + "OpenAI Realtime API doesn't support this tool", + extra={"tool": tool_desc}, + ) + continue + + event_id = utils.shortuuid("fnc_ctx_update_") + # f = asyncio.Future() + # self._response_futures[event_id] = f + self._msg_ch.send_nowait( + SessionUpdateEvent( + type="session.update", + session=session_update_event.Session( + model=self._realtime_model._opts.model, # type: ignore (str -> Literal) + tools=tools, + ), + event_id=event_id, + ) + ) + + self._fnc_ctx = llm.FunctionContext(retained_functions) + + async def update_instructions(self, instructions: str) -> None: + event_id = utils.shortuuid("instructions_update_") + # f = asyncio.Future() + # self._response_futures[event_id] = f + self._msg_ch.send_nowait( + SessionUpdateEvent( + type="session.update", + session=session_update_event.Session( + model=self._realtime_model._opts.model, # type: ignore + instructions=instructions, + ), + event_id=event_id, + ) + ) + + def push_audio(self, frame: rtc.AudioFrame) -> None: + self._msg_ch.send_nowait( + InputAudioBufferAppendEvent( + type="input_audio_buffer.append", + audio=base64.b64encode(frame.data).decode("utf-8"), + ) + ) + + def generate_reply(self) -> asyncio.Future[multimodal.GenerationCreatedEvent]: + f = asyncio.Future() + event_id = utils.shortuuid("response_create_") + self._msg_ch.send_nowait( + ResponseCreateEvent(type="response.create", event_id=event_id) + ) + # self._response_futures[event_id] = f + return f + + def interrupt(self) -> None: + self._msg_ch.send_nowait(ResponseCancelEvent(type="response.cancel")) + + def truncate(self, *, message_id: str, audio_end_ms: int) -> None: + self._msg_ch.send_nowait( + ConversationItemTruncateEvent( + type="conversation.item.truncate", + content_index=0, + item_id=message_id, + audio_end_ms=audio_end_ms, + ) + ) + + async def aclose(self) -> None: + if self._conn is not None: + await self._conn.close() + def _handle_input_audio_buffer_speech_started( self, _: InputAudioBufferSpeechStartedEvent ) -> None: @@ -163,83 +337,97 @@ def _handle_input_audio_buffer_speech_stopped( self.emit("input_speech_stopped", multimodal.InputSpeechStoppedEvent()) def _handle_response_created(self, event: ResponseCreatedEvent) -> None: - response_id = event.response.id - assert response_id is not None, "response.id is None" + assert event.response.id is not None, "response.id is None" + self._current_generation = _ResponseGeneration( - response_id=response_id, - item_id="", - audio_ch=utils.aio.Chan(), - text_ch=utils.aio.Chan(), + message_ch=utils.aio.Chan(), function_ch=utils.aio.Chan(), + messages={}, ) + generation_ev = multimodal.GenerationCreatedEvent( + message_stream=self._current_generation.message_ch, + function_stream=self._current_generation.function_ch, + ) + + self.emit("generation_created", generation_ev) + + # fut = self._response_futures.pop(event.event_id, None) + # if fut is not None and not fut.done(): + # fut.set_result(generation_ev) + def _handle_response_output_item_added( self, event: ResponseOutputItemAddedEvent ) -> None: assert self._current_generation is not None, "current_generation is None" - item_id = event.item.id - assert item_id is not None, "item.id is None" - - # We assume only one "message" item in the current approach - if self._current_generation.item_id and event.item.type == "message": - logger.warning("Received an unexpected second item with type `message`") - return + assert (item_id := event.item.id) is not None, "item.id is None" + assert (item_type := event.item.type) is not None, "item.type is None" - if event.item.type == "function_call": - return - - self._current_generation.item_id = item_id - self.emit( - "generation_created", - multimodal.GenerationCreatedEvent( + if item_type == "message": + item_generation = _MessageGeneration( message_id=item_id, - text_stream=self._current_generation.text_ch, - audio_stream=self._current_generation.audio_ch, - function_stream=self._current_generation.function_ch, - ), + text_ch=utils.aio.Chan(), + audio_ch=utils.aio.Chan(), + ) + self._current_generation.message_ch.send_nowait( + multimodal.MessageGeneration( + message_id=item_id, + text_stream=item_generation.text_ch, + audio_stream=item_generation.audio_ch, + ) + ) + self._current_generation.messages[item_id] = item_generation + + def _handle_conversion_item_created( + self, event: ConversationItemCreatedEvent + ) -> None: + self._remote_chat_ctx.insert( + event.previous_item_id, _openai_item_to_livekit_item(event.item) ) + def _handle_conversion_item_deleted( + self, event: ConversationItemDeletedEvent + ) -> None: + self._remote_chat_ctx.delete(event.item_id) + def _handle_response_audio_transcript_delta( self, event: ResponseAudioTranscriptDeltaEvent ) -> None: assert self._current_generation is not None, "current_generation is None" - self._current_generation.text_ch.send_nowait(event.delta) + item_generation = self._current_generation.messages[event.item_id] + item_generation.text_ch.send_nowait(event.delta) def _handle_response_audio_delta(self, event: ResponseAudioDeltaEvent) -> None: assert self._current_generation is not None, "current_generation is None" + item_generation = self._current_generation.messages[event.item_id] + data = base64.b64decode(event.delta) - frame = rtc.AudioFrame( - data=data, - sample_rate=SAMPLE_RATE, - num_channels=NUM_CHANNELS, - samples_per_channel=len(data) // 2, + item_generation.audio_ch.send_nowait( + rtc.AudioFrame( + data=data, + sample_rate=SAMPLE_RATE, + num_channels=NUM_CHANNELS, + samples_per_channel=len(data) // 2, + ) ) - self._current_generation.audio_ch.send_nowait(frame) def _handle_response_audio_transcript_done( self, _: ResponseAudioTranscriptDoneEvent ) -> None: assert self._current_generation is not None, "current_generation is None" - self._current_generation.text_ch.close() def _handle_response_audio_done(self, _: ResponseAudioDoneEvent) -> None: assert self._current_generation is not None, "current_generation is None" - self._current_generation.audio_ch.close() def _handle_response_output_item_done( self, event: ResponseOutputItemDoneEvent ) -> None: assert self._current_generation is not None, "current_generation is None" + assert (item_id := event.item.id) is not None, "item.id is None" + assert (item_type := event.item.type) is not None, "item.type is None" - item = event.item - if item.type == "function_call": - if len(self.fnc_ctx.ai_functions) == 0: - logger.warning( - "received a function_call item without ai functions", - extra={"item": item}, - ) - return - + if item_type == "function_call": + item = event.item assert item.call_id is not None, "call_id is None" assert item.name is not None, "name is None" assert item.arguments is not None, "arguments is None" @@ -251,13 +439,20 @@ def _handle_response_output_item_done( arguments=item.arguments, ) ) + elif item_type == "message": + item_generation = self._current_generation.messages[item_id] + item_generation.text_ch.close() + item_generation.audio_ch.close() def _handle_response_done(self, _: ResponseDoneEvent) -> None: assert self._current_generation is not None, "current_generation is None" - # self._current_generation.tool_calls_ch.close() + self._current_generation.function_ch.close() self._current_generation = None def _handle_error(self, event: ErrorEvent) -> None: + if event.error.message.startswith("Cancellation failed"): + return + logger.error( "OpenAI Realtime API returned an error", extra={"error": event.error}, @@ -267,123 +462,15 @@ def _handle_error(self, event: ErrorEvent) -> None: multimodal.ErrorEvent(type=event.error.type, message=event.error.message), ) - @property - def chat_ctx(self) -> llm.ChatContext: - return self._chat_ctx.copy() - - @property - def fnc_ctx(self) -> llm.FunctionContext: - return self._fnc_ctx.copy() - - async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: - diff_ops = llm.compute_chat_ctx_diff(self._chat_ctx, chat_ctx) - - for msg_id in diff_ops.to_remove: - self._msg_ch.send_nowait( - ConversationItemDeleteEvent( - type="conversation.item.delete", - item_id=msg_id, - ) - ) - - for previous_msg_id, msg_id in diff_ops.to_create: - chat_item = chat_ctx.get_by_id(msg_id) - assert chat_item is not None - self._msg_ch.send_nowait( - ConversationItemCreateEvent( - type="conversation.item.create", - item=_chat_item_to_conversation_item(chat_item), - previous_item_id=( - "root" if previous_msg_id is None else previous_msg_id - ), - ) - ) - - # TODO(theomonnom): wait for the server confirmation - - async def update_fnc_ctx( - self, fnc_ctx: llm.FunctionContext | list[llm.AIFunction] - ) -> None: - if isinstance(fnc_ctx, list): - fnc_ctx = llm.FunctionContext(fnc_ctx) - - tools: list[session_update_event.SessionTool] = [] - retained_functions: list[llm.AIFunction] = [] - - for ai_fnc in fnc_ctx.ai_functions.values(): - tool_desc = build_legacy_openai_schema(ai_fnc, internally_tagged=True) - try: - session_tool = session_update_event.SessionTool.model_validate( - tool_desc - ) - tools.append(session_tool) - retained_functions.append(ai_fnc) - except ValidationError: - logger.error( - "OpenAI Realtime API doesn't support this tool", - extra={"tool": tool_desc}, - ) - continue - - self._msg_ch.send_nowait( - SessionUpdateEvent( - type="session.update", - session=session_update_event.Session( - model=self._realtime_model._opts.model, # type: ignore (str -> Literal) - tools=tools, - ), - ) - ) - - # TODO(theomonnom): wait for the server confirmation before updating the local state - self._fnc_ctx = llm.FunctionContext(retained_functions) - - async def update_instructions(self, instructions: str) -> None: - self._msg_ch.send_nowait( - SessionUpdateEvent( - type="session.update", - session=session_update_event.Session( - model=self._realtime_model._opts.model, # type: ignore - instructions=instructions, - ), - ) - ) - - # TODO(theomonnom): wait for the server confirmation - - def push_audio(self, frame: rtc.AudioFrame) -> None: - self._msg_ch.send_nowait( - InputAudioBufferAppendEvent( - type="input_audio_buffer.append", - audio=base64.b64encode(frame.data).decode("utf-8"), - ) - ) - - def generate_reply(self) -> None: - self._msg_ch.send_nowait(ResponseCreateEvent(type="response.create")) - - def interrupt(self) -> None: - self._msg_ch.send_nowait(ResponseCancelEvent(type="response.cancel")) - - def truncate(self, *, message_id: str, audio_end_ms: int) -> None: - self._msg_ch.send_nowait( - ConversationItemTruncateEvent( - type="conversation.item.truncate", - content_index=0, - item_id=message_id, - audio_end_ms=audio_end_ms, - ) - ) - - async def aclose(self) -> None: - if self._conn is not None: - await self._conn.close() + # if event.error.event_id: + # fut = self._response_futures.pop(event.error.event_id, None) + # if fut is not None and not fut.done(): + # fut.set_exception(multimodal.RealtimeError(event.error.message)) -def _chat_item_to_conversation_item(item: llm.ChatItem) -> ConversationItem: +def _livekit_item_to_openai_item(item: llm.ChatItem) -> ConversationItem: conversation_item = ConversationItem( id=item.id, - object="realtime.item", ) if item.type == "function_call": @@ -431,3 +518,49 @@ def _chat_item_to_conversation_item(item: llm.ChatItem) -> ConversationItem: conversation_item.content = content_list return conversation_item + + +def _openai_item_to_livekit_item(item: ConversationItem) -> llm.ChatItem: + assert item.id is not None, "id is None" + + if item.type == "function_call": + assert item.call_id is not None, "call_id is None" + assert item.name is not None, "name is None" + assert item.arguments is not None, "arguments is None" + + return llm.FunctionCall( + id=item.id, + call_id=item.call_id, + name=item.name, + arguments=item.arguments, + ) + + if item.type == "function_call_output": + assert item.call_id is not None, "call_id is None" + assert item.output is not None, "output is None" + + return llm.FunctionCallOutput( + id=item.id, + call_id=item.call_id, + output=item.output, + is_error=False, + ) + + if item.type == "message": + assert item.role is not None, "role is None" + assert item.content is not None, "content is None" + + content: list[llm.ChatContent] = [] + + for c in item.content: + if c.type == "text" or c.type == "input_text": + assert c.text is not None, "text is None" + content.append(c.text) + + return llm.ChatMessage( + id=item.id, + role=item.role, + content=content, + ) + + raise ValueError(f"unsupported item type: {item.type}") From bd1bad210193702029adb35fc7e004af4602286b Mon Sep 17 00:00:00 2001 From: Long Chen Date: Tue, 28 Jan 2025 03:26:49 +0800 Subject: [PATCH 17/19] Room IO for agent v1.0 (#1404) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Théo Monnom --- examples/roomio_worker.py | 71 +++++ livekit-agents/livekit/agents/cli/cli.py | 2 +- .../livekit/agents/llm/chat_context.py | 8 +- .../livekit/agents/pipeline/pipeline_agent.py | 27 +- .../livekit/agents/pipeline/room_io.py | 265 ++++++++++++++++++ .../livekit/agents/pipeline/speech_handle.py | 13 +- .../livekit/plugins/openai/__init__.py | 4 +- 7 files changed, 384 insertions(+), 6 deletions(-) create mode 100644 examples/roomio_worker.py create mode 100644 livekit-agents/livekit/agents/pipeline/room_io.py diff --git a/examples/roomio_worker.py b/examples/roomio_worker.py new file mode 100644 index 000000000..426f25665 --- /dev/null +++ b/examples/roomio_worker.py @@ -0,0 +1,71 @@ +import logging + +from dotenv import load_dotenv +from livekit import rtc +from livekit.agents import JobContext, WorkerOptions, WorkerType, cli +from livekit.agents.pipeline import AgentTask, PipelineAgent +from livekit.agents.pipeline.io import PlaybackFinishedEvent +from livekit.agents.pipeline.room_io import RoomInput, RoomInputOptions, RoomOutput +from livekit.plugins import openai + +logger = logging.getLogger("my-worker") +logger.setLevel(logging.INFO) + +load_dotenv() + + +async def entrypoint(ctx: JobContext): + await ctx.connect() + + agent = PipelineAgent( + task=AgentTask( + instructions="Talk to me!", + llm=openai.realtime.RealtimeModel(), + ) + ) + + # default use RoomIO if room is provided + await agent.start( + room=ctx.room, + room_input_options=RoomInputOptions( + audio_enabled=True, + video_enabled=False, + audio_sample_rate=24000, + audio_num_channels=1, + ), + ) + + # # Or use RoomInput and RoomOutput explicitly + # room_input = RoomInput( + # ctx.room, + # options=RoomInputOptions( + # audio_enabled=True, + # video_enabled=False, + # audio_sample_rate=24000, + # audio_num_channels=1, + # ), + # ) + # room_output = RoomOutput(ctx.room, sample_rate=24000, num_channels=1) + + # agent.input.audio = room_input.audio + # agent.output.audio = room_output.audio + + # await room_input.wait_for_participant() + # await room_output.start() + + # TODO: the interrupted flag is not set correctly + @agent.output.audio.on("playback_finished") + def on_playback_finished(ev: PlaybackFinishedEvent) -> None: + logger.info( + "playback_finished", + extra={ + "playback_position": ev.playback_position, + "interrupted": ev.interrupted, + }, + ) + + +if __name__ == "__main__": + # WorkerType.ROOM is the default worker type which will create an agent for every room. + # You can also use WorkerType.PUBLISHER to create a single agent for all participants that publish a track. + cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM)) diff --git a/livekit-agents/livekit/agents/cli/cli.py b/livekit-agents/livekit/agents/cli/cli.py index a29ef17e8..afee008b2 100644 --- a/livekit-agents/livekit/agents/cli/cli.py +++ b/livekit-agents/livekit/agents/cli/cli.py @@ -116,7 +116,7 @@ def dev( asyncio_debug=asyncio_debug, watch=watch, drain_timeout=0, - register=False, + register=True, ) _run_dev(args) diff --git a/livekit-agents/livekit/agents/llm/chat_context.py b/livekit-agents/livekit/agents/llm/chat_context.py index f2ce4a31d..d44e0f2f7 100644 --- a/livekit-agents/livekit/agents/llm/chat_context.py +++ b/livekit-agents/livekit/agents/llm/chat_context.py @@ -21,7 +21,7 @@ ) from livekit import rtc -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing_extensions import TypeAlias from .. import utils @@ -81,12 +81,18 @@ class ImageContent(BaseModel): Currently only supported by OpenAI (see https://platform.openai.com/docs/guides/vision?lang=node#low-or-high-fidelity-image-understanding) """ + # temporary fix for pydantic + model_config = ConfigDict(arbitrary_types_allowed=True) + class AudioContent(BaseModel): type: Literal["audio_content"] = Field(default="audio_content") frame: list[rtc.AudioFrame] transcript: Optional[str] = None + # temporary fix for pydantic before rtc.AudioFrame is supported + model_config = ConfigDict(arbitrary_types_allowed=True) + class ChatMessage(BaseModel): id: str = Field(default_factory=lambda: utils.shortuuid("item_")) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index c7e70bdd7..c2044d9c2 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -5,6 +5,7 @@ from typing import ( AsyncIterable, Literal, + Optional ) from livekit import rtc @@ -12,6 +13,7 @@ from .. import debug, llm, utils from ..log import logger from . import io +from .room_io import RoomInput, RoomInputOptions, RoomOutput from .agent_task import TaskActivity, AgentTask from .speech_handle import SpeechHandle @@ -80,10 +82,33 @@ def __init__( # They can all be overriden by subclasses, by default they use the STT/LLM/TTS specified in the # constructor of the PipelineAgent - def start(self) -> None: + async def start( + self, + room: Optional[rtc.Room] = None, + room_input_options: Optional[RoomInputOptions] = None, + ) -> None: + """Start the pipeline agent. + + Args: + room (Optional[rtc.Room]): The LiveKit room. If provided and no input/output audio + is set, automatically configures room audio I/O. + room_input_options (Optional[RoomInputOptions]): Options for the room input. + """ if self._started: return + if room is not None: + # configure room I/O if not already set + if self.input.audio is None: + room_input = RoomInput(room=room, options=room_input_options) + self._input.audio = room_input.audio + await room_input.wait_for_participant() + + if self.output.audio is None: + room_output = RoomOutput(room=room) + self._output.audio = room_output.audio + await room_output.start() + if self.input.audio is not None: self._forward_audio_atask = asyncio.create_task( self._forward_audio_task(), name="_forward_audio_task" diff --git a/livekit-agents/livekit/agents/pipeline/room_io.py b/livekit-agents/livekit/agents/pipeline/room_io.py new file mode 100644 index 000000000..3be659833 --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/room_io.py @@ -0,0 +1,265 @@ +import asyncio +from dataclasses import dataclass +from typing import AsyncIterator, Optional + +from livekit import rtc + +from .io import AudioSink +from .log import logger + + +@dataclass +class RoomInputOptions: + audio_enabled: bool = True + """Whether to subscribe to audio""" + video_enabled: bool = False + """Whether to subscribe to video""" + audio_sample_rate: int = 16000 + """Sample rate of the input audio in Hz""" + audio_num_channels: int = 1 + """Number of audio channels""" + audio_queue_capacity: int = 0 + """Capacity of the internal audio queue, 0 means unlimited""" + video_queue_capacity: int = 0 + """Capacity of the internal video queue, 0 means unlimited""" + + +DEFAULT_ROOM_INPUT_OPTIONS = RoomInputOptions() + + +class RoomInput: + """Creates video and audio streams from a remote participant in a LiveKit room""" + + def __init__( + self, + room: rtc.Room, + participant_identity: Optional[str] = None, + options: RoomInputOptions = DEFAULT_ROOM_INPUT_OPTIONS, + ) -> None: + """ + Args: + room: The LiveKit room to get streams from + participant_identity: Optional identity of the participant to get streams from. + If None, will use the first participant that joins. + options: RoomInputOptions + """ + self._options = options + self._room = room + self._expected_identity = participant_identity + self._participant: rtc.RemoteParticipant | None = None + self._closed = False + + # streams + self._audio_stream: Optional[rtc.AudioStream] = None + self._video_stream: Optional[rtc.VideoStream] = None + + self._participant_ready = asyncio.Event() + self._room.on("participant_connected", self._on_participant_connected) + + # try to find participant + if self._expected_identity is not None: + participant = self._room.remote_participants.get(self._expected_identity) + if participant is not None: + self._link_participant(participant) + else: + for participant in self._room.remote_participants.values(): + self._link_participant(participant) + if self._participant: + break + + async def wait_for_participant(self) -> rtc.RemoteParticipant: + await self._participant_ready.wait() + assert self._participant is not None + return self._participant + + @property + def audio(self) -> AsyncIterator[rtc.AudioFrame] | None: + if self._audio_stream is None: + return None + + async def _read_stream(): + async for event in self._audio_stream: + yield event.frame + + return _read_stream() + + @property + def video(self) -> AsyncIterator[rtc.VideoFrame] | None: + if self._video_stream is None: + return None + + async def _read_stream(): + async for event in self._video_stream: + yield event.frame + + return _read_stream() + + def _link_participant(self, participant: rtc.RemoteParticipant) -> None: + if ( + self._expected_identity is not None + and participant.identity != self._expected_identity + ): + return + + self._participant = participant + + # set up tracks + if self._options.audio_enabled: + self._audio_stream = rtc.AudioStream.from_participant( + participant=participant, + track_source=rtc.TrackSource.SOURCE_MICROPHONE, + sample_rate=self._options.audio_sample_rate, + num_channels=self._options.audio_num_channels, + capacity=self._options.audio_queue_capacity, + ) + if self._options.video_enabled: + self._video_stream = rtc.VideoStream.from_participant( + participant=participant, + track_source=rtc.TrackSource.SOURCE_CAMERA, + capacity=self._options.video_queue_capacity, + ) + + self._participant_ready.set() + + def _on_participant_connected(self, participant: rtc.RemoteParticipant) -> None: + if self._participant is not None: + return + self._link_participant(participant) + + async def aclose(self) -> None: + if self._closed: + raise RuntimeError("RoomInput already closed") + + self._closed = True + self._room.off("participant_connected", self._on_participant_connected) + self._participant = None + + if self._audio_stream is not None: + await self._audio_stream.aclose() + self._audio_stream = None + if self._video_stream is not None: + await self._video_stream.aclose() + self._video_stream = None + + +class RoomOutput: + """Manages audio output to a LiveKit room""" + + def __init__( + self, room: rtc.Room, *, sample_rate: int = 24000, num_channels: int = 1 + ) -> None: + """Initialize the RoomOutput + + Args: + room: The LiveKit room to publish media to + sample_rate: Sample rate of the audio in Hz + num_channels: Number of audio channels + """ + self._audio_sink = RoomAudioSink( + room=room, sample_rate=sample_rate, num_channels=num_channels + ) + + async def start(self) -> None: + await self._audio_sink.start() + + @property + def audio(self) -> "RoomAudioSink": + return self._audio_sink + + +class RoomAudioSink(AudioSink): + """AudioSink implementation that publishes audio to a LiveKit room""" + + def __init__( + self, + room: rtc.Room, + *, + sample_rate: int = 24000, + num_channels: int = 1, + queue_size_ms: int = 100_000, + ) -> None: + """Initialize the RoomAudioSink + + Args: + room: The LiveKit room to publish audio to + sample_rate: Sample rate of the audio in Hz + num_channels: Number of audio channels + queue_size_ms: Size of the internal audio queue in ms. + Default to 100s to capture as fast as possible. + """ + super().__init__(sample_rate=sample_rate) + self._room = room + + # buffer the audio frames as soon as they are captured + self._audio_source = rtc.AudioSource( + sample_rate=sample_rate, + num_channels=num_channels, + queue_size_ms=queue_size_ms, + ) + + self._publication: rtc.LocalTrackPublication | None = None + self._pushed_duration: Optional[float] = None + self._interrupted: bool = False + self._flush_task: Optional[asyncio.Task[None]] = None + + def _on_reconnected(self) -> None: + self._publication = None + asyncio.create_task(self.start()) + + self._room.on("reconnected", _on_reconnected) + + async def start(self) -> None: + """Start publishing the audio track to the room""" + if self._publication: + return + + track = rtc.LocalAudioTrack.create_audio_track( + "assistant_voice", self._audio_source + ) + self._publication = await self._room.local_participant.publish_track( + track=track, + options=rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE), + ) + await self._publication.wait_for_subscription() + + async def capture_frame(self, frame: rtc.AudioFrame) -> None: + """Capture an audio frame and publish it to the room""" + await super().capture_frame(frame) + + if self._pushed_duration is None: + self._pushed_duration = 0.0 + self._interrupted = False # reset interrupted flag + self._pushed_duration += frame.duration + await self._audio_source.capture_frame(frame) + + def flush(self) -> None: + """Flush the current audio segment and notify when complete""" + super().flush() + if self._pushed_duration is None: + return + if self._flush_task and not self._flush_task.done(): + # shouldn't happen if only one active speech handle at a time + logger.error("flush called while playback is in progress") + self._flush_task.cancel() + self._flush_task = None + + def _playback_finished(task: asyncio.Task[None]) -> None: + self.on_playback_finished( + playback_position=self._pushed_duration, interrupted=self._interrupted + ) + self._pushed_duration = None + self._interrupted = False + + self._flush_task = asyncio.create_task(self._audio_source.wait_for_playout()) + self._flush_task.add_done_callback(_playback_finished) + + def clear_buffer(self) -> None: + """Clear the audio buffer immediately""" + super().clear_buffer() + if self._pushed_duration is None: + return + + queued_duration = self._audio_source.queued_duration + self._pushed_duration = max(0, self._pushed_duration - queued_duration) + self._interrupted = True + self._audio_source.clear_queue() diff --git a/livekit-agents/livekit/agents/pipeline/speech_handle.py b/livekit-agents/livekit/agents/pipeline/speech_handle.py index bd3971291..8ae63ece0 100644 --- a/livekit-agents/livekit/agents/pipeline/speech_handle.py +++ b/livekit-agents/livekit/agents/pipeline/speech_handle.py @@ -79,6 +79,17 @@ def _mark_playout_done(self) -> None: self._playout_done_fut.set_result(None) async def wait_until_interrupted(self, aw: list[Awaitable]) -> None: + temp_tasks = [] + tasks = [] + for task in aw: + if not isinstance(task, asyncio.Task): + task = asyncio.create_task(task) + temp_tasks.append(task) + tasks.append(task) + await asyncio.wait( - [*aw, self._interrupt_fut], return_when=asyncio.FIRST_COMPLETED + [*tasks, self._interrupt_fut], return_when=asyncio.FIRST_COMPLETED ) + for task in temp_tasks: + if not task.done(): + task.cancel() diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/__init__.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/__init__.py index 1a6b7c00d..ae870a241 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/__init__.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. -from . import beta, realtime +from . import realtime from .embeddings import EmbeddingData, create_embeddings from .llm import LLM, LLMStream from .models import TTSModels, TTSVoices, WhisperModels @@ -27,7 +27,7 @@ "LLM", "LLMStream", "WhisperModels", - "beta", + # "beta", "TTSModels", "TTSVoices", "create_embeddings", From 96decf9183f44b60d99e3aee34a9900dcd544aba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 29 Jan 2025 19:16:42 +0100 Subject: [PATCH 18/19] better OAI sync --- examples/minimal_worker.py | 9 +- examples/roomio_worker.py | 3 +- livekit-agents/livekit/agents/llm/__init__.py | 7 +- .../livekit/agents/llm/chat_context.py | 8 +- .../livekit/agents/llm/function_context.py | 12 + .../livekit/agents/llm/remote_chat_context.py | 3 +- livekit-agents/livekit/agents/llm/utils.py | 4 +- .../livekit/agents/multimodal/__init__.py | 4 +- .../livekit/agents/multimodal/realtime.py | 9 +- .../livekit/agents/pipeline/__init__.py | 2 +- .../livekit/agents/pipeline/agent_task.py | 431 ++++-------------- .../livekit/agents/pipeline/context.py | 2 - .../livekit/agents/pipeline/generation.py | 73 ++- .../livekit/agents/pipeline/pipeline_agent.py | 31 +- .../livekit/agents/pipeline/room_io.py | 2 + .../livekit/agents/pipeline/speech_handle.py | 17 +- .../livekit/agents/pipeline/tools.py | 301 ++++++++++++ .../plugins/openai/realtime/realtime_model.py | 68 ++- 18 files changed, 540 insertions(+), 446 deletions(-) create mode 100644 livekit-agents/livekit/agents/pipeline/tools.py diff --git a/examples/minimal_worker.py b/examples/minimal_worker.py index ea332f9da..8a0e2a3b9 100644 --- a/examples/minimal_worker.py +++ b/examples/minimal_worker.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv from livekit.agents import JobContext, WorkerOptions, WorkerType, cli from livekit.agents.llm import ai_function -from livekit.agents.pipeline import AgentTask, ChatCLI, PipelineAgent, AgentContext +from livekit.agents.pipeline import AgentContext, AgentTask, ChatCLI, PipelineAgent from livekit.plugins import openai logger = logging.getLogger("my-worker") @@ -15,7 +15,7 @@ class EchoTask(AgentTask): def __init__(self) -> None: super().__init__( - instructions="Always speak in English even if the user speaks in another language or wants to use another language.", + instructions="You are Echo, always speak in English even if the user speaks in another language or wants to use another language.", llm=openai.realtime.RealtimeModel(voice="echo"), ) @@ -27,7 +27,7 @@ async def talk_to_alloy(self, context: AgentContext): class AlloyTask(AgentTask): def __init__(self) -> None: super().__init__( - instructions="Always speak in English even if the user speaks in another language or wants to use another language.", + instructions="You are Alloy, always speak in English even if the user speaks in another language or wants to use another language.", llm=openai.realtime.RealtimeModel(voice="alloy"), ) @@ -40,7 +40,8 @@ async def entrypoint(ctx: JobContext): agent = PipelineAgent( task=AlloyTask(), ) - agent.start() + + await agent.start() # start a chat inside the CLI chat_cli = ChatCLI(agent) diff --git a/examples/roomio_worker.py b/examples/roomio_worker.py index 426f25665..e65f698eb 100644 --- a/examples/roomio_worker.py +++ b/examples/roomio_worker.py @@ -1,11 +1,10 @@ import logging from dotenv import load_dotenv -from livekit import rtc from livekit.agents import JobContext, WorkerOptions, WorkerType, cli from livekit.agents.pipeline import AgentTask, PipelineAgent from livekit.agents.pipeline.io import PlaybackFinishedEvent -from livekit.agents.pipeline.room_io import RoomInput, RoomInputOptions, RoomOutput +from livekit.agents.pipeline.room_io import RoomInputOptions from livekit.plugins import openai logger = logging.getLogger("my-worker") diff --git a/livekit-agents/livekit/agents/llm/__init__.py b/livekit-agents/livekit/agents/llm/__init__.py index 0d7dc7410..83b16d12d 100644 --- a/livekit-agents/livekit/agents/llm/__init__.py +++ b/livekit-agents/livekit/agents/llm/__init__.py @@ -1,4 +1,4 @@ -from . import utils +from . import remote_chat_context, utils from .chat_context import ( AudioContent, ChatContent, @@ -11,9 +11,10 @@ ) from .fallback_adapter import AvailabilityChangedEvent, FallbackAdapter from .function_context import ( + AIError, AIFunction, FunctionContext, - AIError, + StopResponse, ai_function, find_ai_functions, is_ai_function, @@ -28,7 +29,6 @@ LLMStream, ToolChoice, ) -from . import remote_chat_context __all__ = [ "LLM", @@ -56,6 +56,7 @@ "AIFunction", "FunctionContext", "AIError", + "StopResponse", "utils", "remote_chat_context", ] diff --git a/livekit-agents/livekit/agents/llm/chat_context.py b/livekit-agents/livekit/agents/llm/chat_context.py index d44e0f2f7..f2ce4a31d 100644 --- a/livekit-agents/livekit/agents/llm/chat_context.py +++ b/livekit-agents/livekit/agents/llm/chat_context.py @@ -21,7 +21,7 @@ ) from livekit import rtc -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, Field from typing_extensions import TypeAlias from .. import utils @@ -81,18 +81,12 @@ class ImageContent(BaseModel): Currently only supported by OpenAI (see https://platform.openai.com/docs/guides/vision?lang=node#low-or-high-fidelity-image-understanding) """ - # temporary fix for pydantic - model_config = ConfigDict(arbitrary_types_allowed=True) - class AudioContent(BaseModel): type: Literal["audio_content"] = Field(default="audio_content") frame: list[rtc.AudioFrame] transcript: Optional[str] = None - # temporary fix for pydantic before rtc.AudioFrame is supported - model_config = ConfigDict(arbitrary_types_allowed=True) - class ChatMessage(BaseModel): id: str = Field(default_factory=lambda: utils.shortuuid("item_")) diff --git a/livekit-agents/livekit/agents/llm/function_context.py b/livekit-agents/livekit/agents/llm/function_context.py index 7ecd074ba..b158d193c 100644 --- a/livekit-agents/livekit/agents/llm/function_context.py +++ b/livekit-agents/livekit/agents/llm/function_context.py @@ -44,6 +44,18 @@ def message(self) -> str: return self._message +class StopResponse(Exception): + def __init__(self) -> None: + """ + Exception raised within AI functions. + + This exception can be raised by the user to indicate that + the agent should not generate a response for the current + function call. + """ + super().__init__() + + @dataclass class _AIFunctionInfo: name: str diff --git a/livekit-agents/livekit/agents/llm/remote_chat_context.py b/livekit-agents/livekit/agents/llm/remote_chat_context.py index 10f404202..7998f5588 100644 --- a/livekit-agents/livekit/agents/llm/remote_chat_context.py +++ b/livekit-agents/livekit/agents/llm/remote_chat_context.py @@ -1,7 +1,8 @@ from __future__ import annotations from dataclasses import dataclass, field -from .chat_context import ChatItem, ChatContext + +from .chat_context import ChatContext, ChatItem __all__ = ["RemoteChatContext"] diff --git a/livekit-agents/livekit/agents/llm/utils.py b/livekit-agents/livekit/agents/llm/utils.py index 68a1fa039..f16b4f651 100644 --- a/livekit-agents/livekit/agents/llm/utils.py +++ b/livekit-agents/livekit/agents/llm/utils.py @@ -3,23 +3,21 @@ import inspect from dataclasses import dataclass from typing import ( + TYPE_CHECKING, Annotated, Any, Callable, get_args, get_origin, get_type_hints, - TYPE_CHECKING, ) from pydantic import BaseModel, create_model from pydantic.fields import FieldInfo - from .chat_context import ChatContext from .function_context import AIFunction, get_function_info - if TYPE_CHECKING: from ..pipeline.context import AgentContext from ..pipeline.speech_handle import SpeechHandle diff --git a/livekit-agents/livekit/agents/multimodal/__init__.py b/livekit-agents/livekit/agents/multimodal/__init__.py index bebf8faf2..6df7fde6a 100644 --- a/livekit-agents/livekit/agents/multimodal/__init__.py +++ b/livekit-agents/livekit/agents/multimodal/__init__.py @@ -1,11 +1,11 @@ from .realtime import ( ErrorEvent, - RealtimeError, GenerationCreatedEvent, - MessageGeneration, InputSpeechStartedEvent, InputSpeechStoppedEvent, + MessageGeneration, RealtimeCapabilities, + RealtimeError, RealtimeModel, RealtimeSession, ) diff --git a/livekit-agents/livekit/agents/multimodal/realtime.py b/livekit-agents/livekit/agents/multimodal/realtime.py index bc39509bc..c2ddee9a6 100644 --- a/livekit-agents/livekit/agents/multimodal/realtime.py +++ b/livekit-agents/livekit/agents/multimodal/realtime.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio - from abc import ABC, abstractmethod from dataclasses import dataclass from typing import AsyncIterable, Generic, Literal, TypeVar, Union @@ -100,7 +99,9 @@ def fnc_ctx(self) -> llm.FunctionContext: ... async def update_instructions(self, instructions: str) -> None: ... @abstractmethod - async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: ... + async def update_chat_ctx( + self, chat_ctx: llm.ChatContext + ) -> None: ... # can raise RealtimeError on Timeout @abstractmethod async def update_fnc_ctx( @@ -113,7 +114,9 @@ def push_audio(self, frame: rtc.AudioFrame) -> None: ... @abstractmethod def generate_reply( self, - ) -> asyncio.Future[GenerationCreatedEvent]: ... # can raise RealtimeError + ) -> asyncio.Future[ + GenerationCreatedEvent + ]: ... # can raise RealtimeError on Timeout # cancel the current generation (do nothing if no generation is in progress) @abstractmethod diff --git a/livekit-agents/livekit/agents/pipeline/__init__.py b/livekit-agents/livekit/agents/pipeline/__init__.py index 375e329ab..3ad306934 100644 --- a/livekit-agents/livekit/agents/pipeline/__init__.py +++ b/livekit-agents/livekit/agents/pipeline/__init__.py @@ -1,7 +1,7 @@ from .agent_task import AgentTask from .chat_cli import ChatCLI +from .context import AgentContext from .pipeline_agent import PipelineAgent from .speech_handle import SpeechHandle -from .context import AgentContext __all__ = ["ChatCLI", "PipelineAgent", "AgentTask", "SpeechHandle", "AgentContext"] diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py index 7d928151f..dc888df29 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_task.py +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -3,42 +3,37 @@ import asyncio import contextlib import heapq -import inspect import time -from dataclasses import dataclass from typing import ( - Any, + TYPE_CHECKING, AsyncIterable, Optional, Union, - TYPE_CHECKING, ) from livekit import rtc -from pydantic import ValidationError -from torch import mul from .. import debug, llm, multimodal, stt, tokenize, tts, utils, vad +from .._exceptions import APITimeoutError from ..llm import ( ChatContext, FunctionContext, - AIError, find_ai_functions, - utils as llm_utils, ) from ..log import logger from ..types import NOT_GIVEN, NotGivenOr from ..utils import is_given -from . import io +from . import tools from .audio_recognition import AudioRecognition, RecognitionHooks, _TurnDetector +from .context import AgentContext from .generation import ( + _AudioOutput, + _TextOutput, _TTSGenerationData, - do_llm_inference, - do_tts_inference, + perform_audio_forwarding, + perform_text_forwarding, ) from .speech_handle import SpeechHandle -from .context import AgentContext - if TYPE_CHECKING: from .pipeline_agent import PipelineAgent @@ -610,71 +605,74 @@ async def _realtime_reply_task( audio_output = self._agent.output.audio text_output = self._agent.output.text - await speech_handle.wait_until_interrupted( - [speech_handle._wait_for_authorization()] + await speech_handle.wait_if_not_interrupted( + [asyncio.ensure_future(speech_handle._wait_for_authorization())] ) if speech_handle.interrupted: - speech_handle._mark_playout_done() # TODO(theomonnom): remove the message from the serverside history + speech_handle._mark_playout_done() return @utils.log_exceptions(logger=logger) - async def _read_messages(message_outputs: list[_MessageOutput]) -> None: - ts2 = utils.aio.TaskSet() + async def _read_messages( + outputs: list[tuple[_TextOutput | None, _AudioOutput | None]], + ) -> None: + forward_tasks: list[asyncio.Task] = [] async for msg in generation_ev.message_stream: - if ts2.tasks: + if len(forward_tasks) > 0: logger.warning( "expected to receive only one message generation from the realtime API" ) break - out = _MessageOutput(text="", audio=[]) - message_outputs.append(out) + text_out = None + audio_out = None if text_output is not None: - ts2.create_task( - _forward_text(text_output, msg.text_stream, out), - name="_realtime_reply_task.forward_text", + forward_task, text_out = perform_text_forwarding( + text_output=text_output, llm_output=msg.text_stream ) + forward_tasks.append(forward_task) if audio_output is not None: - ts2.create_task( - _forward_audio(audio_output, msg.audio_stream, out), - name="_realtime_reply_task.forward_audio", + forward_task, audio_out = perform_audio_forwarding( + audio_output=audio_output, tts_output=msg.audio_stream ) + forward_tasks.append(forward_task) + + outputs.append((text_out, audio_out)) try: - await asyncio.gather(*ts2.tasks) + await asyncio.gather(*forward_tasks) finally: - await utils.aio.cancel_and_wait(*ts2.tasks) + await utils.aio.cancel_and_wait(*forward_tasks) - function_outputs: list[_FunctionCallOutput] = [] - message_outputs: list[_MessageOutput] = [] - ts = utils.aio.TaskSet() - ts.create_task( - _read_messages(message_outputs), name="_realtime_reply_task.read_messages" - ) - ts.create_task( - _execute_tools( - agent_ctx=AgentContext(self._agent), - fnc_ctx=self._task.fnc_ctx, - speech_handle=speech_handle, - function_stream=generation_ev.function_stream, - out=function_outputs, - ), - name="_realtime_reply_task.execute_tools", + message_outputs: list[tuple[_TextOutput | None, _AudioOutput | None]] = [] + tasks = [ + asyncio.create_task( + _read_messages(message_outputs), + name="_realtime_reply_task.read_messages", + ) + ] + + exe_task, fnc_outputs = tools.perform_tool_executions( + agent_ctx=AgentContext(self._agent), + fnc_ctx=self._task.fnc_ctx, + speech_handle=speech_handle, + function_stream=generation_ev.function_stream, ) + tasks.append(exe_task) - await speech_handle.wait_until_interrupted([*ts.tasks]) + await speech_handle.wait_if_not_interrupted([*tasks]) if audio_output is not None: - await speech_handle.wait_until_interrupted( - [audio_output.wait_for_playout()] + await speech_handle.wait_if_not_interrupted( + [asyncio.ensure_future(audio_output.wait_for_playout())] ) if speech_handle.interrupted: - await utils.aio.cancel_and_wait(*ts.tasks) + await utils.aio.cancel_and_wait(*tasks) if audio_output is not None: audio_output.clear_buffer() @@ -692,319 +690,44 @@ async def _read_messages(message_outputs: list[_MessageOutput]) -> None: # TODO(theomonnom): truncate message (+ OAI serverside mesage) return - - new_agent_task: AgentTask | None = None - new_items: list[llm.ChatItem] = [] - if len(function_outputs) > 0: - for out in function_outputs: - if isinstance(out.exception, AIError): - new_items.append( - llm.FunctionCallOutput( - call_id=out.call_id, - output=out.exception.message, - is_error=True, - ) - ) - continue - elif out.exception is not None: - logger.error( - "exception occurred while executing tool", - extra={ - "function": out.name, - "speech_id": speech_handle.id, - }, - exc_info=out.exception, - ) - continue - - if out.output is not None: - if isinstance(out.output, tuple): - agent_tasks = [ - item for item in out.output if isinstance(item, AgentTask) - ] - if len(agent_tasks) > 1: - logger.error( - "Multiple AgentTask instances found in tuple output", - extra={ - "call_id": out.call_id, - "function": out.name, - "output": out.output, - }, - ) - continue - - new_agent_task = agent_tasks[0] if agent_tasks else None - out.output = tuple( - item - for item in out.output - if not isinstance(item, AgentTask) - ) - if len(out.output) == 1: - out.output = out.output[0] - elif isinstance(out.output, AgentTask): - new_agent_task = out.output - out.output = None - - if not _is_valid_function_output_type(out.output): - logger.error( - "invalid function output type", - extra={ - "call_id": out.call_id, - "function": out.name, - "output": out.output, - }, - ) - continue - - new_items.append( - llm.FunctionCallOutput( - call_id=out.call_id, - output=str(out.output), - is_error=False, - ) - ) - - if new_items: - chat_ctx = self._rt_session.chat_ctx.copy() - chat_ctx.items.extend(new_items) - await self._rt_session.update_chat_ctx(chat_ctx) - self._rt_session.interrupt() - self._rt_session.generate_reply() - - - if new_agent_task is not None: - if new_items: - await asyncio.sleep(1.0) - self._agent.update_task(new_agent_task) - - debug.Tracing.log_event("playout completed", {"speech_id": speech_handle.id}) - speech_handle._mark_playout_done() - - -def _is_valid_function_output_type(value: Any) -> bool: - VALID_TYPES = (str, int, float, bool, complex, type(None)) - - if isinstance(value, VALID_TYPES): - return True - elif ( - isinstance(value, list) - or isinstance(value, set) - or isinstance(value, frozenset) - or isinstance(value, tuple) - ): - return all(_is_valid_function_output_type(item) for item in value) - elif isinstance(value, dict): - return all( - isinstance(key, VALID_TYPES) and _is_valid_function_output_type(val) - for key, val in value.items() - ) - return False - - -@dataclass -class _MessageOutput: - text: str - audio: list[rtc.AudioFrame] - - -@dataclass -class _FunctionCallOutput: - call_id: str - name: str - arguments: str - output: Any - exception: BaseException | None - - -@utils.log_exceptions(logger=logger) -async def _forward_text( - text_output: io.TextSink, llm_output: AsyncIterable[str], out: _MessageOutput -) -> None: - try: - async for delta in llm_output: - out.text += delta - await text_output.capture_text(delta) - finally: - text_output.flush() - - -@utils.log_exceptions(logger=logger) -async def _forward_audio( - audio_output: io.AudioSink, - tts_output: AsyncIterable[rtc.AudioFrame], - out: _MessageOutput, -) -> None: - try: - async for frame in tts_output: - out.audio.append(frame) - await audio_output.capture_frame(frame) - finally: - audio_output.flush() - - -@utils.log_exceptions(logger=logger) -async def _execute_tools( - *, - agent_ctx: AgentContext, - fnc_ctx: FunctionContext, - speech_handle: SpeechHandle, - function_stream: utils.aio.Chan[llm.FunctionCall], - out: list[_FunctionCallOutput] = [], -) -> None: - """execute tools, when cancelled, stop executing new tools but wait for the pending ones""" - tasks: list[tuple[str, asyncio.Task]] = [] - try: - async for fnc_call in function_stream: - ai_function = fnc_ctx.ai_functions.get(fnc_call.name, None) - if ai_function is None: - logger.warning( - f"LLM called function `{fnc_call.name}` but it was not found in the current task", - extra={ - "function": fnc_call.name, - "speech_id": speech_handle.id, - }, - ) - continue - - try: - function_model = llm_utils.function_arguments_to_pydantic_model( - ai_function - ) - parsed_args = function_model.model_validate_json(fnc_call.arguments) - except ValidationError: - logger.exception( - "LLM called function `{fnc.name}` with invalid arguments", - extra={ - "function": fnc_call.name, - "arguments": fnc_call.arguments, - "speech_id": speech_handle.id, - }, - ) - continue - - logger.debug( - "executing tool", - extra={ - "function": fnc_call.name, - "speech_id": speech_handle.id, - }, - ) - debug.Tracing.log_event( - "executing tool", - { - "function": fnc_call.name, - "speech_id": speech_handle.id, - }, - ) + if len(fnc_outputs) > 0: + new_fnc_outputs: list[llm.FunctionCallOutput] = [] + new_agent_task: AgentTask | None = None + ignore_task_switch = False + for fnc_output, agent_task in fnc_outputs: + if fnc_output is not None: + new_fnc_outputs.append(fnc_output) - fnc_args, fnc_kwargs = llm_utils.pydantic_model_to_function_arguments( - ai_function=ai_function, - model=parsed_args, - agent_ctx=agent_ctx, - speech_handle=speech_handle, - ) - - if inspect.iscoroutinefunction(ai_function): - task = asyncio.create_task( - ai_function(*fnc_args, **fnc_kwargs), - name=f"ai_function_{fnc_call.name}", - ) - tasks.append((fnc_call.name, task)) - - def _log_exceptions(task: asyncio.Task) -> None: - if task.exception() is not None: - logger.error( - "exception occurred while executing tool", - extra={ - "function": fnc_call.name, - "speech_id": speech_handle.id, - }, - exc_info=task.exception(), - ) - out.append( - _FunctionCallOutput( - name=fnc_call.name, - arguments=fnc_call.arguments, - call_id=fnc_call.call_id, - output=None, - exception=task.exception(), - ) - ) - return - - out.append( - _FunctionCallOutput( - name=fnc_call.name, - arguments=fnc_call.arguments, - call_id=fnc_call.call_id, - output=task.result(), - exception=None, - ) + if new_agent_task is not None and agent_task is not None: + logger.error( + "expected to receive only one new task from the tool executions", ) + ignore_task_switch = True - tasks.remove((fnc_call.name, task)) + new_agent_task = agent_task - task.add_done_callback(_log_exceptions) - else: - start_time = time.monotonic() + if len(new_fnc_outputs) > 0: + chat_ctx = self._rt_session.chat_ctx.copy() + chat_ctx.items.extend(new_fnc_outputs) try: - output = ai_function(*fnc_args, **fnc_kwargs) - out.append( - _FunctionCallOutput( - name=fnc_call.name, - arguments=fnc_call.arguments, - call_id=fnc_call.call_id, - output=output, - exception=None, - ) - ) - except Exception as e: - out.append( - _FunctionCallOutput( - name=fnc_call.name, - arguments=fnc_call.arguments, - call_id=fnc_call.call_id, - output=None, - exception=e, - ) + await self._rt_session.update_chat_ctx(chat_ctx) + except multimodal.RealtimeError as e: + logger.warning( + "failed to update chat context before generating the function calls results", + extra={"error": str(e)}, ) - elapsed = time.monotonic() - start_time - if elapsed >= 1.5: + self._rt_session.interrupt() + try: + await self._rt_session.generate_reply() + except multimodal.RealtimeError as e: logger.warning( - f"function execution took too long ({elapsed:.2f}s), is `{fnc_call.name}` blocking?", - extra={ - "function": fnc_call.name, - "speech_id": speech_handle.id, - "elapsed": elapsed, - }, + "failed to generate the function calls results", + extra={"error": str(e)}, ) - except asyncio.CancelledError: - if len(tasks) > 0: - names = [name for name, _ in tasks] - logger.debug( - "waiting for function call to finish before fully cancelling", - extra={ - "functions": names, - "speech_id": speech_handle.id, - }, - ) - debug.Tracing.log_event( - "waiting for function call to finish before fully cancelling", - { - "functions": names, - "speech_id": speech_handle.id, - }, - ) - await asyncio.gather(*[task for _, task in tasks]) - finally: - if len(out) > 0: - logger.debug( - "tools execution completed", - extra={"speech_id": speech_handle.id}, - ) - debug.Tracing.log_event( - "tools execution completed", - {"speech_id": speech_handle.id}, - ) + if not ignore_task_switch and new_agent_task is not None: + self._agent.update_task(new_agent_task) + + debug.Tracing.log_event("playout completed", {"speech_id": speech_handle.id}) + speech_handle._mark_playout_done() diff --git a/livekit-agents/livekit/agents/pipeline/context.py b/livekit-agents/livekit/agents/pipeline/context.py index 7dc1a53c5..fcf0b59a0 100644 --- a/livekit-agents/livekit/agents/pipeline/context.py +++ b/livekit-agents/livekit/agents/pipeline/context.py @@ -1,9 +1,7 @@ from __future__ import annotations - from typing import TYPE_CHECKING - if TYPE_CHECKING: from .pipeline_agent import PipelineAgent diff --git a/livekit-agents/livekit/agents/pipeline/generation.py b/livekit-agents/livekit/agents/pipeline/generation.py index 71224bee2..cf4cc764c 100644 --- a/livekit-agents/livekit/agents/pipeline/generation.py +++ b/livekit-agents/livekit/agents/pipeline/generation.py @@ -2,11 +2,22 @@ import asyncio from dataclasses import dataclass, field -from typing import AsyncIterable, Protocol, Tuple, runtime_checkable +from typing import ( + AsyncIterable, + Protocol, + Tuple, + runtime_checkable, +) from livekit import rtc -from ..llm import ChatChunk, ChatContext, FunctionContext +from .. import utils +from ..llm import ( + ChatChunk, + ChatContext, + FunctionContext, +) +from ..log import logger from ..utils import aio from . import io @@ -24,7 +35,7 @@ class _LLMGenerationData: generated_tools: list[FunctionCallInfo] = field(default_factory=list) -def do_llm_inference( +def perform_llm_inference( *, node: io.LLMNode, chat_ctx: ChatContext, fnc_ctx: FunctionContext | None ) -> Tuple[asyncio.Task, _LLMGenerationData]: text_ch = aio.Chan() @@ -85,7 +96,7 @@ class _TTSGenerationData: audio_ch: aio.Chan[rtc.AudioFrame] -def do_tts_inference( +def perform_tts_inference( *, node: io.TTSNode, input: AsyncIterable[str] ) -> Tuple[asyncio.Task, _TTSGenerationData]: audio_ch = aio.Chan[rtc.AudioFrame]() @@ -107,3 +118,57 @@ async def _inference_task(): tts_task.add_done_callback(lambda _: audio_ch.close()) return tts_task, _TTSGenerationData(audio_ch=audio_ch) + + +@dataclass +class _TextOutput: + text: str + + +def perform_text_forwarding( + *, text_output: io.TextSink, llm_output: AsyncIterable[str] +) -> tuple[asyncio.Task, _TextOutput]: + out = _TextOutput(text="") + task = asyncio.create_task(_text_forwarding_task(text_output, llm_output, out)) + return task, out + + +@utils.log_exceptions(logger=logger) +async def _text_forwarding_task( + text_output: io.TextSink, llm_output: AsyncIterable[str], out: _TextOutput +) -> None: + try: + async for delta in llm_output: + out.text += delta + await text_output.capture_text(delta) + finally: + text_output.flush() + + +@dataclass +class _AudioOutput: + audio: list[rtc.AudioFrame] + + +def perform_audio_forwarding( + *, + audio_output: io.AudioSink, + tts_output: AsyncIterable[rtc.AudioFrame], +) -> tuple[asyncio.Task, _AudioOutput]: + out = _AudioOutput(audio=[]) + task = asyncio.create_task(_audio_forwarding_task(audio_output, tts_output, out)) + return task, out + + +@utils.log_exceptions(logger=logger) +async def _audio_forwarding_task( + audio_output: io.AudioSink, + tts_output: AsyncIterable[rtc.AudioFrame], + out: _AudioOutput, +) -> None: + try: + async for frame in tts_output: + out.audio.append(frame) + await audio_output.capture_frame(frame) + finally: + audio_output.flush() diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index c2044d9c2..a7b244d1c 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -2,19 +2,14 @@ import asyncio from dataclasses import dataclass -from typing import ( - AsyncIterable, - Literal, - Optional -) +from typing import AsyncIterable, Literal from livekit import rtc from .. import debug, llm, utils from ..log import logger from . import io -from .room_io import RoomInput, RoomInputOptions, RoomOutput -from .agent_task import TaskActivity, AgentTask +from .agent_task import AgentTask, TaskActivity from .speech_handle import SpeechHandle EventTypes = Literal[ @@ -84,31 +79,11 @@ def __init__( async def start( self, - room: Optional[rtc.Room] = None, - room_input_options: Optional[RoomInputOptions] = None, ) -> None: - """Start the pipeline agent. - - Args: - room (Optional[rtc.Room]): The LiveKit room. If provided and no input/output audio - is set, automatically configures room audio I/O. - room_input_options (Optional[RoomInputOptions]): Options for the room input. - """ + """Start the pipeline agent.""" if self._started: return - if room is not None: - # configure room I/O if not already set - if self.input.audio is None: - room_input = RoomInput(room=room, options=room_input_options) - self._input.audio = room_input.audio - await room_input.wait_for_participant() - - if self.output.audio is None: - room_output = RoomOutput(room=room) - self._output.audio = room_output.audio - await room_output.start() - if self.input.audio is not None: self._forward_audio_atask = asyncio.create_task( self._forward_audio_task(), name="_forward_audio_task" diff --git a/livekit-agents/livekit/agents/pipeline/room_io.py b/livekit-agents/livekit/agents/pipeline/room_io.py index 3be659833..1cca0edc1 100644 --- a/livekit-agents/livekit/agents/pipeline/room_io.py +++ b/livekit-agents/livekit/agents/pipeline/room_io.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from dataclasses import dataclass from typing import AsyncIterator, Optional diff --git a/livekit-agents/livekit/agents/pipeline/speech_handle.py b/livekit-agents/livekit/agents/pipeline/speech_handle.py index 8ae63ece0..d4a49a50b 100644 --- a/livekit-agents/livekit/agents/pipeline/speech_handle.py +++ b/livekit-agents/livekit/agents/pipeline/speech_handle.py @@ -2,7 +2,7 @@ import asyncio import contextlib -from typing import Awaitable, Callable +from typing import Callable from .. import utils @@ -78,18 +78,7 @@ def _mark_playout_done(self) -> None: # will raise InvalidStateError if the future is already done (interrupted) self._playout_done_fut.set_result(None) - async def wait_until_interrupted(self, aw: list[Awaitable]) -> None: - temp_tasks = [] - tasks = [] - for task in aw: - if not isinstance(task, asyncio.Task): - task = asyncio.create_task(task) - temp_tasks.append(task) - tasks.append(task) - + async def wait_if_not_interrupted(self, aw: list[asyncio.futures.Future]) -> None: await asyncio.wait( - [*tasks, self._interrupt_fut], return_when=asyncio.FIRST_COMPLETED + [*aw, self._interrupt_fut], return_when=asyncio.FIRST_COMPLETED ) - for task in temp_tasks: - if not task.done(): - task.cancel() diff --git a/livekit-agents/livekit/agents/pipeline/tools.py b/livekit-agents/livekit/agents/pipeline/tools.py new file mode 100644 index 000000000..974759feb --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/tools.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +import asyncio +import inspect +import time +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, +) + +from pydantic import ValidationError + +from .. import debug, llm, utils +from ..llm import ( + AIError, + FunctionContext, + StopResponse, +) +from ..llm import ( + utils as llm_utils, +) +from ..log import logger +from .context import AgentContext +from .speech_handle import SpeechHandle + +if TYPE_CHECKING: + from .agent_task import AgentTask + + +def perform_tool_executions( + *, + agent_ctx: AgentContext, + speech_handle: SpeechHandle, + fnc_ctx: FunctionContext, + function_stream: AsyncIterable[llm.FunctionCall], +) -> tuple[asyncio.Task, list[tuple[llm.FunctionCallOutput | None, AgentTask | None]]]: + out: list[tuple[llm.FunctionCallOutput | None, AgentTask | None]] = [] + task = asyncio.create_task( + _execute_tools_task( + agent_ctx=agent_ctx, + speech_handle=speech_handle, + fnc_ctx=fnc_ctx, + function_stream=function_stream, + out=out, + ), + name="execute_tools_task", + ) + return task, out + + +@utils.log_exceptions(logger=logger) +async def _execute_tools_task( + *, + agent_ctx: AgentContext, + speech_handle: SpeechHandle, + fnc_ctx: FunctionContext, + function_stream: AsyncIterable[llm.FunctionCall], + out: list[tuple[llm.FunctionCallOutput | None, AgentTask | None]], +) -> None: + """execute tools, when cancelled, stop executing new tools but wait for the pending ones""" + tasks: list[asyncio.Task] = [] + try: + async for fnc_call in function_stream: + ai_function = fnc_ctx.ai_functions.get(fnc_call.name, None) + if ai_function is None: + logger.warning( + f"LLM called function `{fnc_call.name}` but it was not found in the current task", + extra={ + "function": fnc_call.name, + "speech_id": speech_handle.id, + }, + ) + continue + + try: + function_model = llm_utils.function_arguments_to_pydantic_model( + ai_function + ) + parsed_args = function_model.model_validate_json(fnc_call.arguments) + except ValidationError: + logger.exception( + "LLM called function `{fnc.name}` with invalid arguments", + extra={ + "function": fnc_call.name, + "arguments": fnc_call.arguments, + "speech_id": speech_handle.id, + }, + ) + continue + + logger.debug( + "executing tool", + extra={ + "function": fnc_call.name, + "speech_id": speech_handle.id, + }, + ) + debug.Tracing.log_event( + "executing tool", + { + "function": fnc_call.name, + "speech_id": speech_handle.id, + }, + ) + + fnc_args, fnc_kwargs = llm_utils.pydantic_model_to_function_arguments( + ai_function=ai_function, + model=parsed_args, + agent_ctx=agent_ctx, + speech_handle=speech_handle, + ) + + fnc_out = _FunctionCallOutput( + name=fnc_call.name, + arguments=fnc_call.arguments, + call_id=fnc_call.call_id, + output=None, + exception=None, + ) + + if inspect.iscoroutinefunction(ai_function): + task = asyncio.create_task( + ai_function(*fnc_args, **fnc_kwargs), + name=f"ai_function_{fnc_call.name}", + ) + tasks.append(task) + + def _log_exceptions(task: asyncio.Task) -> None: + if task.exception() is not None: + logger.error( + "exception occurred while executing tool", + extra={ + "function": fnc_call.name, + "speech_id": speech_handle.id, + }, + exc_info=task.exception(), + ) + fnc_out.exception = task.exception() + out.append(_sanitize_function_output(fnc_out)) + return + + fnc_out.output = task.result() + out.append(_sanitize_function_output(fnc_out)) + tasks.remove(task) + + task.add_done_callback(_log_exceptions) + else: + start_time = time.monotonic() + try: + output = ai_function(*fnc_args, **fnc_kwargs) + fnc_out.output = output + out.append(_sanitize_function_output(fnc_out)) + except Exception as e: + fnc_out.exception = e + out.append(_sanitize_function_output(fnc_out)) + + elapsed = time.monotonic() - start_time + if elapsed >= 1.5: + logger.warning( + f"function execution took too long ({elapsed:.2f}s), is `{fnc_call.name}` blocking?", + extra={ + "function": fnc_call.name, + "speech_id": speech_handle.id, + "elapsed": elapsed, + }, + ) + + except asyncio.CancelledError: + if len(tasks) > 0: + names = [task.get_name() for task in tasks] + logger.debug( + "waiting for function call to finish before fully cancelling", + extra={ + "functions": names, + "speech_id": speech_handle.id, + }, + ) + debug.Tracing.log_event( + "waiting for function call to finish before fully cancelling", + { + "functions": names, + "speech_id": speech_handle.id, + }, + ) + await asyncio.gather(*tasks) + finally: + if len(out) > 0: + logger.debug( + "tools execution completed", + extra={"speech_id": speech_handle.id}, + ) + debug.Tracing.log_event( + "tools execution completed", + {"speech_id": speech_handle.id}, + ) + + +def _is_valid_function_output(value: Any) -> bool: + VALID_TYPES = (str, int, float, bool, complex, type(None)) + + if isinstance(value, VALID_TYPES): + return True + elif ( + isinstance(value, list) + or isinstance(value, set) + or isinstance(value, frozenset) + or isinstance(value, tuple) + ): + return all(_is_valid_function_output(item) for item in value) + elif isinstance(value, dict): + return all( + isinstance(key, VALID_TYPES) and _is_valid_function_output(val) + for key, val in value.items() + ) + return False + + +@dataclass +class _FunctionCallOutput: + call_id: str + name: str + arguments: str + output: Any + exception: BaseException | None + + +def _sanitize_function_output( + out: _FunctionCallOutput, +) -> tuple[llm.FunctionCallOutput | None, AgentTask | None]: + from .agent_task import AgentTask + + if isinstance(out.exception, AIError): + return llm.FunctionCallOutput( + call_id=out.call_id, + output=out.exception.message, + is_error=True, + ), None + + if isinstance(out.exception, StopResponse): + return None, None + + if out.exception is not None: + logger.error( + "exception occurred while executing tool", + extra={ + "call_id": out.call_id, + "function": out.name, + }, + exc_info=out.exception, + ) + return llm.FunctionCallOutput( + call_id=out.call_id, + output="An internal error occurred", + is_error=True, + ), None + + fnc_out = out.output + + # find task if any + task: AgentTask | None = None + if isinstance(fnc_out, tuple): + agent_tasks = [item for item in fnc_out if isinstance(item, AgentTask)] + if len(agent_tasks) > 1: + logger.error( + "multiple AgentTask instances found in the function output tuple", + extra={ + "call_id": out.call_id, + "function": out.name, + "output": fnc_out, + }, + ) + return None, None + + if agent_tasks: + task = agent_tasks[0] + + fnc_out = [item for item in fnc_out if not isinstance(item, AgentTask)] + + if isinstance(fnc_out, AgentTask): + task = fnc_out + fnc_out = None + + # validate output without the task + if not _is_valid_function_output(fnc_out): + logger.error( + "invalid function output type", + extra={ + "call_id": out.call_id, + "function": out.name, + "output": fnc_out, + }, + ) + return None, None + + return llm.FunctionCallOutput( + call_id=out.call_id, + output=str(out.output), + is_error=False, + ), task diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index d8838f698..a153f5d54 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -13,10 +13,10 @@ from openai.types.beta.realtime import ( ConversationItem, ConversationItemContent, - ConversationItemCreateEvent, ConversationItemCreatedEvent, - ConversationItemDeleteEvent, + ConversationItemCreateEvent, ConversationItemDeletedEvent, + ConversationItemDeleteEvent, ConversationItemTruncateEvent, ErrorEvent, InputAudioBufferAppendEvent, @@ -36,6 +36,7 @@ SessionUpdateEvent, session_update_event, ) +from openai.types.beta.realtime.response_create_event import Response from .log import logger @@ -111,6 +112,12 @@ def __init__(self, realtime_model: RealtimeModel) -> None: self._main_task(), name="RealtimeSession._main_task" ) + self._response_created_futures: dict[ + str, asyncio.Future[multimodal.GenerationCreatedEvent] + ] = {} + self._item_delete_future: dict[str, asyncio.Future] = {} + self._item_create_future: dict[str, asyncio.Future] = {} + self._current_generation: _ResponseGeneration | None = None self._remote_chat_ctx = llm.remote_chat_context.RemoteChatContext() @@ -154,9 +161,6 @@ async def _listen_for_events() -> None: elif event.type == "error": self._handle_error(event) - if event.type != "response.audio.delta": - print(event) - @utils.log_exceptions(logger=logger) async def _forward_input() -> None: async for msg in self._msg_ch: @@ -200,7 +204,7 @@ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: self._remote_chat_ctx.to_chat_ctx(), chat_ctx ) - # futs = [] + futs = [] for msg_id in diff_ops.to_remove: event_id = utils.shortuuid("chat_ctx_delete_") @@ -211,8 +215,8 @@ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: event_id=event_id, ) ) - # futs.append(f := asyncio.Future()) - # self._response_futures[event_id] = f + futs.append(f := asyncio.Future()) + self._item_delete_future[msg_id] = f for previous_msg_id, msg_id in diff_ops.to_create: event_id = utils.shortuuid("chat_ctx_create_") @@ -229,10 +233,15 @@ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None: event_id=event_id, ) ) - # futs.append(f := asyncio.Future()) - # self._response_futures[event_id] = f + futs.append(f := asyncio.Future()) + self._item_create_future[msg_id] = f - # await asyncio.gather(*futs, return_exceptions=True) + try: + await asyncio.wait_for( + asyncio.gather(*futs, return_exceptions=True), timeout=5.0 + ) + except asyncio.TimeoutError: + raise multimodal.RealtimeError("update_chat_ctx timed out.") from None async def update_fnc_ctx( self, fnc_ctx: llm.FunctionContext | list[llm.AIFunction] @@ -301,13 +310,24 @@ def push_audio(self, frame: rtc.AudioFrame) -> None: ) def generate_reply(self) -> asyncio.Future[multimodal.GenerationCreatedEvent]: - f = asyncio.Future() event_id = utils.shortuuid("response_create_") + fut = asyncio.Future() + self._response_created_futures[event_id] = fut self._msg_ch.send_nowait( - ResponseCreateEvent(type="response.create", event_id=event_id) + ResponseCreateEvent( + type="response.create", + event_id=event_id, + response=Response(metadata={"client_event_id": event_id}), + ) ) - # self._response_futures[event_id] = f - return f + + def _on_timeout() -> None: + if fut and not fut.done(): + fut.set_exception(multimodal.RealtimeError("generate_reply timed out.")) + + handle = asyncio.get_event_loop().call_later(5.0, _on_timeout) + fut.add_done_callback(lambda _: handle.cancel()) + return fut def interrupt(self) -> None: self._msg_ch.send_nowait(ResponseCancelEvent(type="response.cancel")) @@ -352,9 +372,11 @@ def _handle_response_created(self, event: ResponseCreatedEvent) -> None: self.emit("generation_created", generation_ev) - # fut = self._response_futures.pop(event.event_id, None) - # if fut is not None and not fut.done(): - # fut.set_result(generation_ev) + if isinstance(event.response.metadata, dict) and ( + client_event_id := event.response.metadata.get("client_event_id") + ): + if fut := self._response_created_futures.pop(client_event_id, None): + fut.set_result(generation_ev) def _handle_response_output_item_added( self, event: ResponseOutputItemAddedEvent @@ -381,15 +403,24 @@ def _handle_response_output_item_added( def _handle_conversion_item_created( self, event: ConversationItemCreatedEvent ) -> None: + assert event.item.id is not None, "item.id is None" + self._remote_chat_ctx.insert( event.previous_item_id, _openai_item_to_livekit_item(event.item) ) + if fut := self._item_create_future.pop(event.item.id, None): + fut.set_result(None) def _handle_conversion_item_deleted( self, event: ConversationItemDeletedEvent ) -> None: + assert event.item_id is not None, "item_id is None" + self._remote_chat_ctx.delete(event.item_id) + if fut := self._item_delete_future.pop(event.item_id, None): + fut.set_result(None) + def _handle_response_audio_transcript_delta( self, event: ResponseAudioTranscriptDeltaEvent ) -> None: @@ -447,6 +478,7 @@ def _handle_response_output_item_done( def _handle_response_done(self, _: ResponseDoneEvent) -> None: assert self._current_generation is not None, "current_generation is None" self._current_generation.function_ch.close() + self._current_generation.message_ch.close() self._current_generation = None def _handle_error(self, event: ErrorEvent) -> None: From fcd471df4242de37692c2608b4aa7f617961439c Mon Sep 17 00:00:00 2001 From: David Zhao Date: Thu, 30 Jan 2025 22:08:00 -0800 Subject: [PATCH 19/19] fix roomio_worker --- examples/roomio_worker.py | 38 +++++-------------- .../livekit-plugins-openai/setup.py | 2 +- 2 files changed, 10 insertions(+), 30 deletions(-) diff --git a/examples/roomio_worker.py b/examples/roomio_worker.py index e65f698eb..76b310baf 100644 --- a/examples/roomio_worker.py +++ b/examples/roomio_worker.py @@ -4,10 +4,10 @@ from livekit.agents import JobContext, WorkerOptions, WorkerType, cli from livekit.agents.pipeline import AgentTask, PipelineAgent from livekit.agents.pipeline.io import PlaybackFinishedEvent -from livekit.agents.pipeline.room_io import RoomInputOptions +from livekit.agents.pipeline.room_io import RoomInput, RoomOutput from livekit.plugins import openai -logger = logging.getLogger("my-worker") +logger = logging.getLogger("roomio-example") logger.setLevel(logging.INFO) load_dotenv() @@ -23,34 +23,14 @@ async def entrypoint(ctx: JobContext): ) ) - # default use RoomIO if room is provided - await agent.start( - room=ctx.room, - room_input_options=RoomInputOptions( - audio_enabled=True, - video_enabled=False, - audio_sample_rate=24000, - audio_num_channels=1, - ), - ) - - # # Or use RoomInput and RoomOutput explicitly - # room_input = RoomInput( - # ctx.room, - # options=RoomInputOptions( - # audio_enabled=True, - # video_enabled=False, - # audio_sample_rate=24000, - # audio_num_channels=1, - # ), - # ) - # room_output = RoomOutput(ctx.room, sample_rate=24000, num_channels=1) - - # agent.input.audio = room_input.audio - # agent.output.audio = room_output.audio + room_input = RoomInput(ctx.room) + agent.input.audio = room_input.audio + room_output = RoomOutput(room=ctx.room, sample_rate=24000, num_channels=1) + agent.output.audio = room_output.audio + await room_input.wait_for_participant() + await room_output.start() - # await room_input.wait_for_participant() - # await room_output.start() + await agent.start() # TODO: the interrupted flag is not set correctly @agent.output.audio.on("playback_finished") diff --git a/livekit-plugins/livekit-plugins-openai/setup.py b/livekit-plugins/livekit-plugins-openai/setup.py index eb9d6d0fe..e8540f974 100644 --- a/livekit-plugins/livekit-plugins-openai/setup.py +++ b/livekit-plugins/livekit-plugins-openai/setup.py @@ -49,7 +49,7 @@ python_requires=">=3.9.0", install_requires=[ "livekit-agents[codecs, images]>=0.12.3", - "openai>=1.50", + "openai>=1.60", ], extras_require={ "vertex": ["google-auth>=2.0.0"],