Skip to content

Commit

Permalink
tts_say action should use TTSSpeakFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
markbackman committed Feb 3, 2025
1 parent 92694ab commit 45ab5bf
Show file tree
Hide file tree
Showing 18 changed files with 62 additions and 96 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ Example usage:
context = flow_manager.get_current_context()
```

### Deprecated

- The `tts` parameter in `FlowManager.__init__()` is now deprecated and will
be removed in a future version. The `tts_say` action now pushes a
`TTSSpeakFrame`.

## [0.0.12] - 2025-01-30

### Added
Expand Down
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ flow_manager = FlowManager(
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
flow_config=flow_config,
)

Expand Down Expand Up @@ -288,7 +287,6 @@ flow_manager = FlowManager(
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
flow_config=flow_config,
)
await flow_manager.initialize()
Expand Down Expand Up @@ -350,7 +348,6 @@ flow_manager = FlowManager(
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
)
await flow_manager.initialize()

Expand Down
1 change: 0 additions & 1 deletion examples/dynamic/insurance_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,6 @@ async def main():
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
)

@transport.event_handler("on_first_participant_joined")
Expand Down
1 change: 0 additions & 1 deletion examples/dynamic/insurance_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ async def main():
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
)

@transport.event_handler("on_first_participant_joined")
Expand Down
4 changes: 3 additions & 1 deletion examples/dynamic/insurance_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,9 @@ async def main():

# Initialize flow manager with transition callback
flow_manager = FlowManager(
task=task, llm=llm, context_aggregator=context_aggregator, tts=tts
task=task,
llm=llm,
context_aggregator=context_aggregator,
)

@transport.event_handler("on_first_participant_joined")
Expand Down
1 change: 0 additions & 1 deletion examples/static/food_ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ async def main():
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
flow_config=flow_config,
)

Expand Down
1 change: 0 additions & 1 deletion examples/static/movie_explorer_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,6 @@ async def main():
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
flow_config=flow_config,
)

Expand Down
1 change: 0 additions & 1 deletion examples/static/movie_explorer_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,6 @@ async def main():
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
flow_config=flow_config,
)

Expand Down
1 change: 0 additions & 1 deletion examples/static/movie_explorer_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,6 @@ async def main():
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
flow_config=flow_config,
)

Expand Down
1 change: 0 additions & 1 deletion examples/static/patient_intake_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,6 @@ async def main():
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
flow_config=flow_config,
)

Expand Down
1 change: 0 additions & 1 deletion examples/static/patient_intake_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,6 @@ async def main():
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
flow_config=flow_config,
)

Expand Down
1 change: 0 additions & 1 deletion examples/static/patient_intake_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,6 @@ async def main():
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
flow_config=flow_config,
)

Expand Down
1 change: 0 additions & 1 deletion examples/static/restaurant_reservation.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ async def main():
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
flow_config=flow_config,
)

Expand Down
1 change: 0 additions & 1 deletion examples/static/travel_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,6 @@ async def main():
task=task,
llm=llm,
context_aggregator=context_aggregator,
tts=tts,
flow_config=flow_config,
)

Expand Down
16 changes: 4 additions & 12 deletions src/pipecat_flows/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""

import asyncio
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional

from loguru import logger
from pipecat.frames.frames import (
Expand All @@ -45,22 +45,20 @@ class ActionManager:
- Custom user-defined actions
Built-in actions:
- tts_say: Speak text using TTS
- tts_say: Speak text using TTSSpeakFrame
- end_conversation: End the current conversation
Custom actions can be registered using register_action().
"""

def __init__(self, task: PipelineTask, tts=None):
def __init__(self, task: PipelineTask):
"""Initialize the action manager.
Args:
task: PipelineTask instance used to queue frames
tts: Optional TTS service for voice actions
"""
self.action_handlers: Dict[str, Callable] = {}
self.task = task
self.tts = tts

# Register built-in actions
self._register_action("tts_say", self._handle_tts_action)
Expand Down Expand Up @@ -120,19 +118,13 @@ async def _handle_tts_action(self, action: dict) -> None:
Args:
action: Action configuration containing 'text' to speak
"""
if not self.tts:
logger.warning("TTS action called but no TTS service provided")
return

text = action.get("text")
if not text:
logger.error("TTS action missing 'text' field")
return

try:
await self.tts.say(text)
# TODO: Update to TTSSpeakFrame once Pipecat is fixed
# await self.task.queue_frame(TTSSpeakFrame(text=action["text"]))
await self.task.queue_frame(TTSSpeakFrame(text=action["text"]))
except Exception as e:
logger.error(f"TTS error: {e}")

Expand Down
16 changes: 12 additions & 4 deletions src/pipecat_flows/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import copy
import inspect
import sys
import warnings
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union, cast

from loguru import logger
Expand Down Expand Up @@ -70,7 +71,6 @@ class FlowManager:
Attributes:
task: Pipeline task for frame queueing
llm: LLM service instance (OpenAI, Anthropic, or Google)
tts: Optional TTS service for voice actions
state: Shared state dictionary across nodes
current_node: Currently active node identifier
initialized: Whether the manager has been initialized
Expand All @@ -94,18 +94,26 @@ def __init__(
task: PipelineTask instance for queueing frames
llm: LLM service instance (e.g., OpenAI, Anthropic)
context_aggregator: Context aggregator for updating user context
tts: Optional TTS service for voice actions
tts: Optional TTS service for voice actions (deprecated)
flow_config: Optional static flow configuration. If provided,
operates in static mode with predefined nodes
context_strategy: Optional context strategy configuration
Raises:
ValueError: If any transition handler is not a valid async callable
Deprecated:
0.0.13: The `tts` parameter is deprecated and will be removed in a future version.
"""
if tts is not None:
warnings.warn(
"The 'tts' parameter is deprecated and will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)

self.task = task
self.llm = llm
self.tts = tts
self.action_manager = ActionManager(task, tts)
self.action_manager = ActionManager(task)
self.adapter = create_adapter(llm)
self.initialized = False
self._context_aggregator = context_aggregator
Expand Down
73 changes: 23 additions & 50 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
- Custom action registration and execution
- Error handling and validation
- Action sequencing
- TTS service integration
- Frame queueing
The tests use unittest.IsolatedAsyncioTestCase for async support and include
mocked dependencies for PipelineTask and TTS service.
mocked dependencies for PipelineTask.
"""

import unittest
Expand All @@ -38,7 +37,6 @@ class TestActionManager(unittest.IsolatedAsyncioTestCase):
- Custom action registration
- Action execution sequencing
- Error handling:
- Missing TTS service
- Invalid actions
- Failed handlers
- Multiple action execution
Expand All @@ -52,39 +50,32 @@ class TestActionManager(unittest.IsolatedAsyncioTestCase):
"""

def setUp(self):
"""
Set up test fixtures before each test.
"""Set up test fixtures before each test.
Creates:
- Mock PipelineTask for frame queueing
- Mock TTS service for speech synthesis
- ActionManager instance with mocked dependencies
"""
self.mock_task = AsyncMock()
self.mock_task.queue_frame = AsyncMock()

self.mock_tts = AsyncMock()
self.mock_tts.say = AsyncMock()

self.action_manager = ActionManager(self.mock_task, self.mock_tts)
self.action_manager = ActionManager(self.mock_task)

async def test_initialization(self):
"""Test ActionManager initialization and default handlers."""
# Verify built-in action handlers are registered
self.assertIn("tts_say", self.action_manager.action_handlers)
self.assertIn("end_conversation", self.action_manager.action_handlers)

# Test initialization without TTS service
action_manager_no_tts = ActionManager(self.mock_task, None)
self.assertIsNone(action_manager_no_tts.tts)

async def test_tts_action(self):
"""Test basic TTS action execution."""
action = {"type": "tts_say", "text": "Hello"}
await self.action_manager.execute_actions([action])

# Verify TTS service was called with correct text
self.mock_tts.say.assert_called_once_with("Hello")
# Verify TTSSpeakFrame was queued with correct text
self.mock_task.queue_frame.assert_called_once()
frame = self.mock_task.queue_frame.call_args[0][0]
self.assertIsInstance(frame, TTSSpeakFrame)
self.assertEqual(frame.text, "Hello")

@patch("loguru.logger.error")
async def test_tts_action_no_text(self, mock_logger):
Expand All @@ -97,22 +88,7 @@ async def test_tts_action_no_text(self, mock_logger):
# Verify error was logged
mock_logger.assert_called_with("TTS action missing 'text' field")

# Verify TTS service was not called
self.mock_tts.say.assert_not_called()

@patch("loguru.logger.warning")
async def test_tts_action_no_service(self, mock_logger):
"""Test TTS action when no TTS service is provided."""
action_manager = ActionManager(self.mock_task, None)
action = {"type": "tts_say", "text": "Hello"}

# Should log warning but not raise error
await action_manager.execute_actions([action])

# Verify warning was logged
mock_logger.assert_called_with("TTS action called but no TTS service provided")

# Verify no frames were queued
# Verify no frame was queued
self.mock_task.queue_frame.assert_not_called()

async def test_end_conversation_action(self):
Expand Down Expand Up @@ -177,10 +153,14 @@ async def test_multiple_actions(self):
]
await self.action_manager.execute_actions(actions)

# Verify TTS was called twice in correct order
self.assertEqual(self.mock_tts.say.call_count, 2)
expected_calls = [unittest.mock.call("First"), unittest.mock.call("Second")]
self.assertEqual(self.mock_tts.say.call_args_list, expected_calls)
# Verify frames were queued in correct order
self.assertEqual(self.mock_task.queue_frame.call_count, 2)

first_frame = self.mock_task.queue_frame.call_args_list[0][0][0]
self.assertEqual(first_frame.text, "First")

second_frame = self.mock_task.queue_frame.call_args_list[1][0][0]
self.assertEqual(second_frame.text, "Second")

def test_register_invalid_handler(self):
"""Test registering invalid action handlers."""
Expand All @@ -199,41 +179,34 @@ async def test_none_or_empty_actions(self):
# Test None actions
await self.action_manager.execute_actions(None)
self.mock_task.queue_frame.assert_not_called()
self.mock_tts.say.assert_not_called()

# Test empty list
await self.action_manager.execute_actions([])
self.mock_task.queue_frame.assert_not_called()
self.mock_tts.say.assert_not_called()

@patch("loguru.logger.error")
async def test_action_error_handling(self, mock_logger):
"""Test error handling during action execution."""
# Configure TTS mock to raise an error
self.mock_tts.say.side_effect = Exception("TTS error")
# Configure task mock to raise an error
self.mock_task.queue_frame.side_effect = Exception("Frame error")

action = {"type": "tts_say", "text": "Hello"}
await self.action_manager.execute_actions([action])

# Verify error was logged
mock_logger.assert_called_with("TTS error: TTS error")

# Verify action was still marked as executed (doesn't raise)
self.mock_tts.say.assert_called_once()
mock_logger.assert_called_with("TTS error: Frame error")

async def test_action_execution_error_handling(self):
"""Test error handling during action execution."""
action_manager = ActionManager(self.mock_task, self.mock_tts)

# Test action with missing handler
with self.assertRaises(ActionError):
await action_manager.execute_actions([{"type": "nonexistent_action"}])
await self.action_manager.execute_actions([{"type": "nonexistent_action"}])

# Test action handler that raises an exception
async def failing_handler(action):
raise Exception("Handler error")

action_manager._register_action("failing_action", failing_handler)
self.action_manager._register_action("failing_action", failing_handler)

with self.assertRaises(ActionError):
await action_manager.execute_actions([{"type": "failing_action"}])
await self.action_manager.execute_actions([{"type": "failing_action"}])
Loading

0 comments on commit 45ab5bf

Please sign in to comment.