Skip to content

Commit

Permalink
Reworked nested chat
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Sze <[email protected]>
  • Loading branch information
marklysze committed Dec 2, 2024
1 parent d1e27c8 commit ced3c65
Showing 1 changed file with 177 additions and 94 deletions.
271 changes: 177 additions & 94 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import copy
import inspect
import json
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -40,43 +41,10 @@ def __post_init__(self):
self.agent = AfterWorkOption(self.agent.upper())


@dataclass
class NESTED_CHAT_CONFIG:
chat_list: List[Dict[str, Any]]
starting_message_method: Optional[Union[str, Callable]] = None
starting_llm_summary_args: Optional[Dict[str, Any]] = None

def __post_init__(self):
assert isinstance(self.chat_list, list) and self.chat_list, "'chat_list' must be a non-empty list"
assert all(isinstance(chat, dict) for chat in self.chat_list), "'chat_list' must be a list of dictionaries"
assert isinstance(
self.starting_message_method, (str, Callable)
), "'starting_message_method' must be a string or callable"

if self.starting_llm_summary_args is not None:
assert (
self.starting_message_method == "llm_summary"
), "If 'starting_llm_summary_args' is provided, 'starting_message_method' must be 'carryover_llm_summary'"

if isinstance(self.starting_message_method, str):
assert self.starting_message_method in [
"carryover",
"carryover_last_msg",
"carryover_llm_summary",
], "'starting_message_method' must be 'carryover', 'carryover_last_msg', 'carryover_llm_summary' or a callable"
assert "message" in self.chat_list[0], "All carryovers need the first chat_list item to have a 'message'"

if isinstance(self.starting_message_method, Callable):
if "message" in self.chat_list[0]:
raise ValueError(
"If 'starting_message_method' is a callable, the first chat_list item can not have a 'message'. The callable will return the message."
)


@dataclass
class ON_CONDITION:
agent: Optional["SwarmAgent"] = None
nested_chat: Optional[NESTED_CHAT_CONFIG] = None
nested_chat: Optional[Dict[str, Any]] = None
condition: str = ""

def __post_init__(self):
Expand All @@ -85,7 +53,7 @@ def __post_init__(self):
assert isinstance(self.agent, SwarmAgent), "'agent' must be a SwarmAgent"

if self.nested_chat is not None:
assert isinstance(self.nested_chat, NESTED_CHAT_CONFIG), "'nested_chat' must be a NESTED_CHAT_CONFIG"
assert isinstance(self.nested_chat, Dict), "'nested_chat' must be a Dict"

# Ensure they have an agent or nested_chat
assert self.agent is not None or self.nested_chat is not None, "'agent' or 'nested_chat' must be provided"
Expand Down Expand Up @@ -405,72 +373,76 @@ def transfer_to_agent() -> "SwarmAgent":
# Transition to a nested chat

# Create closure (see above note)
def make_transfer_nested_function(nested_chat_config: NESTED_CHAT_CONFIG):
def make_transfer_nested_function(
chat_queue: List[Dict[str, Any]],
config: Optional[Any],
reply_func_from_nested_chats: Union[str, Callable],
use_async: bool,
):
# _reply_func = reply_func_from_nested_chats # Explicitly store parameter

def transfer_to_nested_chat() -> str:

# All messages excluding the tool call message to get here
current_messages = self._groupchatmanager.groupchat.messages[:-1]
starting_message = [{"content": "", "role": "user"}]

if "message" in nested_chat_config.chat_list[0]:
starting_message[0]["content"] = nested_chat_config.chat_list[0]["message"]

carry_over_message = ""

if nested_chat_config.starting_message_method == "carryover":
# Carryovers put a string concatenated value of messages into the first message
# All carryovers need the "message" parameter as well
# (e.g. message = <first nested chat message>\nContext: \n<swarm message 1>\n<swarm message 2>\n...)
carry_over_message = current_messages

elif nested_chat_config.starting_message_method == "carryover_last_msg":
# (e.g. message = <first nested chat message>\nContext: \n<last swarm message>)
carry_over_message = current_messages[-1]["content"]

elif nested_chat_config.starting_message_method == "carryover_llm_summary":
# We need to remove the last tool message from the messages before running inference, as the last message can't be a tool call
last_tool_message = self._oai_messages[self._groupchatmanager].pop()

carry_over_message = ConversableAgent._reflection_with_llm_as_summary(
sender=self._groupchatmanager,
recipient=self,
summary_args=(
nested_chat_config.starting_llm_summary_args
if nested_chat_config.starting_llm_summary_args
else {}
),
)

self._oai_messages[self._groupchatmanager].append(
last_tool_message
) # Restore the tool message

elif isinstance(nested_chat_config.starting_message_method, Callable):
nested_chat_config.chat_list[0]["message"] = nested_chat_config.starting_message_method(
context_variables=self.get_swarm_context_variables(),
messages=self._groupchatmanager.groupchat.messages,
)

if carry_over_message:
nested_chat_config.chat_list[0]["carryover"] = carry_over_message

print("In transfer_to_nested_chat")
self.register_nested_chats(
nested_chat_config.chat_list, trigger=lambda sender: True, position=0
# All messages, excluding the tool call message for swarm
base_messages = copy.deepcopy(self.chat_messages[self._groupchatmanager])
base_messages.pop()

# Note: This flow is based on ConversableAgent.register_nested_chats as we are doing this instead of registering a nested chat

if use_async:
for chat in chat_queue:
if chat.get("chat_id") is None:
raise ValueError("chat_id is required for async nested chats")

if use_async:
if callable(reply_func_from_nested_chats):
_reply_func = (
reply_func_from_nested_chats # Have to re-assign in this nested function
)
elif reply_func_from_nested_chats == "summary_from_nested_chats":
_reply_func = self._a_summary_from_nested_chats

if not callable(_reply_func) or not inspect.iscoroutinefunction(_reply_func):
raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine")

else:
if callable(reply_func_from_nested_chats):
_reply_func = (
reply_func_from_nested_chats # Have to re-assign in this nested function
)
elif reply_func_from_nested_chats == "summary_from_nested_chats":
_reply_func = self._summary_from_nested_chats
if not callable(_reply_func):
raise ValueError("reply_func_from_nested_chats must be a callable")

# Run the summary_from_nested_chats, or equivalent callable, to get the final output of the nested chat
# Recipient will be the SwarmAgent the function is registered to.
_, reply_str = _reply_func(
chat_queue=chat_queue,
recipient=self,
messages=base_messages,
sender=self._groupchatmanager,
config=config,
)

# Note: If we pass a list of messages in, the nested chat always
# extracts and uses just the last message. This is the reason we use carryovers.
reply = self.generate_reply(sender=self, messages=starting_message)

# Remove the registered nested chat we added
self._reply_func_list.pop(0)

return reply
return reply_str

return transfer_to_nested_chat

transfer_func = make_transfer_nested_function(transit.nested_chat)
# Extract the nested chat configuration
chat_queue = transit.nested_chat["chat_queue"]
config = transit.nested_chat.get("config", None)
config_reply_func_from_nested_chats = transit.nested_chat.get("reply_func_from_nested_chats", None)
if not config_reply_func_from_nested_chats:
config_reply_func_from_nested_chats = "summary_from_nested_chats"
use_async = transit.nested_chat.get("use_async", False)

# Make the function for the nested chat
transfer_func = make_transfer_nested_function(
chat_queue, config, config_reply_func_from_nested_chats, use_async
)

# Add the function to the agent so it can be triggered as a tool call
self.add_single_function(
transfer_func, f"transfer_to_nested_chat_{len(self._function_map)}", transit.condition
)
Expand Down Expand Up @@ -604,6 +576,117 @@ def get_swarm_context_variables(self) -> Dict[str, Any]:

raise Exception("Tool Execution agent not found")

@staticmethod
def process_nested_chat_carryover(
chat: Dict[str, Any], recipient: ConversableAgent, messages: List[Dict[str, Any]], sender: ConversableAgent
) -> None:
"""Process carryover messages for a nested chat (typically for the first chat of a swarm)
The carryover_config key is a dictionary containing:
"summary_method": The method to use to summarise the messages, can be "all", "last_msg", "reflection_with_llm" or a Callable
"summary_args": Optional arguments for the summary method
Supported carryover 'summary_methods' are:
"all" - all messages will be incorporated
"last_msg" - the last message will be incorporated
"reflection_with_llm" - an llm will summarise all the messages and the summary will be incorporated as a single message
Callable - a callable with the signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str
"""

def concat_carryover(chat_message: str, carryover_message: Union[str, List[Dict[str, Any]]]) -> str:
"""Concatenate the carryover message to the chat message."""
prefix = f"{chat_message}\n" if chat_message else ""

if isinstance(carryover_message, str):
content = carryover_message
elif isinstance(carryover_message, list):
content = "\n".join(
msg["content"] for msg in carryover_message if "content" in msg and msg["content"] is not None
)
else:
raise ValueError("Carryover message must be a string or a list of dictionaries")

return f"{prefix}Context:\n{content}"

carryover_config = chat["carryover_config"]

if "summary_method" not in carryover_config:
raise ValueError("Carryover configuration must contain a 'summary_method' key")

carryover_summary_method = carryover_config["summary_method"]
carryover_summary_args = carryover_config.get("summary_args") or {}

chat_message = chat.get("message", "")

if carryover_summary_method == "all":
# Put a string concatenated value of all parent messages into the first message
# (e.g. message = <first nested chat message>\nContext: \n<swarm message 1>\n<swarm message 2>\n...)
carry_over_message = concat_carryover(chat_message, messages)

elif carryover_summary_method == "last_msg":
# (e.g. message = <first nested chat message>\nContext: \n<last swarm message>)
carry_over_message = concat_carryover(chat_message, messages[-1]["content"])

elif carryover_summary_method == "reflection_with_llm":
# If the last message is a tool call, we need to remove it (typical for Swarm as this is triggered by a tool call)
restore_tool_call = False
if "tool_calls" in recipient._oai_messages[sender][-1]:
last_tool_message = recipient._oai_messages[sender].pop()
restore_tool_call = True

carry_over_message_llm = ConversableAgent._reflection_with_llm_as_summary(
sender=sender,
recipient=recipient,
summary_args=carryover_summary_args,
)

carry_over_message = concat_carryover(chat_message, carry_over_message_llm)

# Restore the tool call message
if restore_tool_call:
recipient._oai_messages[sender].append(last_tool_message)

elif isinstance(carryover_summary_method, Callable):
carry_over_message_result = carryover_summary_method(
recipient,
messages=messages,
)

carry_over_message = concat_carryover(chat_message, carry_over_message_result)

chat["message"] = carry_over_message

@staticmethod
def _summary_from_nested_chats(
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
) -> Tuple[bool, Union[str, None]]:
"""Overridden _summary_from_nested_chats method from ConversableAgent.
This function initiates one or a sequence of chats between the "recipient" and the agents in the chat_queue.
It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
Swarm Updates:
- the 'messages' parameter contains the parent chat's messages
- the first chat in the queue can contain a 'carryover_config' which is a dictionary that denotes how to carryover messages from the swarm chat into the first chat of the nested chats). Only applies to the first chat.
e.g.: carryover_summarize_chat_config = {"summary_method": "reflection_with_llm", "summary_args": None}
summary_method can be "last_msg", "all", "reflection_with_llm", Callable
The Callable signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str
The summary will be concatenated to the message of the first chat in the queue.
Returns:
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
"""

# Carryover configuration allowed on the first chat in the queue only
if len(chat_queue) > 0 and "carryover_config" in chat_queue[0]:
SwarmAgent.process_nested_chat_carryover(chat_queue[0], recipient, messages, sender)

chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
if not chat_to_run:
return True, None
res = sender.initiate_chats(chat_to_run)
return True, res[-1].summary


# Forward references for SwarmAgent in SwarmResult
SwarmResult.update_forward_refs()

0 comments on commit ced3c65

Please sign in to comment.