Skip to content

Commit

Permalink
Add Gemini adapter for get_recent_tool_call_pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
markbackman committed Jan 29, 2025
1 parent 9df3c27 commit b3d520a
Showing 1 changed file with 84 additions and 17 deletions.
101 changes: 84 additions & 17 deletions src/pipecat_flows/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

from loguru import logger

Expand Down Expand Up @@ -126,7 +126,23 @@ def format_functions(self, functions: List[Dict[str, Any]]) -> List[Dict[str, An
return functions

def get_recent_tool_call_pairs(self, messages: List[dict]) -> List[dict]:
"""Get recent consecutive tool calls from OpenAI message format."""
"""Gets consecutive function call/response pairs from message history.
Processes messages in reverse order to find the most recent function
interactions. Stops at the first regular message to maintain
conversation context.
Args:
messages: List of messages in OpenAI format
Returns:
List of messages containing matched function call/response pairs
in chronological order.
Note:
Function calls must be from "assistant" role with tool_calls,
and responses must have "tool" role with matching tool_call_id.
"""
tool_messages = []

for i in range(len(messages) - 1, -1, -1):
Expand Down Expand Up @@ -223,7 +239,23 @@ def format_functions(self, functions: List[Dict[str, Any]]) -> List[Dict[str, An
return formatted

def get_recent_tool_call_pairs(self, messages: List[dict]) -> List[dict]:
"""Get recent consecutive tool calls from Anthropic message format."""
"""Gets consecutive function call/response pairs from message history.
Processes messages in reverse order to find the most recent function
interactions. Stops at the first regular message to maintain
conversation context.
Args:
messages: List of messages in Anthropic format
Returns:
List of messages containing matched function call/response pairs
in chronological order.
Note:
Function calls must be from "assistant" role with tool_use content,
and responses must be from "user" role with tool_result content.
"""
tool_messages = []

for i in range(len(messages) - 1, -1, -1):
Expand Down Expand Up @@ -331,25 +363,60 @@ def format_functions(self, functions: List[Dict[str, Any]]) -> List[Dict[str, An
return [{"function_declarations": all_declarations}] if all_declarations else []

def get_recent_tool_call_pairs(self, messages: List[dict]) -> List[dict]:
"""Get recent consecutive tool calls from Gemini message format."""
"""Gets consecutive function call/response pairs from message history.
Processes messages in reverse order to find the most recent function
interactions. Stops at the first regular text message to maintain
conversation context.
Args:
messages: List of messages in Gemini format
Returns:
List of messages containing matched function call/response pairs
in chronological order.
Note:
Function calls must be from "model" role and responses from "user" role.
Names must match between call and response.
"""
tool_messages = []

# Process messages in reverse order
for i in range(len(messages) - 1, -1, -1):
# If we hit a regular message, stop collecting
if not (
messages[i].get("parts", [{}])[0].get("function_response")
or messages[i].get("parts", [{}])[0].get("function_call")
):
current_message = messages[i]

# Stop at first regular text message
if len(current_message.parts) > 0 and str(current_message.parts[0]).startswith("text:"):
break

# Collect tool call pairs
if (
messages[i].get("role") == "user"
and messages[i].get("parts", [{}])[0].get("function_response")
and i > 0
and messages[i - 1].get("role") == "model"
):
tool_messages[0:0] = [messages[i - 1], messages[i]]
try:
part = current_message.parts[0]
# Check if current message is a function response
is_function_response = (
current_message.role == "user"
and len(current_message.parts) > 0
and str(part).startswith("function_response {")
)

if is_function_response and i > 0:
prev_message = messages[i - 1]
prev_part = prev_message.parts[0]

# Check if previous message is matching function call
is_matching_function_call = (
prev_message.role == "model"
and len(prev_message.parts) > 0
and str(prev_part).startswith("function_call {")
and prev_part.function_call.name == part.function_response.name
)

# Add pair to start of list to maintain chronological order
if is_matching_function_call:
tool_messages[0:0] = [prev_message, current_message]

except Exception:
continue

return tool_messages

Expand Down

0 comments on commit b3d520a

Please sign in to comment.