Skip to content

Commit

Permalink
Minor cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
sternakt committed Jan 22, 2025
1 parent f5a8ebc commit c042f4c
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions autogen/agentchat/realtime_agent/clients/gemini/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def connection(self) -> "ClientConnection":
raise RuntimeError("Gemini WebSocket is not initialized")
return self._connection

async def send_function_result(self, call_id: str, result: str) -> None: # Looks like Gemini doesn't results.
async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the Gemini Realtime API.
Args:
Expand Down Expand Up @@ -128,6 +128,7 @@ async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: s
pass

async def _initialize_session(self) -> None:
"""Initialize the session with the Gemini Realtime API."""
session_config = {
"setup": {
"system_instruction": {
Expand All @@ -141,9 +142,7 @@ async def _initialize_session(self) -> None:
{
"name": tool_schema["name"],
"description": tool_schema["description"],
"parameters": tool_schema[
"parameters"
], # GeminiClient._create_gemini_function_parameters(tool_schema["parameters"]),
"parameters": tool_schema["parameters"],
}
for tool_schema in self._pending_session_updates.get("tools", [])
]
Expand All @@ -161,10 +160,13 @@ async def _initialize_session(self) -> None:
await self.connection.send(json.dumps(session_config))

async def session_update(self, session_options: dict[str, Any]) -> None:
"""Record or apply session updates."""
"""Record session updates to be applied when the connection is established.
Args:
session_options (dict[str, Any]): The session options to update.
"""
if self._is_reading_events:
self.logger.warning("Is reading events. Session update will be ignored.")
# Record session updates
else:
self._pending_session_updates.update(session_options)

Expand All @@ -180,7 +182,7 @@ async def connect(self) -> AsyncGenerator[None, None]:
self._connection = None

async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read Audio Events"""
"""Read Events from the Gemini Realtime API."""
if self._connection is None:
raise RuntimeError("Client is not connected, call connect() first.")
await self._initialize_session()
Expand All @@ -194,7 +196,14 @@ async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
yield event

def _parse_message(self, response: dict[str, Any]) -> list[RealtimeEvent]:
# Determine the type of message and dispatch it to the appropriate handler
"""Parse a message from the Gemini Realtime API.
Args:
response (dict[str, Any]): The response to parse.
Returns:
list[RealtimeEvent]: The parsed events.
"""
if "serverContent" in response and "modelTurn" in response["serverContent"]:
try:
b64data = response["serverContent"]["modelTurn"]["parts"][0]["inlineData"].pop("data")
Expand Down Expand Up @@ -232,8 +241,7 @@ def get_factory(
Args:
model (str): The model to create the client for.
voice (str): The voice to use.
system_message (str): The system message to use.
logger (Logger): The logger for the client.
kwargs (Any): Additional arguments.
Returns:
Expand Down

0 comments on commit c042f4c

Please sign in to comment.