diff --git a/autogen/__init__.py b/autogen/__init__.py index 707d98e080..4e81eef354 100644 --- a/autogen/__init__.py +++ b/autogen/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# Copyright (c) 2023 - 2025, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 # @@ -10,6 +10,7 @@ AFTER_WORK, ON_CONDITION, UPDATE_SYSTEM_MESSAGE, + AfterWork, AfterWorkOption, Agent, AssistantAgent, @@ -17,16 +18,19 @@ ConversableAgent, GroupChat, GroupChatManager, + OnCondition, ReasoningAgent, SwarmAgent, SwarmResult, ThinkNode, + UpdateSystemMessage, UserProxyAgent, a_initiate_swarm_chat, gather_usage_summary, initiate_chats, initiate_swarm_chat, register_function, + register_hand_off, visualize_tree, ) from .code_utils import DEFAULT_MODEL, FAST_MODEL @@ -64,6 +68,7 @@ "FAST_MODEL", "ON_CONDITION", "UPDATE_SYSTEM_MESSAGE", + "AfterWork", "AfterWorkOption", "Agent", "AgentNameConflict", @@ -78,6 +83,7 @@ "InvalidCarryOverType", "ModelClient", "NoEligibleSpeaker", + "OnCondition", "OpenAIWrapper", "ReasoningAgent", "SenderRequired", @@ -85,6 +91,7 @@ "SwarmResult", "ThinkNode", "UndefinedNextAgent", + "UpdateSystemMessage", "UserProxyAgent", "__version__", "a_initiate_swarm_chat", @@ -99,5 +106,6 @@ "initiate_chats", "initiate_swarm_chat", "register_function", + "register_hand_off", "visualize_tree", ] diff --git a/autogen/agentchat/__init__.py b/autogen/agentchat/__init__.py index 64ac7873e9..4218f34093 100644 --- a/autogen/agentchat/__init__.py +++ b/autogen/agentchat/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# Copyright (c) 2023 - 2025, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 # @@ -17,14 +17,17 @@ from .contrib.swarm_agent import ( AFTER_WORK, ON_CONDITION, - UPDATE_SYSTEM_MESSAGE, + AfterWork, AfterWorkOption, + OnCondition, SwarmAgent, SwarmResult, + UpdateCondition, a_initiate_swarm_chat, initiate_swarm_chat, + register_hand_off, ) -from .conversable_agent import ConversableAgent, register_function +from .conversable_agent import UPDATE_SYSTEM_MESSAGE, ConversableAgent, UpdateSystemMessage, register_function from .groupchat import GroupChat, GroupChatManager from .user_proxy_agent import UserProxyAgent from .utils import gather_usage_summary @@ -33,6 +36,7 @@ "AFTER_WORK", "ON_CONDITION", "UPDATE_SYSTEM_MESSAGE", + "AfterWork", "AfterWorkOption", "Agent", "AssistantAgent", @@ -40,15 +44,19 @@ "ConversableAgent", "GroupChat", "GroupChatManager", + "OnCondition", "ReasoningAgent", "SwarmAgent", "SwarmResult", "ThinkNode", + "UpdateCondition", + "UpdateSystemMessage", "UserProxyAgent", "a_initiate_swarm_chat", "gather_usage_summary", "initiate_chats", "initiate_swarm_chat", "register_function", + "register_hand_off", "visualize_tree", ] diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 9f3f36e14b..4c13a49559 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# Copyright (c) 2023 - 2025, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 import copy @@ -8,25 +8,55 @@ from dataclasses import dataclass from enum import Enum from inspect import signature -from typing import Any, Callable, Literal, Optional, Union +from types import MethodType +from typing import Any, Callable, Optional, Union from pydantic import BaseModel from autogen.oai import OpenAIWrapper -from autogen.tools import get_function_schema from ..agent import Agent from ..chat import ChatResult -from ..conversable_agent import ConversableAgent +from ..conversable_agent import __CONTEXT_VARIABLES_PARAM_NAME__, ConversableAgent from ..groupchat import GroupChat, GroupChatManager from ..user_proxy_agent import UserProxyAgent -# Parameter name for context variables -# Use the value in functions and they will be substituted with the context variables: -# e.g. def my_function(context_variables: Dict[str, Any], my_other_parameters: Any) -> Any: -__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables" -__TOOL_EXECUTOR_NAME__ = "Tool_Execution" +@dataclass +class UpdateCondition: + """Update the condition string before they reply + + Args: + update_function: The string or function to update the condition string. Can be a string or a Callable. + If a string, it will be used as a template and substitute the context variables. + If a Callable, it should have the signature: + def my_update_function(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str + """ + + update_function: Union[Callable, str] + + def __post_init__(self): + if isinstance(self.update_function, str): + assert self.update_function.strip(), " please provide a non-empty string or a callable" + # find all {var} in the string + vars = re.findall(r"\{(\w+)\}", self.update_function) + if len(vars) == 0: + warnings.warn("Update function string contains no variables. This is probably unintended.") + + elif isinstance(self.update_function, Callable): + sig = signature(self.update_function) + if len(sig.parameters) != 2: + raise ValueError( + "Update function must accept two parameters of type ConversableAgent and List[Dict[str, Any]], respectively" + ) + if sig.return_annotation != str: + raise ValueError("Update function must return a string") + else: + raise ValueError("Update function must be either a string or a callable") + + +# Created tool executor's name +__TOOL_EXECUTOR_NAME__ = "_Swarm_Tool_Executor" class AfterWorkOption(Enum): @@ -37,113 +67,142 @@ class AfterWorkOption(Enum): @dataclass -class AFTER_WORK: # noqa: N801 +class AfterWork: """Handles the next step in the conversation when an agent doesn't suggest a tool call or a handoff Args: - agent (Union[AfterWorkOption, SwarmAgent, str, Callable]): The agent to hand off to or the after work option. Can be a SwarmAgent, a string name of a SwarmAgent, an AfterWorkOption, or a Callable. + agent: The agent to hand off to or the after work option. Can be a ConversableAgent, a string name of a ConversableAgent, an AfterWorkOption, or a Callable. The Callable signature is: - def my_after_work_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, SwarmAgent, str]: + def my_after_work_func(last_speaker: ConversableAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, ConversableAgent, str]: """ - agent: Union[AfterWorkOption, "SwarmAgent", str, Callable] + agent: Union[AfterWorkOption, ConversableAgent, str, Callable] def __post_init__(self): if isinstance(self.agent, str): self.agent = AfterWorkOption(self.agent.upper()) +class AFTER_WORK(AfterWork): # noqa: N801 + """Deprecated: Use AfterWork instead. This class will be removed in a future version (TBD).""" + + def __init__(self, *args, **kwargs): + warnings.warn( + "AFTER_WORK is deprecated and will be removed in a future version (TBD). Use AfterWork instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + + @dataclass -class ON_CONDITION: # noqa: N801 +class OnCondition: """Defines a condition for transitioning to another agent or nested chats Args: - target (Union[SwarmAgent, dict[str, Any]]): The agent to hand off to or the nested chat configuration. Can be a SwarmAgent or a Dict. + target: The agent to hand off to or the nested chat configuration. Can be a ConversableAgent or a Dict. If a Dict, it should follow the convention of the nested chat configuration, with the exception of a carryover configuration which is unique to Swarms. Swarm Nested chat documentation: https://docs.ag2.ai/docs/topics/swarm#registering-handoffs-to-a-nested-chat condition (str): The condition for transitioning to the target agent, evaluated by the LLM to determine whether to call the underlying function/tool which does the transition. - available (Union[Callable, str]): Optional condition to determine if this ON_CONDITION is available. Can be a Callable or a string. + available (Union[Callable, str]): Optional condition to determine if this OnCondition is available. Can be a Callable or a string. If a string, it will look up the value of the context variable with that name, which should be a bool. """ - target: Union["SwarmAgent", dict[str, Any]] = None - condition: str = "" + target: Union[ConversableAgent, dict[str, Any]] = None + condition: Union[str, UpdateCondition] = "" available: Optional[Union[Callable, str]] = None def __post_init__(self): # Ensure valid types if self.target is not None: - assert isinstance(self.target, SwarmAgent) or isinstance(self.target, dict), ( - "'target' must be a SwarmAgent or a Dict" + assert isinstance(self.target, ConversableAgent) or isinstance(self.target, dict), ( + "'target' must be a ConversableAgent or a dict" ) # Ensure they have a condition - assert isinstance(self.condition, str) and self.condition.strip(), "'condition' must be a non-empty string" + if isinstance(self.condition, str): + assert self.condition.strip(), "'condition' must be a non-empty string" + else: + assert isinstance(self.condition, UpdateCondition), "'condition' must be a string or UpdateOnConditionStr" if self.available is not None: assert isinstance(self.available, (Callable, str)), "'available' must be a callable or a string" -@dataclass -class UPDATE_SYSTEM_MESSAGE: # noqa: N801 - """Update the agent's system message before they reply +class ON_CONDITION(OnCondition): # noqa: N801 + """Deprecated: Use OnCondition instead. This class will be removed in a future version (TBD).""" + + def __init__(self, *args, **kwargs): + warnings.warn( + "ON_CONDITION is deprecated and will be removed in a future version (TBD). Use OnCondition instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + + +def _establish_swarm_agent(agent: ConversableAgent): + """Establish the swarm agent with the swarm-related attributes and hooks. Not for the tool executor. Args: - update_function (Union[Callable, str]): The string or function to update the agent's system message. Can be a string or a Callable. - If a string, it will be used as a template and substitute the context variables. - If a Callable, it should have the signature: - def my_update_function(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str + agent (ConversableAgent): The agent to establish as a swarm agent. """ - update_function: Union[Callable, str] + def _swarm_agent_str(self: ConversableAgent) -> str: + """Customise the __str__ method to show the agent name for transition messages.""" + return f"Swarm agent --> {self.name}" - def __post_init__(self): - if isinstance(self.update_function, str): - # find all {var} in the string - vars = re.findall(r"\{(\w+)\}", self.update_function) - if len(vars) == 0: - warnings.warn("Update function string contains no variables. This is probably unintended.") + agent._swarm_after_work = None - elif isinstance(self.update_function, Callable): - sig = signature(self.update_function) - if len(sig.parameters) != 2: - raise ValueError( - "Update function must accept two parameters of type ConversableAgent and List[Dict[str Any]], respectively" - ) - if sig.return_annotation != str: - raise ValueError("Update function must return a string") - else: - raise ValueError("Update function must be either a string or a callable") + # Store nested chats hand offs as we'll establish these in the initiate_swarm_chat + # List of Dictionaries containing the nested_chats and condition + agent._swarm_nested_chat_handoffs = [] + + # Store conditional functions (and their OnCondition instances) to add/remove later when transitioning to this agent + agent._swarm_conditional_functions = {} + + # Register the hook to update agent state (except tool executor) + agent.register_hook("update_agent_state", _update_conditional_functions) + + agent._get_display_name = MethodType(_swarm_agent_str, agent) + + # Mark this agent as established as a swarm agent + agent._swarm_is_established = True def _prepare_swarm_agents( - initial_agent: "SwarmAgent", - agents: list["SwarmAgent"], -) -> tuple["SwarmAgent", list["SwarmAgent"]]: + initial_agent: ConversableAgent, + agents: list[ConversableAgent], +) -> tuple[ConversableAgent, list[ConversableAgent]]: """Validates agents, create the tool executor, configure nested chats. Args: - initial_agent (SwarmAgent): The first agent in the conversation. - agents (list[SwarmAgent]): List of all agents in the conversation. + initial_agent (ConversableAgent): The first agent in the conversation. + agents (list[ConversableAgent]): List of all agents in the conversation. Returns: - SwarmAgent: The tool executor agent. - list[SwarmAgent]: List of nested chat agents. + ConversableAgent: The tool executor agent. + list[ConversableAgent]: List of nested chat agents. """ - assert isinstance(initial_agent, SwarmAgent), "initial_agent must be a SwarmAgent" - assert all(isinstance(agent, SwarmAgent) for agent in agents), "Agents must be a list of SwarmAgents" + assert isinstance(initial_agent, ConversableAgent), "initial_agent must be a ConversableAgent" + assert all(isinstance(agent, ConversableAgent) for agent in agents), "Agents must be a list of ConversableAgents" + + # Initialize all agents as swarm agents + for agent in agents: + if not hasattr(agent, "_swarm_is_established"): + _establish_swarm_agent(agent) # Ensure all agents in hand-off after-works are in the passed in agents list for agent in agents: - if agent.after_work is not None: - if isinstance(agent.after_work.agent, SwarmAgent): - assert agent.after_work.agent in agents, "Agent in hand-off must be in the agents list" + if agent._swarm_after_work is not None: + if isinstance(agent._swarm_after_work.agent, ConversableAgent): + assert agent._swarm_after_work.agent in agents, "Agent in hand-off must be in the agents list" - tool_execution = SwarmAgent( + tool_execution = ConversableAgent( name=__TOOL_EXECUTOR_NAME__, - system_message="Tool Execution", + system_message="Tool Execution, do not use this agent directly.", ) - tool_execution._set_to_tool_execution() + _set_to_tool_execution(tool_execution) nested_chat_agents = [] for agent in agents: @@ -153,26 +212,26 @@ def _prepare_swarm_agents( for agent in agents + nested_chat_agents: tool_execution._function_map.update(agent._function_map) # Add conditional functions to the tool_execution agent - for func_name, (func, _) in agent._conditional_functions.items(): + for func_name, (func, _) in agent._swarm_conditional_functions.items(): tool_execution._function_map[func_name] = func return tool_execution, nested_chat_agents -def _create_nested_chats(agent: "SwarmAgent", nested_chat_agents: list["SwarmAgent"]): +def _create_nested_chats(agent: ConversableAgent, nested_chat_agents: list[ConversableAgent]): """Create nested chat agents and register nested chats. Args: - agent (SwarmAgent): The agent to create nested chat agents for, including registering the hand offs. - nested_chat_agents (list[SwarmAgent]): List for all nested chat agents, appends to this. + agent (ConversableAgent): The agent to create nested chat agents for, including registering the hand offs. + nested_chat_agents (list[ConversableAgent]): List for all nested chat agents, appends to this. """ - for i, nested_chat_handoff in enumerate(agent._nested_chat_handoffs): + for i, nested_chat_handoff in enumerate(agent._swarm_nested_chat_handoffs): nested_chats: dict[str, Any] = nested_chat_handoff["nested_chats"] condition = nested_chat_handoff["condition"] available = nested_chat_handoff["available"] # Create a nested chat agent specifically for this nested chat - nested_chat_agent = SwarmAgent(name=f"nested_chat_{agent.name}_{i + 1}") + nested_chat_agent = ConversableAgent(name=f"nested_chat_{agent.name}_{i + 1}") nested_chat_agent.register_nested_chats( nested_chats["chat_queue"], @@ -185,19 +244,19 @@ def _create_nested_chats(agent: "SwarmAgent", nested_chat_agents: list["SwarmAge ) # After the nested chat is complete, transfer back to the parent agent - nested_chat_agent.register_hand_off(AFTER_WORK(agent=agent)) + register_hand_off(nested_chat_agent, AfterWork(agent=agent)) nested_chat_agents.append(nested_chat_agent) # Nested chat is triggered through an agent transfer to this nested chat agent - agent.register_hand_off(ON_CONDITION(nested_chat_agent, condition, available)) + register_hand_off(agent, OnCondition(nested_chat_agent, condition, available)) def _process_initial_messages( messages: Union[list[dict[str, Any]], str], user_agent: Optional[UserProxyAgent], - agents: list["SwarmAgent"], - nested_chat_agents: list["SwarmAgent"], + agents: list[ConversableAgent], + nested_chat_agents: list[ConversableAgent], ) -> tuple[list[dict], Optional[Agent], list[str], list[Agent]]: """Process initial messages, validating agent names against messages, and determining the last agent to speak. @@ -242,8 +301,8 @@ def _process_initial_messages( def _setup_context_variables( - tool_execution: "SwarmAgent", - agents: list["SwarmAgent"], + tool_execution: ConversableAgent, + agents: list[ConversableAgent], manager: GroupChatManager, context_variables: dict[str, Any], ) -> None: @@ -270,11 +329,11 @@ def _cleanup_temp_user_messages(chat_result: ChatResult) -> None: def _determine_next_agent( - last_speaker: "SwarmAgent", + last_speaker: ConversableAgent, groupchat: GroupChat, initial_agent: ConversableAgent, use_initial_agent: bool, - tool_execution: "SwarmAgent", + tool_execution: ConversableAgent, swarm_agent_names: list[str], user_agent: Optional[UserProxyAgent], swarm_after_work: Optional[Union[AfterWorkOption, Callable]], @@ -282,11 +341,11 @@ def _determine_next_agent( """Determine the next agent in the conversation. Args: - last_speaker (SwarmAgent): The last agent to speak. + last_speaker (ConversableAgent): The last agent to speak. groupchat (GroupChat): GroupChat instance. initial_agent (ConversableAgent): The initial agent in the conversation. use_initial_agent (bool): Whether to use the initial agent straight away. - tool_execution (SwarmAgent): The tool execution agent. + tool_execution (ConversableAgent): The tool execution agent. swarm_agent_names (list[str]): List of agent names. user_agent (UserProxyAgent): Optional user proxy agent. swarm_after_work (Union[AfterWorkOption, Callable]): Method to handle conversation continuation when an agent doesn't select the next agent. @@ -297,9 +356,9 @@ def _determine_next_agent( if "tool_calls" in groupchat.messages[-1]: return tool_execution - if tool_execution._next_agent is not None: - next_agent = tool_execution._next_agent - tool_execution._next_agent = None + if tool_execution._swarm_next_agent is not None: + next_agent = tool_execution._swarm_next_agent + tool_execution._swarm_next_agent = None # Check for string, access agent from group chat. @@ -316,7 +375,7 @@ def _determine_next_agent( for message in reversed(groupchat.messages): if "name" in message and message["name"] in swarm_agent_names: agent = groupchat.agent_by_name(name=message["name"]) - if isinstance(agent, SwarmAgent): + if isinstance(agent, ConversableAgent): last_swarm_speaker = agent break if last_swarm_speaker is None: @@ -328,9 +387,9 @@ def _determine_next_agent( # Resolve after_work condition (agent-level overrides global) after_work_condition = ( - last_swarm_speaker.after_work if last_swarm_speaker.after_work is not None else swarm_after_work + last_swarm_speaker._swarm_after_work if last_swarm_speaker._swarm_after_work is not None else swarm_after_work ) - if isinstance(after_work_condition, AFTER_WORK): + if isinstance(after_work_condition, AfterWork): after_work_condition = after_work_condition.agent # Evaluate callable after_work @@ -342,7 +401,7 @@ def _determine_next_agent( return groupchat.agent_by_name(name=after_work_condition) else: raise ValueError(f"Invalid agent name in after_work: {after_work_condition}") - elif isinstance(after_work_condition, SwarmAgent): + elif isinstance(after_work_condition, ConversableAgent): return after_work_condition elif isinstance(after_work_condition, AfterWorkOption): if after_work_condition == AfterWorkOption.TERMINATE: @@ -358,17 +417,17 @@ def _determine_next_agent( def create_swarm_transition( - initial_agent: "SwarmAgent", - tool_execution: "SwarmAgent", + initial_agent: ConversableAgent, + tool_execution: ConversableAgent, swarm_agent_names: list[str], user_agent: Optional[UserProxyAgent], swarm_after_work: Optional[Union[AfterWorkOption, Callable]], -) -> Callable[["SwarmAgent", GroupChat], Optional[Agent]]: +) -> Callable[[ConversableAgent, GroupChat], Optional[Agent]]: """Creates a transition function for swarm chat with enclosed state for the use_initial_agent. Args: - initial_agent (SwarmAgent): The first agent to speak - tool_execution (SwarmAgent): The tool execution agent + initial_agent (ConversableAgent): The first agent to speak + tool_execution (ConversableAgent): The tool execution agent swarm_agent_names (list[str]): List of all agent names user_agent (UserProxyAgent): Optional user proxy agent swarm_after_work (Union[AfterWorkOption, Callable]): Swarm-level after work @@ -380,7 +439,7 @@ def create_swarm_transition( # of swarm_transition state = {"use_initial_agent": True} - def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat) -> Optional[Agent]: + def swarm_transition(last_speaker: ConversableAgent, groupchat: GroupChat) -> Optional[Agent]: result = _determine_next_agent( last_speaker=last_speaker, groupchat=groupchat, @@ -398,14 +457,14 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat) -> Optional def initiate_swarm_chat( - initial_agent: "SwarmAgent", + initial_agent: ConversableAgent, messages: Union[list[dict[str, Any]], str], - agents: list["SwarmAgent"], + agents: list[ConversableAgent], user_agent: Optional[UserProxyAgent] = None, max_rounds: int = 20, context_variables: Optional[dict[str, Any]] = None, - after_work: Optional[Union[AfterWorkOption, Callable]] = AFTER_WORK(AfterWorkOption.TERMINATE), -) -> tuple[ChatResult, dict[str, Any], "SwarmAgent"]: + after_work: Optional[Union[AfterWorkOption, Callable]] = AfterWork(AfterWorkOption.TERMINATE), +) -> tuple[ChatResult, dict[str, Any], ConversableAgent]: """Initialize and run a swarm chat Args: @@ -416,20 +475,20 @@ def initiate_swarm_chat( max_rounds: Maximum number of conversation rounds. context_variables: Starting context variables. after_work: Method to handle conversation continuation when an agent doesn't select the next agent. If no agent is selected and no tool calls are output, we will use this method to determine the next agent. - Must be a AFTER_WORK instance (which is a dataclass accepting a SwarmAgent, AfterWorkOption, A str (of the AfterWorkOption)) or a callable. + Must be a AfterWork instance (which is a dataclass accepting a ConversableAgent, AfterWorkOption, A str (of the AfterWorkOption)) or a callable. AfterWorkOption: - TERMINATE (Default): Terminate the conversation. - REVERT_TO_USER : Revert to the user agent if a user agent is provided. If not provided, terminate the conversation. - STAY : Stay with the last speaker. - Callable: A custom function that takes the current agent, messages, and groupchat as arguments and returns an AfterWorkOption or a SwarmAgent (by reference or string name). + Callable: A custom function that takes the current agent, messages, and groupchat as arguments and returns an AfterWorkOption or a ConversableAgent (by reference or string name). ```python - def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, SwarmAgent, str]: + def custom_afterwork_func(last_speaker: ConversableAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, ConversableAgent, str]: ``` Returns: ChatResult: Conversations chat history. Dict[str, Any]: Updated Context variables. - SwarmAgent: Last speaker. + ConversableAgent: Last speaker. """ tool_execution, nested_chat_agents = _prepare_swarm_agents(initial_agent, agents) @@ -455,7 +514,7 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any manager = GroupChatManager(groupchat) - # Point all SwarmAgent's context variables to this function's context_variables + # Point all ConversableAgent's context variables to this function's context_variables _setup_context_variables(tool_execution, agents, manager, context_variables or {}) if len(processed_messages) > 1: @@ -477,14 +536,14 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any async def a_initiate_swarm_chat( - initial_agent: "SwarmAgent", + initial_agent: ConversableAgent, messages: Union[list[dict[str, Any]], str], - agents: list["SwarmAgent"], + agents: list[ConversableAgent], user_agent: Optional[UserProxyAgent] = None, max_rounds: int = 20, context_variables: Optional[dict[str, Any]] = None, - after_work: Optional[Union[AfterWorkOption, Callable]] = AFTER_WORK(AfterWorkOption.TERMINATE), -) -> tuple[ChatResult, dict[str, Any], "SwarmAgent"]: + after_work: Optional[Union[AfterWorkOption, Callable]] = AfterWork(AfterWorkOption.TERMINATE), +) -> tuple[ChatResult, dict[str, Any], ConversableAgent]: """Initialize and run a swarm chat asynchronously Args: @@ -495,20 +554,20 @@ async def a_initiate_swarm_chat( max_rounds: Maximum number of conversation rounds. context_variables: Starting context variables. after_work: Method to handle conversation continuation when an agent doesn't select the next agent. If no agent is selected and no tool calls are output, we will use this method to determine the next agent. - Must be a AFTER_WORK instance (which is a dataclass accepting a SwarmAgent, AfterWorkOption, A str (of the AfterWorkOption)) or a callable. + Must be a AfterWork instance (which is a dataclass accepting a ConversableAgent, AfterWorkOption, A str (of the AfterWorkOption)) or a callable. AfterWorkOption: - TERMINATE (Default): Terminate the conversation. - REVERT_TO_USER : Revert to the user agent if a user agent is provided. If not provided, terminate the conversation. - STAY : Stay with the last speaker. - Callable: A custom function that takes the current agent, messages, and groupchat as arguments and returns an AfterWorkOption or a SwarmAgent (by reference or string name). + Callable: A custom function that takes the current agent, messages, and groupchat as arguments and returns an AfterWorkOption or a ConversableAgent (by reference or string name). ```python - def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, SwarmAgent, str]: + def custom_afterwork_func(last_speaker: ConversableAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, ConversableAgent, str]: ``` Returns: ChatResult: Conversations chat history. Dict[str, Any]: Updated Context variables. - SwarmAgent: Last speaker. + ConversableAgent: Last speaker. """ tool_execution, nested_chat_agents = _prepare_swarm_agents(initial_agent, agents) @@ -534,7 +593,7 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any manager = GroupChatManager(groupchat) - # Point all SwarmAgent's context variables to this function's context_variables + # Point all ConversableAgent's context variables to this function's context_variables _setup_context_variables(tool_execution, agents, manager, context_variables or {}) if len(processed_messages) > 1: @@ -560,12 +619,12 @@ class SwarmResult(BaseModel): Args: values (str): The result values as a string. - agent (SwarmAgent): The swarm agent instance, if applicable. + agent (ConversableAgent): The agent instance, if applicable. context_variables (dict): A dictionary of context variables. """ values: str = "" - agent: Optional[Union["SwarmAgent", str]] = None + agent: Optional[Union[ConversableAgent, str]] = None context_variables: dict[str, Any] = {} class Config: # Add this inner class @@ -575,463 +634,204 @@ def __str__(self): return self.values -class SwarmAgent(ConversableAgent): - """Swarm agent for participating in a swarm. +def _set_to_tool_execution(agent: ConversableAgent): + """Set to a special instance of ConversableAgent that is responsible for executing tool calls from other swarm agents. + This agent will be used internally and should not be visible to the user. - SwarmAgent is a subclass of ConversableAgent. - - Additional args: - functions (List[Callable]): A list of functions to register with the agent. - update_agent_state_before_reply (List[Callable]): A list of functions, including UPDATE_SYSTEM_MESSAGEs, called to update the agent before it replies. + It will execute the tool calls and update the referenced context_variables and next_agent accordingly. """ + agent._swarm_next_agent = None + agent._reply_func_list.clear() + agent.register_reply([Agent, None], _generate_swarm_tool_reply) - def __init__( - self, - name: str, - system_message: Optional[str] = "You are a helpful AI Assistant.", - llm_config: Optional[Union[dict, Literal[False]]] = None, - functions: Union[list[Callable], Callable] = None, - is_termination_msg: Optional[Callable[[dict], bool]] = None, - max_consecutive_auto_reply: Optional[int] = None, - human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", - description: Optional[str] = None, - code_execution_config=False, - update_agent_state_before_reply: Optional[ - Union[list[Union[Callable, UPDATE_SYSTEM_MESSAGE]], Callable, UPDATE_SYSTEM_MESSAGE] - ] = None, - **kwargs, - ) -> None: - super().__init__( - name, - system_message, - is_termination_msg, - max_consecutive_auto_reply, - human_input_mode, - llm_config=llm_config, - description=description, - code_execution_config=code_execution_config, - **kwargs, - ) - if isinstance(functions, list): - if not all(isinstance(func, Callable) for func in functions): - raise TypeError("All elements in the functions list must be callable") - self.add_functions(functions) - elif isinstance(functions, Callable): - self.add_single_function(functions) - elif functions is not None: - raise TypeError("Functions must be a callable or a list of callables") - - self.after_work = None - - # Used in the tool execution agent to transfer to the next agent - self._next_agent = None - - # Store nested chats hand offs as we'll establish these in the initiate_swarm_chat - # List of Dictionaries containing the nested_chats and condition - self._nested_chat_handoffs = [] - - self.register_update_agent_state_before_reply(update_agent_state_before_reply) - - # Store conditional functions (and their ON_CONDITION instances) to add/remove later when transitioning to this agent - self._conditional_functions = {} - - # Register the hook to update agent state (except tool executor) - if name != __TOOL_EXECUTOR_NAME__: - self.register_hook("update_agent_state", self._update_conditional_functions) - - def register_update_agent_state_before_reply(self, functions: Optional[Union[list[Callable], Callable]]): - """Register functions that will be called when the agent is selected and before it speaks. - You can add your own validation or precondition functions here. - - Args: - functions (List[Callable[[], None]]): A list of functions to be registered. Each function - is called when the agent is selected and before it speaks. - """ - if functions is None: - return - if not isinstance(functions, list) and type(functions) not in [UPDATE_SYSTEM_MESSAGE, Callable]: - raise ValueError("functions must be a list of callables") - - if not isinstance(functions, list): - functions = [functions] - - for func in functions: - if isinstance(func, UPDATE_SYSTEM_MESSAGE): - # Wrapper function that allows this to be used in the update_agent_state hook - # Its primary purpose, however, is just to update the agent's system message - # Outer function to create a closure with the update function - def create_wrapper(update_func: UPDATE_SYSTEM_MESSAGE): - def update_system_message_wrapper( - agent: ConversableAgent, messages: list[dict[str, Any]] - ) -> list[dict[str, Any]]: - if isinstance(update_func.update_function, str): - # Templates like "My context variable passport is {passport}" will - # use the context_variables for substitution - sys_message = OpenAIWrapper.instantiate( - template=update_func.update_function, - context=agent._context_variables, - allow_format_str_template=True, - ) - else: - sys_message = update_func.update_function(agent, messages) - - agent.update_system_message(sys_message) - return messages - - return update_system_message_wrapper - - self.register_hook(hookable_method="update_agent_state", hook=create_wrapper(func)) +def register_hand_off( + agent: ConversableAgent, + hand_to: Union[list[Union[OnCondition, AfterWork]], OnCondition, AfterWork], +): + """Register a function to hand off to another agent. - else: - self.register_hook(hookable_method="update_agent_state", hook=func) + Args: + agent: The agent to register the hand off with. + hand_to: A list of OnCondition's and an, optional, AfterWork condition + + Hand off template: + def transfer_to_agent_name() -> ConversableAgent: + return agent_name + 1. register the function with the agent + 2. register the schema with the agent, description set to the condition + """ + # If the agent hasn't been established as a swarm agent, do so first + if not hasattr(agent, "_swarm_is_established"): + _establish_swarm_agent(agent) - def _set_to_tool_execution(self): - """Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents. - This agent will be used internally and should not be visible to the user. + # Ensure that hand_to is a list or OnCondition or AfterWork + if not isinstance(hand_to, (list, OnCondition, AfterWork)): + raise ValueError("hand_to must be a list of OnCondition or AfterWork") - It will execute the tool calls and update the referenced context_variables and next_agent accordingly. - """ - self._next_agent = None - self._reply_func_list.clear() - self.register_reply([Agent, None], SwarmAgent.generate_swarm_tool_reply) + if isinstance(hand_to, (OnCondition, AfterWork)): + hand_to = [hand_to] - def __str__(self): - return f"SwarmAgent --> {self.name}" - - def register_hand_off( - self, - hand_to: Union[list[Union[ON_CONDITION, AFTER_WORK]], ON_CONDITION, AFTER_WORK], - ): - """Register a function to hand off to another agent. - - Args: - hand_to: A list of ON_CONDITIONs and an, optional, AFTER_WORK condition - - Hand off template: - def transfer_to_agent_name() -> SwarmAgent: - return agent_name - 1. register the function with the agent - 2. register the schema with the agent, description set to the condition - """ - # Ensure that hand_to is a list or ON_CONDITION or AFTER_WORK - if not isinstance(hand_to, (list, ON_CONDITION, AFTER_WORK)): - raise ValueError("hand_to must be a list of ON_CONDITION or AFTER_WORK") - - if isinstance(hand_to, (ON_CONDITION, AFTER_WORK)): - hand_to = [hand_to] - - for transit in hand_to: - if isinstance(transit, AFTER_WORK): - assert isinstance(transit.agent, (AfterWorkOption, SwarmAgent, str, Callable)), ( - "Invalid After Work value" + for transit in hand_to: + if isinstance(transit, AfterWork): + assert isinstance(transit.agent, (AfterWorkOption, ConversableAgent, str, Callable)), ( + "Invalid After Work value" + ) + agent._swarm_after_work = transit + elif isinstance(transit, OnCondition): + if isinstance(transit.target, ConversableAgent): + # Transition to agent + + # Create closure with current loop transit value + # to ensure the condition matches the one in the loop + def make_transfer_function(current_transit: OnCondition): + def transfer_to_agent() -> ConversableAgent: + return current_transit.target + + return transfer_to_agent + + transfer_func = make_transfer_function(transit) + + # Store function to add/remove later based on it being 'available' + # Function names are made unique and allow multiple OnCondition's to the same agent + base_func_name = f"transfer_{agent.name}_to_{transit.target.name}" + func_name = base_func_name + count = 2 + while func_name in agent._swarm_conditional_functions: + func_name = f"{base_func_name}_{count}" + count += 1 + + # Store function to add/remove later based on it being 'available' + agent._swarm_conditional_functions[func_name] = (transfer_func, transit) + + elif isinstance(transit.target, dict): + # Transition to a nested chat + # We will store them here and establish them in the initiate_swarm_chat + agent._swarm_nested_chat_handoffs.append( + {"nested_chats": transit.target, "condition": transit.condition, "available": transit.available} ) - self.after_work = transit - elif isinstance(transit, ON_CONDITION): - if isinstance(transit.target, SwarmAgent): - # Transition to agent - - # Create closure with current loop transit value - # to ensure the condition matches the one in the loop - def make_transfer_function(current_transit: ON_CONDITION): - def transfer_to_agent() -> "SwarmAgent": - return current_transit.target - - return transfer_to_agent - - transfer_func = make_transfer_function(transit) - - # Store function to add/remove later based on it being 'available' - # Function names are made unique and allow multiple ON_CONDITIONS to the same agent - base_func_name = f"transfer_{self.name}_to_{transit.target.name}" - func_name = base_func_name - count = 2 - while func_name in self._conditional_functions: - func_name = f"{base_func_name}_{count}" - count += 1 - - # Store function to add/remove later based on it being 'available' - self._conditional_functions[func_name] = (transfer_func, transit) - - elif isinstance(transit.target, dict): - # Transition to a nested chat - # We will store them here and establish them in the initiate_swarm_chat - self._nested_chat_handoffs.append( - {"nested_chats": transit.target, "condition": transit.condition, "available": transit.available} - ) - - else: - raise ValueError("Invalid hand off condition, must be either ON_CONDITION or AFTER_WORK") - - @staticmethod - def _update_conditional_functions(agent: Agent, messages: Optional[list[dict]] = None) -> None: - """Updates the agent's functions based on the ON_CONDITION's available condition.""" - for func_name, (func, on_condition) in agent._conditional_functions.items(): - is_available = True - - if on_condition.available is not None: - if isinstance(on_condition.available, Callable): - is_available = on_condition.available(agent, next(iter(agent.chat_messages.values()))) - elif isinstance(on_condition.available, str): - is_available = agent.get_context(on_condition.available) or False - - if is_available: - if func_name not in agent._function_map: - agent.add_single_function(func, func_name, on_condition.condition) - else: - # Remove function using the stored name - if func_name in agent._function_map: - agent.update_tool_signature(func_name, is_remove=True) - del agent._function_map[func_name] - - def generate_swarm_tool_reply( - self, - messages: Optional[list[dict]] = None, - sender: Optional[Agent] = None, - config: Optional[OpenAIWrapper] = None, - ) -> tuple[bool, dict]: - """Pre-processes and generates tool call replies. - - This function: - 1. Adds context_variables back to the tool call for the function, if necessary. - 2. Generates the tool calls reply. - 3. Updates context_variables and next_agent based on the tool call response. - """ - if config is None: - config = self - if messages is None: - messages = self._oai_messages[sender] - - message = messages[-1] - if "tool_calls" in message: - tool_call_count = len(message["tool_calls"]) - - # Loop through tool calls individually (so context can be updated after each function call) - next_agent = None - tool_responses_inner = [] - contents = [] - for index in range(tool_call_count): - # Deep copy to ensure no changes to messages when we insert the context variables - message_copy = copy.deepcopy(message) - - # 1. add context_variables to the tool call arguments - tool_call = message_copy["tool_calls"][index] - - if tool_call["type"] == "function": - function_name = tool_call["function"]["name"] - - # Check if this function exists in our function map - if function_name in self._function_map: - func = self._function_map[function_name] # Get the original function - - # Inject the context variables into the tool call if it has the parameter - sig = signature(func) - if __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters: - current_args = json.loads(tool_call["function"]["arguments"]) - current_args[__CONTEXT_VARIABLES_PARAM_NAME__] = self._context_variables - tool_call["function"]["arguments"] = json.dumps(current_args) - - # Ensure we are only executing the one tool at a time - message_copy["tool_calls"] = [tool_call] - - # 2. generate tool calls reply - _, tool_message = self.generate_tool_calls_reply([message_copy]) - - # 3. update context_variables and next_agent, convert content to string - for tool_response in tool_message["tool_responses"]: - content = tool_response.get("content") - if isinstance(content, SwarmResult): - if content.context_variables != {}: - self._context_variables.update(content.context_variables) - if content.agent is not None: - next_agent = content.agent - elif isinstance(content, Agent): - next_agent = content - - tool_responses_inner.append(tool_response) - contents.append(str(tool_response["content"])) - - self._next_agent = next_agent - - # Put the tool responses and content strings back into the response message - # Caters for multiple tool calls - tool_message["tool_responses"] = tool_responses_inner - tool_message["content"] = "\n".join(contents) - - return True, tool_message - return False, None - - def add_single_function(self, func: Callable, name=None, description=""): - """Add a single function to the agent, removing context variables for LLM use""" - if name: - func._name = name - else: - func._name = func.__name__ - if description: - func._description = description else: - # Use function's docstring, strip whitespace, fall back to empty string - func._description = (func.__doc__ or "").strip() - - f = get_function_schema(func, name=func._name, description=func._description) - - # Remove context_variables parameter from function schema - f_no_context = f.copy() - if __CONTEXT_VARIABLES_PARAM_NAME__ in f_no_context["function"]["parameters"]["properties"]: - del f_no_context["function"]["parameters"]["properties"][__CONTEXT_VARIABLES_PARAM_NAME__] - if "required" in f_no_context["function"]["parameters"]: - required = f_no_context["function"]["parameters"]["required"] - f_no_context["function"]["parameters"]["required"] = [ - param for param in required if param != __CONTEXT_VARIABLES_PARAM_NAME__ - ] - # If required list is empty, remove it - if not f_no_context["function"]["parameters"]["required"]: - del f_no_context["function"]["parameters"]["required"] - - self.update_tool_signature(f_no_context, is_remove=False) - self.register_function({func._name: func}) - - def add_functions(self, func_list: list[Callable]): - for func in func_list: - self.add_single_function(func) - - @staticmethod - def process_nested_chat_carryover( - chat: dict[str, Any], - recipient: ConversableAgent, - messages: list[dict[str, Any]], - sender: ConversableAgent, - config: Any, - trim_n_messages: int = 0, - ) -> None: - """Process carryover messages for a nested chat (typically for the first chat of a swarm) - - The carryover_config key is a dictionary containing: - "summary_method": The method to use to summarise the messages, can be "all", "last_msg", "reflection_with_llm" or a Callable - "summary_args": Optional arguments for the summary method - - Supported carryover 'summary_methods' are: - "all" - all messages will be incorporated - "last_msg" - the last message will be incorporated - "reflection_with_llm" - an llm will summarise all the messages and the summary will be incorporated as a single message - Callable - a callable with the signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str - - Args: - chat: The chat dictionary containing the carryover configuration - recipient: The recipient agent - messages: The messages from the parent chat - sender: The sender agent - trim_n_messages: The number of latest messages to trim from the messages list - """ - - def concat_carryover(chat_message: str, carryover_message: Union[str, list[dict[str, Any]]]) -> str: - """Concatenate the carryover message to the chat message.""" - prefix = f"{chat_message}\n" if chat_message else "" - - if isinstance(carryover_message, str): - content = carryover_message - elif isinstance(carryover_message, list): - content = "\n".join( - msg["content"] for msg in carryover_message if "content" in msg and msg["content"] is not None - ) - else: - raise ValueError("Carryover message must be a string or a list of dictionaries") - - return f"{prefix}Context:\n{content}" - - carryover_config = chat["carryover_config"] - - if "summary_method" not in carryover_config: - raise ValueError("Carryover configuration must contain a 'summary_method' key") - - carryover_summary_method = carryover_config["summary_method"] - carryover_summary_args = carryover_config.get("summary_args") or {} - - chat_message = "" - message = chat.get("message") - - # If the message is a callable, run it and get the result - if message: - chat_message = message(recipient, messages, sender, config) if callable(message) else message - - # deep copy and trim the latest messages - content_messages = copy.deepcopy(messages) - content_messages = content_messages[:-trim_n_messages] - - if carryover_summary_method == "all": - # Put a string concatenated value of all parent messages into the first message - # (e.g. message = \nContext: \n\n\n...) - carry_over_message = concat_carryover(chat_message, content_messages) - - elif carryover_summary_method == "last_msg": - # (e.g. message = \nContext: \n) - carry_over_message = concat_carryover(chat_message, content_messages[-1]["content"]) - - elif carryover_summary_method == "reflection_with_llm": - # (e.g. message = \nContext: \n) - - # Add the messages to the nested chat agent for reflection (we'll clear after reflection) - chat["recipient"]._oai_messages[sender] = content_messages - - carry_over_message_llm = ConversableAgent._reflection_with_llm_as_summary( - sender=sender, - recipient=chat["recipient"], # Chat recipient LLM config will be used for the reflection - summary_args=carryover_summary_args, - ) - - recipient._oai_messages[sender] = [] - - carry_over_message = concat_carryover(chat_message, carry_over_message_llm) - - elif isinstance(carryover_summary_method, Callable): - # (e.g. message = \nContext: \n) - carry_over_message_result = carryover_summary_method(recipient, content_messages, carryover_summary_args) - - carry_over_message = concat_carryover(chat_message, carry_over_message_result) - - chat["message"] = carry_over_message - - @staticmethod - def _summary_from_nested_chats( - chat_queue: list[dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any - ) -> tuple[bool, Union[str, None]]: - """Overridden _summary_from_nested_chats method from ConversableAgent. - This function initiates one or a sequence of chats between the "recipient" and the agents in the chat_queue. - - It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. - - Swarm Updates: - - the 'messages' parameter contains the parent chat's messages - - the first chat in the queue can contain a 'carryover_config' which is a dictionary that denotes how to carryover messages from the swarm chat into the first chat of the nested chats). Only applies to the first chat. - e.g.: carryover_summarize_chat_config = {"summary_method": "reflection_with_llm", "summary_args": None} - summary_method can be "last_msg", "all", "reflection_with_llm", Callable - The Callable signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str - The summary will be concatenated to the message of the first chat in the queue. - - Returns: - Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. - """ - # Carryover configuration allowed on the first chat in the queue only, trim the last two messages specifically for swarm nested chat carryover as these are the messages for the transition to the nested chat agent - restore_chat_queue_message = False - if len(chat_queue) > 0 and "carryover_config" in chat_queue[0]: - if "message" in chat_queue[0]: - # As we're updating the message in the nested chat queue, we need to restore it after finishing this nested chat. - restore_chat_queue_message = True - original_chat_queue_message = chat_queue[0]["message"] - SwarmAgent.process_nested_chat_carryover(chat_queue[0], recipient, messages, sender, config, 2) - - chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) - if not chat_to_run: - return True, None - res = sender.initiate_chats(chat_to_run) + raise ValueError("Invalid hand off condition, must be either OnCondition or AfterWork") + + +def _update_conditional_functions(agent: ConversableAgent, messages: Optional[list[dict]] = None) -> None: + """Updates the agent's functions based on the OnCondition's available condition.""" + for func_name, (func, on_condition) in agent._swarm_conditional_functions.items(): + is_available = True + + if on_condition.available is not None: + if isinstance(on_condition.available, Callable): + is_available = on_condition.available(agent, next(iter(agent.chat_messages.values()))) + elif isinstance(on_condition.available, str): + is_available = agent.get_context(on_condition.available) or False + + # first remove the function if it exists + if func_name in agent._function_map: + agent.update_tool_signature(func_name, is_remove=True) + del agent._function_map[func_name] + + # then add the function if it is available, so that the function signature is updated + if is_available: + condition = on_condition.condition + if isinstance(condition, UpdateCondition): + if isinstance(condition.update_function, str): + condition = OpenAIWrapper.instantiate( + template=condition.update_function, + context=agent._context_variables, + allow_format_str_template=True, + ) + else: + condition = condition.update_function(agent, messages) + agent._add_single_function(func, func_name, condition) + + +def _generate_swarm_tool_reply( + agent: ConversableAgent, + messages: Optional[list[dict]] = None, + sender: Optional[Agent] = None, + config: Optional[OpenAIWrapper] = None, +) -> tuple[bool, dict]: + """Pre-processes and generates tool call replies. + + This function: + 1. Adds context_variables back to the tool call for the function, if necessary. + 2. Generates the tool calls reply. + 3. Updates context_variables and next_agent based on the tool call response.""" + + if config is None: + config = agent + if messages is None: + messages = agent._oai_messages[sender] + + message = messages[-1] + if "tool_calls" in message: + tool_call_count = len(message["tool_calls"]) + + # Loop through tool calls individually (so context can be updated after each function call) + next_agent = None + tool_responses_inner = [] + contents = [] + for index in range(tool_call_count): + # Deep copy to ensure no changes to messages when we insert the context variables + message_copy = copy.deepcopy(message) + + # 1. add context_variables to the tool call arguments + tool_call = message_copy["tool_calls"][index] + + if tool_call["type"] == "function": + function_name = tool_call["function"]["name"] + + # Check if this function exists in our function map + if function_name in agent._function_map: + func = agent._function_map[function_name] # Get the original function + + # Inject the context variables into the tool call if it has the parameter + sig = signature(func) + if __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters: + current_args = json.loads(tool_call["function"]["arguments"]) + current_args[__CONTEXT_VARIABLES_PARAM_NAME__] = agent._context_variables + tool_call["function"]["arguments"] = json.dumps(current_args) + + # Ensure we are only executing the one tool at a time + message_copy["tool_calls"] = [tool_call] + + # 2. generate tool calls reply + _, tool_message = agent.generate_tool_calls_reply([message_copy]) + + # 3. update context_variables and next_agent, convert content to string + for tool_response in tool_message["tool_responses"]: + content = tool_response.get("content") + if isinstance(content, SwarmResult): + if content.context_variables != {}: + agent._context_variables.update(content.context_variables) + if content.agent is not None: + next_agent = content.agent + elif isinstance(content, Agent): + next_agent = content + + tool_responses_inner.append(tool_response) + contents.append(str(tool_response["content"])) + + agent._swarm_next_agent = next_agent + + # Put the tool responses and content strings back into the response message + # Caters for multiple tool calls + tool_message["tool_responses"] = tool_responses_inner + tool_message["content"] = "\n".join(contents) + + return True, tool_message + return False, None - # We need to restore the chat queue message if it has been modified so that it will be the original message for subsequent uses - if restore_chat_queue_message: - chat_queue[0]["message"] = original_chat_queue_message - return True, res[-1].summary +class SwarmAgent(ConversableAgent): + """SwarmAgent is deprecated and has been incorporated into ConversableAgent, use ConversableAgent instead. SwarmAgent will be removed in a future version (TBD)""" + def __init__(self, *args, **kwargs): + warnings.warn( + "SwarmAgent is deprecated and has been incorporated into ConversableAgent, use ConversableAgent instead. SwarmAgent will be removed in a future version (TBD).", + DeprecationWarning, + stacklevel=2, + ) -# Forward references for SwarmAgent in SwarmResult -SwarmResult.update_forward_refs() + super().__init__(*args, **kwargs) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 1812c56600..ce7a89a282 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# Copyright (c) 2023 - 2025, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 # @@ -13,6 +13,8 @@ import re import warnings from collections import defaultdict +from dataclasses import dataclass +from inspect import signature from typing import ( Any, Callable, @@ -55,7 +57,7 @@ ) from ..oai.client import ModelClient, OpenAIWrapper from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled -from ..tools import ChatContext, Tool, load_basemodels_if_needed, serialize_to_str +from ..tools import ChatContext, Tool, get_function_schema, load_basemodels_if_needed, serialize_to_str from .agent import Agent, LLMAgent from .chat import ChatResult, _post_process_carryover_item, a_initiate_chats, initiate_chats from .utils import consolidate_chat_info, gather_usage_summary @@ -66,6 +68,55 @@ F = TypeVar("F", bound=Callable[..., Any]) +# Parameter name for context variables +# Use the value in functions and they will be substituted with the context variables: +# e.g. def my_function(context_variables: Dict[str, Any], my_other_parameters: Any) -> Any: +__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables" + + +@dataclass +class UpdateSystemMessage: + """Update the agent's system message before they reply + + Args: + content_updater: The format string or function to update the agent's system message. Can be a format string or a Callable. + If a string, it will be used as a template and substitute the context variables. + If a Callable, it should have the signature: + def my_content_updater(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str + """ + + content_updater: Union[Callable, str] + + def __post_init__(self): + if isinstance(self.content_updater, str): + # find all {var} in the string + vars = re.findall(r"\{(\w+)\}", self.content_updater) + if len(vars) == 0: + warnings.warn("Update function string contains no variables. This is probably unintended.") + + elif isinstance(self.content_updater, Callable): + sig = signature(self.content_updater) + if len(sig.parameters) != 2: + raise ValueError( + "The update function must accept two parameters of type ConversableAgent and List[Dict[str, Any]], respectively" + ) + if sig.return_annotation != str: + raise ValueError("The update function must return a string") + else: + raise ValueError("The update function must be either a string or a callable") + + +class UPDATE_SYSTEM_MESSAGE(UpdateSystemMessage): # noqa: N801 + """Deprecated: Use UpdateSystemMessage instead. This class will be removed in a future version (TBD).""" + + def __init__(self, *args, **kwargs): + warnings.warn( + "UPDATE_SYSTEM_MESSAGE is deprecated and will be removed in a future version (TBD). Use UpdateSystemMessage instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + class ConversableAgent(LLMAgent): """(In preview) A class for generic conversable agents which can be configured as assistant or user proxy. @@ -103,6 +154,10 @@ def __init__( chat_messages: Optional[dict[Agent, list[dict]]] = None, silent: Optional[bool] = None, context_variables: Optional[dict[str, Any]] = None, + functions: Union[list[Callable], Callable] = None, + update_agent_state_before_reply: Optional[ + Union[list[Union[Callable, UpdateSystemMessage]], Callable, UpdateSystemMessage] + ] = None, ): """Args: name (str): name of the agent. @@ -156,6 +211,11 @@ def __init__( Note: Will maintain a reference to the passed in context variables (enabling a shared context) Only used in Swarms at this stage: https://docs.ag2.ai/docs/reference/agentchat/contrib/swarm_agent + functions (List[Callable]): A list of functions to register with the agent. + These functions will be provided to the LLM, however they won't, by default, be executed by the agent. + If the agent is in a swarm, the swarm's tool executor will execute the function. + When not in a swarm, you can have another agent execute the tools by adding them to that agent's function_map. + update_agent_state_before_reply (List[Callable]): A list of functions, including UpdateSystemMessage's, called to update the agent before it replies. """ # we change code_execution_config below and we have to make sure we don't change the input # in case of UserProxyAgent, without this we could even change the default value {} @@ -179,6 +239,7 @@ def __init__( else (lambda x: content_str(x.get("content")) == "TERMINATE") ) self.silent = silent + # Take a copy to avoid modifying the given dict if isinstance(llm_config, dict): try: @@ -217,6 +278,16 @@ def __init__( self._context_variables = context_variables if context_variables is not None else {} + # Register functions to the agent + if isinstance(functions, list): + if not all(isinstance(func, Callable) for func in functions): + raise TypeError("All elements in the functions list must be callable") + self._add_functions(functions) + elif isinstance(functions, Callable): + self._add_single_function(functions) + elif functions is not None: + raise TypeError("Functions must be a callable or a list of callables") + # Setting up code execution. # Do not register code execution reply if code execution is disabled. if code_execution_config is not False: @@ -284,11 +355,119 @@ def __init__( "update_agent_state": [], } + # Associate agent update state hooks + self._register_update_agent_state_before_reply(update_agent_state_before_reply) + def _validate_name(self, name: str) -> None: # Validation for name using regex to detect any whitespace if re.search(r"\s", name): raise ValueError(f"The name of the agent cannot contain any whitespace. The name provided is: '{name}'") + def _get_display_name(self): + """Get the string representation of the agent. + + If you would like to change the standard string representation for an + instance of ConversableAgent, you can point it to another function. + In this example a function called _swarm_agent_str that returns a string: + agent._get_display_name = MethodType(_swarm_agent_str, agent) + """ + return self.name + + def __str__(self): + return self._get_display_name() + + def _add_functions(self, func_list: list[Callable]): + """Add (Register) a list of functions to the agent + + Args: + func_list (list[Callable]): A list of functions to register with the agent.""" + for func in func_list: + self._add_single_function(func) + + def _add_single_function(self, func: Callable, name: Optional[str] = None, description: Optional[str] = ""): + """Add a single function to the agent, removing context variables for LLM use. + + Args: + func (Callable): The function to register. + name (str): The name of the function. If not provided, the function's name will be used. + description (str): The description of the function, used by the LLM. If not provided, the function's docstring will be used. + """ + if name: + func._name = name + else: + func._name = func.__name__ + + if description: + func._description = description + else: + # Use function's docstring, strip whitespace, fall back to empty string + func._description = (func.__doc__ or "").strip() + + f = get_function_schema(func, name=func._name, description=func._description) + + # Remove context_variables parameter from function schema + f_no_context = f.copy() + if __CONTEXT_VARIABLES_PARAM_NAME__ in f_no_context["function"]["parameters"]["properties"]: + del f_no_context["function"]["parameters"]["properties"][__CONTEXT_VARIABLES_PARAM_NAME__] + if "required" in f_no_context["function"]["parameters"]: + required = f_no_context["function"]["parameters"]["required"] + f_no_context["function"]["parameters"]["required"] = [ + param for param in required if param != __CONTEXT_VARIABLES_PARAM_NAME__ + ] + # If required list is empty, remove it + if not f_no_context["function"]["parameters"]["required"]: + del f_no_context["function"]["parameters"]["required"] + + self.update_tool_signature(f_no_context, is_remove=False) + self.register_function({func._name: func}) + + def _register_update_agent_state_before_reply(self, functions: Optional[Union[list[Callable], Callable]]): + """ + Register functions that will be called when the agent is selected and before it speaks. + You can add your own validation or precondition functions here. + + Args: + functions (List[Callable[[], None]]): A list of functions to be registered. Each function + is called when the agent is selected and before it speaks. + """ + if functions is None: + return + if not isinstance(functions, list) and type(functions) not in [UpdateSystemMessage, Callable]: + raise ValueError("functions must be a list of callables") + + if not isinstance(functions, list): + functions = [functions] + + for func in functions: + if isinstance(func, UpdateSystemMessage): + # Wrapper function that allows this to be used in the update_agent_state hook + # Its primary purpose, however, is just to update the agent's system message + # Outer function to create a closure with the update function + def create_wrapper(update_func: UpdateSystemMessage): + def update_system_message_wrapper( + agent: ConversableAgent, messages: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + if isinstance(update_func.content_updater, str): + # Templates like "My context variable passport is {passport}" will + # use the context_variables for substitution + sys_message = OpenAIWrapper.instantiate( + template=update_func.content_updater, + context=agent._context_variables, + allow_format_str_template=True, + ) + else: + sys_message = update_func.content_updater(agent, messages) + + agent.update_system_message(sys_message) + return messages + + return update_system_message_wrapper + + self.register_hook(hookable_method="update_agent_state", hook=create_wrapper(func)) + + else: + self.register_hook(hookable_method="update_agent_state", hook=func) + def _validate_llm_config(self, llm_config): assert llm_config in (None, False) or isinstance(llm_config, dict), ( "llm_config must be a dict or False or None." @@ -444,6 +623,148 @@ def _get_chats_to_run( chat_to_run.append(current_c) return chat_to_run + @staticmethod + def _process_nested_chat_carryover( + chat: dict[str, Any], + recipient: Agent, + messages: list[dict[str, Any]], + sender: Agent, + config: Any, + trim_n_messages: int = 0, + ) -> None: + """Process carryover messages for a nested chat (typically for the first chat of a swarm) + + The carryover_config key is a dictionary containing: + "summary_method": The method to use to summarise the messages, can be "all", "last_msg", "reflection_with_llm" or a Callable + "summary_args": Optional arguments for the summary method + + Supported carryover 'summary_methods' are: + "all" - all messages will be incorporated + "last_msg" - the last message will be incorporated + "reflection_with_llm" - an llm will summarise all the messages and the summary will be incorporated as a single message + Callable - a callable with the signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str + + Args: + chat: The chat dictionary containing the carryover configuration + recipient: The recipient agent + messages: The messages from the parent chat + sender: The sender agent + trim_n_messages: The number of latest messages to trim from the messages list + """ + + def concat_carryover(chat_message: str, carryover_message: Union[str, list[dict[str, Any]]]) -> str: + """Concatenate the carryover message to the chat message.""" + prefix = f"{chat_message}\n" if chat_message else "" + + if isinstance(carryover_message, str): + content = carryover_message + elif isinstance(carryover_message, list): + content = "\n".join( + msg["content"] for msg in carryover_message if "content" in msg and msg["content"] is not None + ) + else: + raise ValueError("Carryover message must be a string or a list of dictionaries") + + return f"{prefix}Context:\n{content}" + + carryover_config = chat["carryover_config"] + + if "summary_method" not in carryover_config: + raise ValueError("Carryover configuration must contain a 'summary_method' key") + + carryover_summary_method = carryover_config["summary_method"] + carryover_summary_args = carryover_config.get("summary_args") or {} + + chat_message = "" + message = chat.get("message") + + # If the message is a callable, run it and get the result + if message: + chat_message = message(recipient, messages, sender, config) if callable(message) else message + + # deep copy and trim the latest messages + content_messages = copy.deepcopy(messages) + content_messages = content_messages[:-trim_n_messages] + + if carryover_summary_method == "all": + # Put a string concatenated value of all parent messages into the first message + # (e.g. message = \nContext: \n\n\n...) + carry_over_message = concat_carryover(chat_message, content_messages) + + elif carryover_summary_method == "last_msg": + # (e.g. message = \nContext: \n) + carry_over_message = concat_carryover(chat_message, content_messages[-1]["content"]) + + elif carryover_summary_method == "reflection_with_llm": + # (e.g. message = \nContext: \n) + + # Add the messages to the nested chat agent for reflection (we'll clear after reflection) + chat["recipient"]._oai_messages[sender] = content_messages + + carry_over_message_llm = ConversableAgent._reflection_with_llm_as_summary( + sender=sender, + recipient=chat["recipient"], # Chat recipient LLM config will be used for the reflection + summary_args=carryover_summary_args, + ) + + recipient._oai_messages[sender] = [] + + carry_over_message = concat_carryover(chat_message, carry_over_message_llm) + + elif isinstance(carryover_summary_method, Callable): + # (e.g. message = \nContext: \n) + carry_over_message_result = carryover_summary_method(recipient, content_messages, carryover_summary_args) + + carry_over_message = concat_carryover(chat_message, carry_over_message_result) + + chat["message"] = carry_over_message + + @staticmethod + def _process_chat_queue_carryover( + chat_queue: list[dict[str, Any]], + recipient: Agent, + messages: Union[str, Callable], + sender: Agent, + config: Any, + trim_messages: int = 2, + ) -> tuple[bool, Optional[str]]: + """Process carryover configuration for the first chat in the queue. + + Args: + chat_queue: List of chat configurations + recipient: Receiving agent + messages: Chat messages + sender: Sending agent + config: LLM configuration + trim_messages: Number of messages to trim for nested chat carryover (default 2 for swarm chats) + + Returns: + Tuple containing: + - restore_flag: Whether the original message needs to be restored + - original_message: The original message to restore (if any) + """ + restore_chat_queue_message = False + original_chat_queue_message = None + + # Carryover configuration allowed on the first chat in the queue only, trim the last two messages specifically for swarm nested chat carryover as these are the messages for the transition to the nested chat agent + if len(chat_queue) > 0 and "carryover_config" in chat_queue[0]: + if "message" in chat_queue[0]: + # As we're updating the message in the nested chat queue, we need to restore it after finishing this nested chat. + restore_chat_queue_message = True + original_chat_queue_message = chat_queue[0]["message"] + + # TODO Check the trimming required if not a swarm chat, it may not be 2 because other chats don't have the swarm transition messages. We may need to add as a carryover_config parameter. + ConversableAgent._process_nested_chat_carryover( + chat=chat_queue[0], + recipient=recipient, + messages=messages, + sender=sender, + config=config, + trim_n_messages=trim_messages, + ) + + return restore_chat_queue_message, original_chat_queue_message + @staticmethod def _summary_from_nested_chats( chat_queue: list[dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any @@ -454,13 +775,29 @@ def _summary_from_nested_chats( It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + The first chat in the queue can contain a 'carryover_config' which is a dictionary that denotes how to carryover messages from the parent chat into the first chat of the nested chats). Only applies to the first chat. + e.g.: carryover_summarize_chat_config = {"summary_method": "reflection_with_llm", "summary_args": None} + summary_method can be "last_msg", "all", "reflection_with_llm", Callable + The Callable signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str + The summary will be concatenated to the message of the first chat in the queue. + Returns: Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. """ + # Process carryover configuration + restore_chat_queue_message, original_chat_queue_message = ConversableAgent._process_chat_queue_carryover( + chat_queue, recipient, messages, sender, config + ) + chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) if not chat_to_run: return True, None res = initiate_chats(chat_to_run) + + # We need to restore the chat queue message if it has been modified so that it will be the original message for subsequent uses + if restore_chat_queue_message: + chat_queue[0]["message"] = original_chat_queue_message + return True, res[-1].summary @staticmethod @@ -473,14 +810,30 @@ async def _a_summary_from_nested_chats( It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. + The first chat in the queue can contain a 'carryover_config' which is a dictionary that denotes how to carryover messages from the parent chat into the first chat of the nested chats). Only applies to the first chat. + e.g.: carryover_summarize_chat_config = {"summary_method": "reflection_with_llm", "summary_args": None} + summary_method can be "last_msg", "all", "reflection_with_llm", Callable + The Callable signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str + The summary will be concatenated to the message of the first chat in the queue. + Returns: Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. """ + # Process carryover configuration + restore_chat_queue_message, original_chat_queue_message = ConversableAgent._process_chat_queue_carryover( + chat_queue, recipient, messages, sender, config + ) + chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) if not chat_to_run: return True, None res = await a_initiate_chats(chat_to_run) index_of_last_chat = chat_to_run[-1]["chat_id"] + + # We need to restore the chat queue message if it has been modified so that it will be the original message for subsequent uses + if restore_chat_queue_message: + chat_queue[0]["message"] = original_chat_queue_message + return True, res[index_of_last_chat].summary def register_nested_chats( diff --git a/autogen/agentchat/realtime_agent/realtime_agent.py b/autogen/agentchat/realtime_agent/realtime_agent.py index 984191d44e..f372832d41 100644 --- a/autogen/agentchat/realtime_agent/realtime_agent.py +++ b/autogen/agentchat/realtime_agent/realtime_agent.py @@ -10,7 +10,6 @@ from fastapi import WebSocket from ...tools import Tool -from .. import SwarmAgent from ..agent import Agent from ..contrib.swarm_agent import AfterWorkOption, initiate_swarm_chat from ..conversable_agent import ConversableAgent @@ -105,8 +104,8 @@ def __init__( self._answer_event: anyio.Event = anyio.Event() self._answer: str = "" self._start_swarm_chat = False - self._initial_agent: Optional[SwarmAgent] = None - self._agents: Optional[list[SwarmAgent]] = None + self._initial_agent: Optional[ConversableAgent] = None + self._agents: Optional[list[ConversableAgent]] = None def _validate_name(self, name: str) -> None: # RealtimeAgent does not need to validate the name @@ -138,15 +137,15 @@ def register_observer(self, observer: RealtimeObserver) -> None: def register_swarm( self, *, - initial_agent: SwarmAgent, - agents: list[SwarmAgent], + initial_agent: ConversableAgent, + agents: list[ConversableAgent], system_message: Optional[str] = None, ) -> None: """Register a swarm of agents with the Realtime Agent. Args: - initial_agent (SwarmAgent): The initial agent. - agents (list[SwarmAgent]): The agents in the swarm. + initial_agent (ConversableAgent): The initial agent. + agents (list[ConversableAgent]): The agents in the swarm. system_message (str): The system message for the agent. """ logger = self.logger diff --git a/notebook/agentchat_realtime_swarm.ipynb b/notebook/agentchat_realtime_swarm.ipynb index b7ade2c755..16774d43d7 100644 --- a/notebook/agentchat_realtime_swarm.ipynb +++ b/notebook/agentchat_realtime_swarm.ipynb @@ -14,6 +14,17 @@ "In this notebook, we implement OpenAI's [airline customer service example](https://github.com/openai/swarm/tree/main/examples/airline) in AG2 using the RealtimeAgent for enhanced interaction." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "````mdx-code-block\n", + ":::note\n", + "This notebook has been updated as swarms can now accommodate any ConversableAgent.\n", + ":::\n", + "````" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -270,14 +281,14 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from autogen import ON_CONDITION, SwarmAgent\n", + "from autogen import ConversableAgent, OnCondition, register_hand_off\n", "\n", "# Triage Agent\n", - "triage_agent = SwarmAgent(\n", + "triage_agent = ConversableAgent(\n", " name=\"Triage_Agent\",\n", " system_message=triage_instructions(context_variables=context_variables),\n", " llm_config=llm_config,\n", @@ -285,7 +296,7 @@ ")\n", "\n", "# Flight Modification Agent\n", - "flight_modification = SwarmAgent(\n", + "flight_modification = ConversableAgent(\n", " name=\"Flight_Modification_Agent\",\n", " system_message=\"\"\"You are a Flight Modification Agent for a customer service airline.\n", " Your task is to determine if the user wants to cancel or change their flight.\n", @@ -295,7 +306,7 @@ ")\n", "\n", "# Flight Cancel Agent\n", - "flight_cancel = SwarmAgent(\n", + "flight_cancel = ConversableAgent(\n", " name=\"Flight_Cancel_Traversal\",\n", " system_message=STARTER_PROMPT + FLIGHT_CANCELLATION_POLICY,\n", " llm_config=llm_config,\n", @@ -303,7 +314,7 @@ ")\n", "\n", "# Flight Change Agent\n", - "flight_change = SwarmAgent(\n", + "flight_change = ConversableAgent(\n", " name=\"Flight_Change_Traversal\",\n", " system_message=STARTER_PROMPT + FLIGHT_CHANGE_POLICY,\n", " llm_config=llm_config,\n", @@ -311,7 +322,7 @@ ")\n", "\n", "# Lost Baggage Agent\n", - "lost_baggage = SwarmAgent(\n", + "lost_baggage = ConversableAgent(\n", " name=\"Lost_Baggage_Traversal\",\n", " system_message=STARTER_PROMPT + LOST_BAGGAGE_POLICY,\n", " llm_config=llm_config,\n", @@ -325,33 +336,35 @@ "source": [ "### Register Handoffs\n", "\n", - "Now we register the handoffs for the agents. Note that you don't need to define the transfer functions and pass them in. Instead, you can directly register the handoffs using the `ON_CONDITION` class." + "Now we register the handoffs for the agents. Note that you don't need to define the transfer functions and pass them in. Instead, you can directly register the handoffs using the `OnCondition` class." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Register hand-offs\n", - "triage_agent.register_hand_off(\n", - " [\n", - " ON_CONDITION(flight_modification, \"To modify a flight\"),\n", - " ON_CONDITION(lost_baggage, \"To find lost baggage\"),\n", - " ]\n", + "register_hand_off(\n", + " agent=triage_agent,\n", + " hand_to=[\n", + " OnCondition(flight_modification, \"To modify a flight\"),\n", + " OnCondition(lost_baggage, \"To find lost baggage\"),\n", + " ],\n", ")\n", "\n", - "flight_modification.register_hand_off(\n", - " [\n", - " ON_CONDITION(flight_cancel, \"To cancel a flight\"),\n", - " ON_CONDITION(flight_change, \"To change a flight\"),\n", - " ]\n", + "register_hand_off(\n", + " agent=flight_modification,\n", + " hand_to=[\n", + " OnCondition(flight_cancel, \"To cancel a flight\"),\n", + " OnCondition(flight_change, \"To change a flight\"),\n", + " ],\n", ")\n", "\n", "transfer_to_triage_description = \"Call this function when a user needs to be transferred to a different agent and a different policy.\\nFor instance, if a user is asking about a topic that is not handled by the current agent, call this function.\"\n", "for agent in [flight_modification, flight_cancel, flight_change, lost_baggage]:\n", - " agent.register_hand_off(ON_CONDITION(triage_agent, transfer_to_triage_description))" + " register_hand_off(agent=agent, hand_to=OnCondition(triage_agent, transfer_to_triage_description))" ] }, { diff --git a/notebook/agentchat_swarm.ipynb b/notebook/agentchat_swarm.ipynb index 56141046c2..02878c3e22 100644 --- a/notebook/agentchat_swarm.ipynb +++ b/notebook/agentchat_swarm.ipynb @@ -16,6 +16,17 @@ "In this notebook, we implement OpenAI's [airline customer service example](https://github.com/openai/swarm/tree/main/examples/airline) in AG2." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "````mdx-code-block\n", + ":::note\n", + "This notebook has been updated as swarms can now accommodate any ConversableAgent.\n", + ":::\n", + "````" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -222,10 +233,10 @@ "metadata": {}, "outputs": [], "source": [ - "from autogen import ON_CONDITION, AfterWorkOption, SwarmAgent, initiate_swarm_chat\n", + "from autogen import AfterWorkOption, ConversableAgent, OnCondition, initiate_swarm_chat, register_hand_off\n", "\n", "# Triage Agent\n", - "triage_agent = SwarmAgent(\n", + "triage_agent = ConversableAgent(\n", " name=\"Triage_Agent\",\n", " system_message=triage_instructions(context_variables=context_variables),\n", " llm_config=llm_config,\n", @@ -233,7 +244,7 @@ ")\n", "\n", "# Flight Modification Agent\n", - "flight_modification = SwarmAgent(\n", + "flight_modification = ConversableAgent(\n", " name=\"Flight_Modification_Agent\",\n", " system_message=\"\"\"You are a Flight Modification Agent for a customer service airline.\n", " Your task is to determine if the user wants to cancel or change their flight.\n", @@ -243,7 +254,7 @@ ")\n", "\n", "# Flight Cancel Agent\n", - "flight_cancel = SwarmAgent(\n", + "flight_cancel = ConversableAgent(\n", " name=\"Flight_Cancel_Traversal\",\n", " system_message=STARTER_PROMPT + FLIGHT_CANCELLATION_POLICY,\n", " llm_config=llm_config,\n", @@ -251,7 +262,7 @@ ")\n", "\n", "# Flight Change Agent\n", - "flight_change = SwarmAgent(\n", + "flight_change = ConversableAgent(\n", " name=\"Flight_Change_Traversal\",\n", " system_message=STARTER_PROMPT + FLIGHT_CHANGE_POLICY,\n", " llm_config=llm_config,\n", @@ -259,7 +270,7 @@ ")\n", "\n", "# Lost Baggage Agent\n", - "lost_baggage = SwarmAgent(\n", + "lost_baggage = ConversableAgent(\n", " name=\"Lost_Baggage_Traversal\",\n", " system_message=STARTER_PROMPT + LOST_BAGGAGE_POLICY,\n", " llm_config=llm_config,\n", @@ -273,7 +284,7 @@ "source": [ "### Register Handoffs\n", "\n", - "Now we register the handoffs for the agents. Note that you don't need to define the transfer functions and pass them in. Instead, you can directly register the handoffs using the `ON_CONDITION` class." + "Now we register the handoffs for the agents. Note that you don't need to define the transfer functions and pass them in. Instead, you can directly register the handoffs using the `OnCondition` class." ] }, { @@ -283,23 +294,25 @@ "outputs": [], "source": [ "# Register hand-offs\n", - "triage_agent.register_hand_off(\n", - " [\n", - " ON_CONDITION(flight_modification, \"To modify a flight\"),\n", - " ON_CONDITION(lost_baggage, \"To find lost baggage\"),\n", - " ]\n", + "register_hand_off(\n", + " agent=triage_agent,\n", + " hand_to=[\n", + " OnCondition(flight_modification, \"To modify a flight\"),\n", + " OnCondition(lost_baggage, \"To find lost baggage\"),\n", + " ],\n", ")\n", "\n", - "flight_modification.register_hand_off(\n", - " [\n", - " ON_CONDITION(flight_cancel, \"To cancel a flight\"),\n", - " ON_CONDITION(flight_change, \"To change a flight\"),\n", - " ]\n", + "register_hand_off(\n", + " agent=flight_modification,\n", + " hand_to=[\n", + " OnCondition(flight_cancel, \"To cancel a flight\"),\n", + " OnCondition(flight_change, \"To change a flight\"),\n", + " ],\n", ")\n", "\n", "transfer_to_triage_description = \"Call this function when a user needs to be transferred to a different agent and a different policy.\\nFor instance, if a user is asking about a topic that is not handled by the current agent, call this function.\"\n", "for agent in [flight_modification, flight_cancel, flight_change, lost_baggage]:\n", - " agent.register_hand_off(ON_CONDITION(triage_agent, transfer_to_triage_description))" + " register_hand_off(agent=agent, hand_to=OnCondition(triage_agent, transfer_to_triage_description))" ] }, { @@ -373,7 +386,7 @@ "\u001b[33mTool_Execution\u001b[0m (to chat_manager):\n", "\n", "\u001b[32m***** Response from calling tool (call_Qgji9KAw1e3ktxykLU8v1wg7) *****\u001b[0m\n", - "SwarmAgent --> Flight_Modification_Agent\n", + "Swarm agent --> Flight_Modification_Agent\n", "\u001b[32m**********************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", @@ -397,7 +410,7 @@ "\u001b[33mTool_Execution\u001b[0m (to chat_manager):\n", "\n", "\u001b[32m***** Response from calling tool (call_QYu7uBko1EaEZ7VzxPwx2jNO) *****\u001b[0m\n", - "SwarmAgent --> Flight_Cancel_Traversal\n", + "Swarm agent --> Flight_Cancel_Traversal\n", "\u001b[32m**********************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", diff --git a/notebook/agentchat_swarm_enhanced.ipynb b/notebook/agentchat_swarm_enhanced.ipynb index 40d55ec46e..46e18dfc03 100644 --- a/notebook/agentchat_swarm_enhanced.ipynb +++ b/notebook/agentchat_swarm_enhanced.ipynb @@ -19,6 +19,17 @@ "- Nested chats" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "````mdx-code-block\n", + ":::note\n", + "This notebook has been updated as swarms can now accommodate any ConversableAgent.\n", + ":::\n", + "````" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -95,22 +106,22 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from typing import Any, Dict, List\n", "\n", "from autogen import (\n", - " AFTER_WORK,\n", - " ON_CONDITION,\n", - " UPDATE_SYSTEM_MESSAGE,\n", + " AfterWork,\n", " AfterWorkOption,\n", " ConversableAgent,\n", - " SwarmAgent,\n", + " OnCondition,\n", " SwarmResult,\n", + " UpdateSystemMessage,\n", " UserProxyAgent,\n", " initiate_swarm_chat,\n", + " register_hand_off,\n", ")" ] }, @@ -274,7 +285,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -299,10 +310,10 @@ "Enquiring for Order ID: {order_id}\n", "\"\"\"\n", "\n", - "order_triage_agent = SwarmAgent(\n", + "order_triage_agent = ConversableAgent(\n", " name=\"order_triage_agent\",\n", " update_agent_state_before_reply=[\n", - " UPDATE_SYSTEM_MESSAGE(order_triage_prompt),\n", + " UpdateSystemMessage(order_triage_prompt),\n", " ],\n", " functions=[check_order_id, record_order_id],\n", " llm_config=llm_config,\n", @@ -310,7 +321,7 @@ "\n", "authentication_prompt = \"You are an authentication agent that verifies the identity of the customer.\"\n", "\n", - "authentication_agent = SwarmAgent(\n", + "authentication_agent = ConversableAgent(\n", " name=\"authentication_agent\",\n", " system_message=authentication_prompt,\n", " functions=[login_customer_by_username],\n", @@ -329,10 +340,10 @@ "Enquiring for Order ID: {order_id}\n", "\"\"\"\n", "\n", - "order_mgmt_agent = SwarmAgent(\n", + "order_mgmt_agent = ConversableAgent(\n", " name=\"order_mgmt_agent\",\n", " update_agent_state_before_reply=[\n", - " UPDATE_SYSTEM_MESSAGE(order_management_prompt),\n", + " UpdateSystemMessage(order_management_prompt),\n", " ],\n", " functions=[check_order_id, record_order_id],\n", " llm_config=llm_config,\n", @@ -397,65 +408,68 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Handoffs (ON_CONDITIONS and AFTER_WORKS)" + "### Handoffs (OnCondition and AfterWork)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# HANDOFFS\n", - "order_triage_agent.register_hand_off(\n", - " [\n", - " ON_CONDITION(\n", + "register_hand_off(\n", + " agent=order_triage_agent,\n", + " hand_to=[\n", + " OnCondition(\n", " target=authentication_agent,\n", " condition=\"The customer is not logged in, authenticate the customer.\",\n", " available=\"requires_login\",\n", " ),\n", - " ON_CONDITION(\n", + " OnCondition(\n", " target=order_mgmt_agent,\n", " condition=\"The customer is logged in, continue with the order triage.\",\n", " available=\"logged_in\",\n", " ),\n", - " AFTER_WORK(AfterWorkOption.REVERT_TO_USER),\n", - " ]\n", + " AfterWork(AfterWorkOption.REVERT_TO_USER),\n", + " ],\n", ")\n", "\n", - "authentication_agent.register_hand_off(\n", - " [\n", - " ON_CONDITION(\n", + "register_hand_off(\n", + " agent=authentication_agent,\n", + " hand_to=[\n", + " OnCondition(\n", " target=order_triage_agent,\n", " condition=\"The customer is logged in, continue with the order triage.\",\n", " available=\"logged_in\",\n", " ),\n", - " AFTER_WORK(AfterWorkOption.REVERT_TO_USER),\n", - " ]\n", + " AfterWork(AfterWorkOption.REVERT_TO_USER),\n", + " ],\n", ")\n", "\n", "\n", - "def has_order_in_context(agent: SwarmAgent, messages: List[Dict[str, Any]]) -> bool:\n", + "def has_order_in_context(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> bool:\n", " return agent.get_context(\"has_order_id\")\n", "\n", "\n", - "order_mgmt_agent.register_hand_off(\n", - " [\n", - " ON_CONDITION(\n", + "register_hand_off(\n", + " agent=order_mgmt_agent,\n", + " hand_to=[\n", + " OnCondition(\n", " target={\n", " \"chat_queue\": chat_queue,\n", " },\n", " condition=\"Retrieve the status of the order\",\n", " available=has_order_in_context,\n", " ),\n", - " ON_CONDITION(\n", + " OnCondition(\n", " target=authentication_agent,\n", " condition=\"The customer is not logged in, authenticate the customer.\",\n", " available=\"requires_login\",\n", " ),\n", - " ON_CONDITION(target=order_triage_agent, condition=\"The customer has no more enquiries about this order.\"),\n", - " AFTER_WORK(AfterWorkOption.REVERT_TO_USER),\n", - " ]\n", + " OnCondition(target=order_triage_agent, condition=\"The customer has no more enquiries about this order.\"),\n", + " AfterWork(AfterWorkOption.REVERT_TO_USER),\n", + " ],\n", ")" ] }, @@ -499,7 +513,7 @@ "\u001b[33mTool_Execution\u001b[0m (to chat_manager):\n", "\n", "\u001b[32m***** Response from calling tool (call_RhIdaMav5FoXxvXiYhyDoivV) *****\u001b[0m\n", - "SwarmAgent --> authentication_agent\n", + "Swarm agent --> authentication_agent\n", "\u001b[32m**********************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", @@ -600,7 +614,7 @@ "\u001b[33mTool_Execution\u001b[0m (to chat_manager):\n", "\n", "\u001b[32m***** Response from calling tool (call_mXHJHDzVPTXWDhll0UH7w3QI) *****\u001b[0m\n", - "SwarmAgent --> order_mgmt_agent\n", + "Swarm agent --> order_mgmt_agent\n", "\u001b[32m**********************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", @@ -762,7 +776,7 @@ "\u001b[33mTool_Execution\u001b[0m (to chat_manager):\n", "\n", "\u001b[32m***** Response from calling tool (call_sYsVS1U3k3Cf2KbqKJ4hhyRa) *****\u001b[0m\n", - "SwarmAgent --> nested_chat_order_mgmt_agent_1\n", + "Swarm agent --> nested_chat_order_mgmt_agent_1\n", "\u001b[32m**********************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", @@ -868,7 +882,7 @@ "\u001b[33mTool_Execution\u001b[0m (to chat_manager):\n", "\n", "\u001b[32m***** Response from calling tool (call_VtBmcKhDAhh7JUz9aXyPq9Aj) *****\u001b[0m\n", - "SwarmAgent --> order_triage_agent\n", + "Swarm agent --> order_triage_agent\n", "\u001b[32m**********************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", diff --git a/notebook/agentchat_swarm_graphrag_telemetry_trip_planner.ipynb b/notebook/agentchat_swarm_graphrag_telemetry_trip_planner.ipynb index 00bac06f29..259bca191d 100644 --- a/notebook/agentchat_swarm_graphrag_telemetry_trip_planner.ipynb +++ b/notebook/agentchat_swarm_graphrag_telemetry_trip_planner.ipynb @@ -17,6 +17,17 @@ "- Swarm orchestration utilising context variables" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "````mdx-code-block\n", + ":::note\n", + "This notebook has been updated as swarms can now accommodate any ConversableAgent.\n", + ":::\n", + "````" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -416,7 +427,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -430,13 +441,14 @@ "from pydantic import BaseModel\n", "\n", "from autogen import (\n", - " AFTER_WORK,\n", - " ON_CONDITION,\n", + " AfterWork,\n", " AfterWorkOption,\n", - " SwarmAgent,\n", + " ConversableAgent,\n", + " OnCondition,\n", " SwarmResult,\n", " UserProxyAgent,\n", " initiate_swarm_chat,\n", + " register_hand_off,\n", ")" ] }, @@ -638,17 +650,17 @@ "source": [ "### Agents\n", "\n", - "Our SwarmAgents and a UserProxyAgent (human) which the swarm will interact with." + "Our Swarm agents and a UserProxyAgent (human) which the swarm will interact with." ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Planner agent, interacting with the customer and GraphRag agent, to create an itinerary\n", - "planner_agent = SwarmAgent(\n", + "planner_agent = ConversableAgent(\n", " name=\"planner_agent\",\n", " system_message=\"You are a trip planner agent. It is important to know where the customer is going, how many days, what they want to do.\"\n", " + \"You will work with another agent, graphrag_agent, to get information about restaurant and attractions. \"\n", @@ -661,7 +673,7 @@ ")\n", "\n", "# FalkorDB GraphRAG agent, utilising the FalkorDB to gather data for the Planner agent\n", - "graphrag_agent = SwarmAgent(\n", + "graphrag_agent = ConversableAgent(\n", " name=\"graphrag_agent\",\n", " system_message=\"Return a list of restaurants and/or attractions. List them separately and provide ALL the options in the location. Do not provide travel advice.\",\n", ")\n", @@ -675,7 +687,7 @@ "for config in structured_config_list:\n", " config[\"response_format\"] = Itinerary\n", "\n", - "structured_output_agent = SwarmAgent(\n", + "structured_output_agent = ConversableAgent(\n", " name=\"structured_output_agent\",\n", " system_message=\"You are a data formatting agent, format the provided itinerary in the context below into the provided format.\",\n", " llm_config={\"config_list\": structured_config_list, \"timeout\": 120},\n", @@ -683,7 +695,7 @@ ")\n", "\n", "# Route Timing agent, adding estimated travel times to the itinerary by utilising the Google Maps Platform\n", - "route_timing_agent = SwarmAgent(\n", + "route_timing_agent = ConversableAgent(\n", " name=\"route_timing_agent\",\n", " system_message=\"You are a route timing agent. YOU MUST call the update_itinerary_with_travel_times tool if you do not see the exact phrase 'Timed itinerary added to context with travel times' is seen in this conversation. Only after this please tell the customer 'Your itinerary is ready!'.\",\n", " llm_config=llm_config,\n", @@ -707,32 +719,31 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "planner_agent.register_hand_off(\n", + "register_hand_off(\n", + " agent=planner_agent,\n", " hand_to=[\n", - " ON_CONDITION(\n", + " OnCondition(\n", " graphrag_agent,\n", " \"Need information on the restaurants and attractions for a location. DO NOT call more than once at a time.\",\n", " ), # Get info from FalkorDB GraphRAG\n", - " ON_CONDITION(structured_output_agent, \"Itinerary is confirmed by the customer\"),\n", - " AFTER_WORK(AfterWorkOption.REVERT_TO_USER), # Revert to the customer for more information on their plans\n", - " ]\n", + " OnCondition(structured_output_agent, \"Itinerary is confirmed by the customer\"),\n", + " AfterWork(AfterWorkOption.REVERT_TO_USER), # Revert to the customer for more information on their plans\n", + " ],\n", ")\n", "\n", "\n", "# Back to the Planner when information has been retrieved\n", - "graphrag_agent.register_hand_off(hand_to=[AFTER_WORK(planner_agent)])\n", + "register_hand_off(agent=graphrag_agent, hand_to=[AfterWork(planner_agent)])\n", "\n", "# Once we have formatted our itinerary, we can hand off to the route timing agent to add in the travel timings\n", - "structured_output_agent.register_hand_off(hand_to=[AFTER_WORK(route_timing_agent)])\n", + "register_hand_off(agent=structured_output_agent, hand_to=[AfterWork(route_timing_agent)])\n", "\n", "# Finally, once the route timing agent has finished, we can terminate the swarm\n", - "route_timing_agent.register_hand_off(\n", - " hand_to=[AFTER_WORK(AfterWorkOption.TERMINATE)] # Once this agent has finished, the swarm can terminate\n", - ")" + "register_hand_off(agent=route_timing_agent, hand_to=[AfterWork(AfterWorkOption.TERMINATE)])" ] }, { diff --git a/notebook/agentchat_swarm_graphrag_trip_planner.ipynb b/notebook/agentchat_swarm_graphrag_trip_planner.ipynb index 4a33b9c209..ca4ce1aa16 100644 --- a/notebook/agentchat_swarm_graphrag_trip_planner.ipynb +++ b/notebook/agentchat_swarm_graphrag_trip_planner.ipynb @@ -16,6 +16,17 @@ "- Swarm orchestration utilising context variables" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "````mdx-code-block\n", + ":::note\n", + "This notebook has been updated as swarms can now accommodate any ConversableAgent.\n", + ":::\n", + "````" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -312,7 +323,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -326,13 +337,14 @@ "from pydantic import BaseModel\n", "\n", "from autogen import (\n", - " AFTER_WORK,\n", - " ON_CONDITION,\n", + " AfterWork,\n", " AfterWorkOption,\n", - " SwarmAgent,\n", + " ConversableAgent,\n", + " OnCondition,\n", " SwarmResult,\n", " UserProxyAgent,\n", " initiate_swarm_chat,\n", + " register_hand_off,\n", ")" ] }, @@ -534,17 +546,17 @@ "source": [ "### Agents\n", "\n", - "Our SwarmAgents and a UserProxyAgent (human) which the swarm will interact with." + "Our Swarm agents and a UserProxyAgent (human) which the swarm will interact with." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Planner agent, interacting with the customer and GraphRag agent, to create an itinerary\n", - "planner_agent = SwarmAgent(\n", + "planner_agent = ConversableAgent(\n", " name=\"planner_agent\",\n", " system_message=\"You are a trip planner agent. It is important to know where the customer is going, how many days, what they want to do.\"\n", " + \"You will work with another agent, graphrag_agent, to get information about restaurant and attractions. \"\n", @@ -557,7 +569,7 @@ ")\n", "\n", "# FalkorDB GraphRAG agent, utilising the FalkorDB to gather data for the Planner agent\n", - "graphrag_agent = SwarmAgent(\n", + "graphrag_agent = ConversableAgent(\n", " name=\"graphrag_agent\",\n", " system_message=\"Return a list of restaurants and/or attractions. List them separately and provide ALL the options in the location. Do not provide travel advice.\",\n", ")\n", @@ -571,7 +583,7 @@ "for config in structured_config_list:\n", " config[\"response_format\"] = Itinerary\n", "\n", - "structured_output_agent = SwarmAgent(\n", + "structured_output_agent = ConversableAgent(\n", " name=\"structured_output_agent\",\n", " system_message=\"You are a data formatting agent, format the provided itinerary in the context below into the provided format.\",\n", " llm_config={\"config_list\": structured_config_list, \"timeout\": 120},\n", @@ -579,7 +591,7 @@ ")\n", "\n", "# Route Timing agent, adding estimated travel times to the itinerary by utilising the Google Maps Platform\n", - "route_timing_agent = SwarmAgent(\n", + "route_timing_agent = ConversableAgent(\n", " name=\"route_timing_agent\",\n", " system_message=\"You are a route timing agent. YOU MUST call the update_itinerary_with_travel_times tool if you do not see the exact phrase 'Timed itinerary added to context with travel times' is seen in this conversation. Only after this please tell the customer 'Your itinerary is ready!'.\",\n", " llm_config=llm_config,\n", @@ -603,32 +615,31 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "planner_agent.register_hand_off(\n", + "register_hand_off(\n", + " agent=planner_agent,\n", " hand_to=[\n", - " ON_CONDITION(\n", + " OnCondition(\n", " graphrag_agent,\n", " \"Need information on the restaurants and attractions for a location. DO NOT call more than once at a time.\",\n", " ), # Get info from FalkorDB GraphRAG\n", - " ON_CONDITION(structured_output_agent, \"Itinerary is confirmed by the customer\"),\n", - " AFTER_WORK(AfterWorkOption.REVERT_TO_USER), # Revert to the customer for more information on their plans\n", - " ]\n", + " OnCondition(structured_output_agent, \"Itinerary is confirmed by the customer\"),\n", + " AfterWork(AfterWorkOption.REVERT_TO_USER), # Revert to the customer for more information on their plans\n", + " ],\n", ")\n", "\n", "\n", "# Back to the Planner when information has been retrieved\n", - "graphrag_agent.register_hand_off(hand_to=[AFTER_WORK(planner_agent)])\n", + "register_hand_off(agent=graphrag_agent, hand_to=[AfterWork(planner_agent)])\n", "\n", "# Once we have formatted our itinerary, we can hand off to the route timing agent to add in the travel timings\n", - "structured_output_agent.register_hand_off(hand_to=[AFTER_WORK(route_timing_agent)])\n", + "register_hand_off(agent=structured_output_agent, hand_to=[AfterWork(route_timing_agent)])\n", "\n", "# Finally, once the route timing agent has finished, we can terminate the swarm\n", - "route_timing_agent.register_hand_off(\n", - " hand_to=[AFTER_WORK(AfterWorkOption.TERMINATE)] # Once this agent has finished, the swarm can terminate\n", - ")" + "register_hand_off(agent=route_timing_agent, hand_to=[AfterWork(AfterWorkOption.TERMINATE)])" ] }, { @@ -689,7 +700,7 @@ "\u001b[33mTool_Execution\u001b[0m (to chat_manager):\n", "\n", "\u001b[32m***** Response from calling tool (call_vQMpso8aOomdfq8S2uCRlnzj) *****\u001b[0m\n", - "SwarmAgent --> graphrag_agent\n", + "Swarm agent --> graphrag_agent\n", "\u001b[32m**********************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", @@ -905,7 +916,7 @@ "\n", "--------------------------------------------------------------------------------\n", "\u001b[32m***** Response from calling tool (call_NBw71N4pS66h8VLlgu5nvveN) *****\u001b[0m\n", - "SwarmAgent --> structured_output_agent\n", + "Swarm agent --> structured_output_agent\n", "\u001b[32m**********************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 56be9f96b9..d5f48ddaee 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# Copyright (c) 2023 - 2025, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 from typing import Any, Union @@ -6,14 +6,14 @@ import pytest +from autogen.agentchat.agent import Agent from autogen.agentchat.contrib.swarm_agent import ( - AFTER_WORK, - ON_CONDITION, - UPDATE_SYSTEM_MESSAGE, __TOOL_EXECUTOR_NAME__, + AfterWork, AfterWorkOption, - SwarmAgent, + OnCondition, SwarmResult, + UpdateCondition, _cleanup_temp_user_messages, _create_nested_chats, _prepare_swarm_agents, @@ -21,19 +21,23 @@ _setup_context_variables, a_initiate_swarm_chat, initiate_swarm_chat, + register_hand_off, ) -from autogen.agentchat.conversable_agent import ConversableAgent +from autogen.agentchat.conversable_agent import ConversableAgent, UpdateSystemMessage from autogen.agentchat.groupchat import GroupChat, GroupChatManager from autogen.agentchat.user_proxy_agent import UserProxyAgent TEST_MESSAGES = [{"role": "user", "content": "Initial message"}] -def test_swarm_agent_initialization(): - """Test SwarmAgent initialization with valid and invalid parameters""" - # Invalid functions parameter - with pytest.raises(TypeError): - SwarmAgent("test_agent", functions="invalid") +class NotConversableAgent(Agent): + """Dummy class to test invalid agent types""" + + def __init__( + self, + name: str, + ): + super().__init__(name) def test_swarm_result(): @@ -49,44 +53,45 @@ def test_swarm_result(): assert result.context_variables == context # Test with agent - agent = SwarmAgent("test") + agent = ConversableAgent("test") result = SwarmResult(values="test", agent=agent) assert result.agent == agent def test_after_work_initialization(): - """Test AFTER_WORK initialization with different options""" + """Test AfterWork initialization with different options""" # Test with AfterWorkOption - after_work = AFTER_WORK(AfterWorkOption.TERMINATE) + after_work = AfterWork(AfterWorkOption.TERMINATE) assert after_work.agent == AfterWorkOption.TERMINATE # Test with string - after_work = AFTER_WORK("TERMINATE") + after_work = AfterWork("TERMINATE") assert after_work.agent == AfterWorkOption.TERMINATE - # Test with SwarmAgent - agent = SwarmAgent("test") - after_work = AFTER_WORK(agent) + # Test with ConversableAgent + agent = ConversableAgent("test") + after_work = AfterWork(agent) assert after_work.agent == agent # Test with Callable - def test_callable(x: int) -> SwarmAgent: + def test_callable(x: int) -> ConversableAgent: return agent - after_work = AFTER_WORK(test_callable) + after_work = AfterWork(test_callable) assert after_work.agent == test_callable # Test with invalid option with pytest.raises(ValueError): - AFTER_WORK("INVALID_OPTION") + AfterWork("INVALID_OPTION") def test_on_condition(): - """Test ON_CONDITION initialization""" - # Test with a ConversableAgent - test_conversable_agent = ConversableAgent("test_conversable_agent") - with pytest.raises(AssertionError, match="'target' must be a SwarmAgent or a Dict"): - _ = ON_CONDITION(target=test_conversable_agent, condition="test condition") + """Test OnCondition initialization""" + + # Test with a base Agent + test_conversable_agent = NotConversableAgent("test_conversable_agent") + with pytest.raises(AssertionError, match="'target' must be a ConversableAgent or a dict"): + _ = OnCondition(target=test_conversable_agent, condition="test condition") def test_receiving_agent(): @@ -94,7 +99,7 @@ def test_receiving_agent(): # 1. Test with a single message - should always be the initial agent messages_one_no_name = [{"role": "user", "content": "Initial message"}] - test_initial_agent = SwarmAgent("InitialAgent") + test_initial_agent = ConversableAgent("InitialAgent") # Test the chat chat_result, context_vars, last_speaker = initiate_swarm_chat( @@ -106,7 +111,7 @@ def test_receiving_agent(): assert chat_result.chat_history[1].get("name") == "InitialAgent" # 2. Test with a single message from an existing agent (should still be initial agent) - test_second_agent = SwarmAgent("SecondAgent") + test_second_agent = ConversableAgent("SecondAgent") messages_one_w_name = [{"role": "user", "content": "Initial message", "name": "SecondAgent"}] @@ -138,8 +143,9 @@ def test_receiving_agent(): def test_resume_speaker(): """Tests resumption of chat with multiple messages""" - test_initial_agent = SwarmAgent("InitialAgent") - test_second_agent = SwarmAgent("SecondAgent") + + test_initial_agent = ConversableAgent("InitialAgent") + test_second_agent = ConversableAgent("SecondAgent") # For multiple messages, last agent initiates the chat multiple_messages = [ @@ -173,8 +179,9 @@ def test_resume_speaker(): def test_after_work_options(): """Test different after work options""" - agent1 = SwarmAgent("agent1") - agent2 = SwarmAgent("agent2") + + agent1 = ConversableAgent("agent1") + agent2 = ConversableAgent("agent2") user_agent = UserProxyAgent("test_user") # Fake generate_oai_reply @@ -186,14 +193,14 @@ def mock_generate_oai_reply(*args, **kwargs): agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply) # 1. Test TERMINATE - agent1.after_work = AFTER_WORK(AfterWorkOption.TERMINATE) + agent1._swarm_after_work = AfterWork(AfterWorkOption.TERMINATE) chat_result, context_vars, last_speaker = initiate_swarm_chat( initial_agent=agent1, messages=TEST_MESSAGES, agents=[agent1, agent2] ) assert last_speaker == agent1 # 2. Test REVERT_TO_USER - agent1.after_work = AFTER_WORK(AfterWorkOption.REVERT_TO_USER) + agent1._swarm_after_work = AfterWork(AfterWorkOption.REVERT_TO_USER) test_messages = [ {"role": "user", "content": "Initial message"}, @@ -209,7 +216,7 @@ def mock_generate_oai_reply(*args, **kwargs): assert chat_result.chat_history[3]["name"] == "test_user" # 3. Test STAY - agent1.after_work = AFTER_WORK(AfterWorkOption.STAY) + agent1._swarm_after_work = AfterWork(AfterWorkOption.STAY) chat_result, context_vars, last_speaker = initiate_swarm_chat( initial_agent=agent1, messages=test_messages, agents=[agent1, agent2], max_rounds=4 ) @@ -223,7 +230,7 @@ def mock_generate_oai_reply(*args, **kwargs): def test_callable(last_speaker, messages, groupchat): return agent2 - agent1.after_work = AFTER_WORK(test_callable) + agent1._swarm_after_work = AfterWork(test_callable) chat_result, context_vars, last_speaker = initiate_swarm_chat( initial_agent=agent1, messages=test_messages, agents=[agent1, agent2], max_rounds=4 @@ -234,7 +241,8 @@ def test_callable(last_speaker, messages, groupchat): def test_on_condition_handoff(): - """Test ON_CONDITION in handoffs""" + """Test OnCondition in handoffs""" + testing_llm_config = { "config_list": [ { @@ -244,10 +252,10 @@ def test_on_condition_handoff(): ] } - agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) - agent2 = SwarmAgent("agent2", llm_config=testing_llm_config) + agent1 = ConversableAgent("agent1", llm_config=testing_llm_config) + agent2 = ConversableAgent("agent2", llm_config=testing_llm_config) - agent1.register_hand_off(hand_to=ON_CONDITION(target=agent2, condition="always take me to agent 2")) + register_hand_off(agent1, hand_to=OnCondition(target=agent2, condition="always take me to agent 2")) # Fake generate_oai_reply def mock_generate_oai_reply(*args, **kwargs): @@ -278,8 +286,8 @@ def mock_generate_oai_reply_tool(*args, **kwargs): def test_temporary_user_proxy(): """Test that temporary user proxy agent name is cleared""" - agent1 = SwarmAgent("agent1") - agent2 = SwarmAgent("agent2") + agent1 = ConversableAgent("agent1") + agent2 = ConversableAgent("agent2") chat_result, context_vars, last_speaker = initiate_swarm_chat( initial_agent=agent1, messages=TEST_MESSAGES, agents=[agent1, agent2] @@ -314,8 +322,8 @@ def test_func_2(context_variables: dict[str, Any], param2: str) -> str: context_variables["my_key"] += 100 return SwarmResult(values=f"Test 2 {param2}", context_variables=context_variables, agent=agent1) - agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) - agent2 = SwarmAgent("agent2", functions=[test_func_1, test_func_2], llm_config=testing_llm_config) + agent1 = ConversableAgent("agent1", llm_config=testing_llm_config) + agent2 = ConversableAgent("agent2", functions=[test_func_1, test_func_2], llm_config=testing_llm_config) # Fake generate_oai_reply def mock_generate_oai_reply(*args, **kwargs): @@ -369,8 +377,8 @@ def test_func_1(context_variables: dict[str, Any], param1: str) -> str: context_variables["my_key"] += 1 return SwarmResult(values=f"Test 1 {param1}", context_variables=context_variables, agent=agent1) - agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) - agent2 = SwarmAgent("agent2", functions=[test_func_1], llm_config=testing_llm_config) + agent1 = ConversableAgent("agent1", llm_config=testing_llm_config) + agent2 = ConversableAgent("agent2", functions=[test_func_1], llm_config=testing_llm_config) # Fake generate_oai_reply def mock_generate_oai_reply(*args, **kwargs): @@ -403,8 +411,8 @@ def mock_generate_oai_reply_tool(*args, **kwargs): def test_invalid_parameters(): """Test various invalid parameter combinations""" - agent1 = SwarmAgent("agent1") - agent2 = SwarmAgent("agent2") + agent1 = ConversableAgent("agent1") + agent2 = ConversableAgent("agent2") # Test invalid initial agent type with pytest.raises(AssertionError): @@ -420,38 +428,40 @@ def test_invalid_parameters(): def test_non_swarm_in_hand_off(): - """Test that SwarmAgents in the group chat are the only agents in hand-offs""" - agent1 = SwarmAgent("agent1") - bad_agent = ConversableAgent("bad_agent") + """Test that agents in the group chat are the only agents in hand-offs""" + + agent1 = ConversableAgent("agent1") + bad_agent = NotConversableAgent("bad_agent") with pytest.raises(AssertionError, match="Invalid After Work value"): - agent1.register_hand_off(hand_to=AFTER_WORK(bad_agent)) + register_hand_off(agent1, hand_to=AfterWork(bad_agent)) - with pytest.raises(AssertionError, match="'target' must be a SwarmAgent or a Dict"): - agent1.register_hand_off(hand_to=ON_CONDITION(target=bad_agent, condition="Testing")) + with pytest.raises(AssertionError, match="'target' must be a ConversableAgent or a dict"): + register_hand_off(agent1, hand_to=OnCondition(target=bad_agent, condition="Testing")) - with pytest.raises(ValueError, match="hand_to must be a list of ON_CONDITION or AFTER_WORK"): - agent1.register_hand_off(0) + with pytest.raises(ValueError, match="hand_to must be a list of OnCondition or AfterWork"): + register_hand_off(agent1, 0) def test_initialization(): """Test initiate_swarm_chat""" - agent1 = SwarmAgent("agent1") - agent2 = SwarmAgent("agent2") - agent3 = SwarmAgent("agent3") - bad_agent = ConversableAgent("bad_agent") - with pytest.raises(AssertionError, match="Agents must be a list of SwarmAgents"): + agent1 = ConversableAgent("agent1") + agent2 = ConversableAgent("agent2") + agent3 = ConversableAgent("agent3") + bad_agent = NotConversableAgent("bad_agent") + + with pytest.raises(AssertionError, match="Agents must be a list of ConversableAgent"): chat_result, context_vars, last_speaker = initiate_swarm_chat( initial_agent=agent2, messages=TEST_MESSAGES, agents=[agent1, agent2, bad_agent], max_rounds=3 ) - with pytest.raises(AssertionError, match="initial_agent must be a SwarmAgent"): + with pytest.raises(AssertionError, match="initial_agent must be a ConversableAgent"): chat_result, context_vars, last_speaker = initiate_swarm_chat( initial_agent=bad_agent, messages=TEST_MESSAGES, agents=[agent1, agent2], max_rounds=3 ) - agent1.register_hand_off(hand_to=AFTER_WORK(agent3)) + register_hand_off(agent1, hand_to=AfterWork(agent3)) with pytest.raises(AssertionError, match="Agent in hand-off must be in the agents list"): chat_result, context_vars, last_speaker = initiate_swarm_chat( @@ -477,9 +487,9 @@ def custom_update_function(agent: ConversableAgent, messages: list[dict]) -> str template_message = "Template message with {test_var}" # Create agents with different update configurations - agent1 = SwarmAgent("agent1", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function)) + agent1 = ConversableAgent("agent1", update_agent_state_before_reply=UpdateSystemMessage(custom_update_function)) - agent2 = SwarmAgent("agent2", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(template_message)) + agent2 = ConversableAgent("agent2", update_agent_state_before_reply=UpdateSystemMessage(template_message)) # Mock the reply function to capture the system message def mock_generate_oai_reply(*args, **kwargs): @@ -514,33 +524,15 @@ def mock_generate_oai_reply(*args, **kwargs): # Verify template result assert message_container.captured_sys_message == "Template message with test_value" - # Test invalid update function - with pytest.raises(ValueError, match="Update function must be either a string or a callable"): - SwarmAgent("agent3", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(123)) - - # Test invalid callable (wrong number of parameters) - def invalid_update_function(context_variables): - return "Invalid function" - - with pytest.raises(ValueError, match="Update function must accept two parameters"): - SwarmAgent("agent4", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_update_function)) - - # Test invalid callable (wrong return type) - def invalid_return_function(context_variables, messages) -> dict: - return {} - - with pytest.raises(ValueError, match="Update function must return a string"): - SwarmAgent("agent5", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_return_function)) - # Test multiple update functions def another_update_function(context_variables: dict[str, Any], messages: list[dict]) -> str: return "Another update" - agent6 = SwarmAgent( + agent6 = ConversableAgent( "agent6", update_agent_state_before_reply=[ - UPDATE_SYSTEM_MESSAGE(custom_update_function), - UPDATE_SYSTEM_MESSAGE(another_update_function), + UpdateSystemMessage(custom_update_function), + UpdateSystemMessage(another_update_function), ], ) @@ -571,14 +563,14 @@ def hello_world(context_variables: dict) -> SwarmResult: value = "Hello, World!" return SwarmResult(values=value, context_variables=context_variables, agent="agent_2") - # Create SwarmAgent instances - agent_1 = SwarmAgent( + # Create agent instances + agent_1 = ConversableAgent( name="agent_1", system_message="Your task is to call hello_world() function.", llm_config=testing_llm_config, functions=[hello_world], ) - agent_2 = SwarmAgent( + agent_2 = ConversableAgent( name="agent_2", system_message="Your task is to let the user know what the previous agent said.", llm_config=testing_llm_config, @@ -610,7 +602,7 @@ def mock_generate_oai_reply_agent2(*args, **kwargs): agents=[agent_1, agent_2], context_variables={}, messages="Begin by calling the hello_world() function.", - after_work=AFTER_WORK(AfterWorkOption.TERMINATE), + after_work=AfterWork(AfterWorkOption.TERMINATE), max_rounds=5, ) @@ -623,13 +615,13 @@ def hello_world(context_variables: dict) -> SwarmResult: value = "Hello, World!" return SwarmResult(values=value, context_variables=context_variables, agent="agent_unknown") - agent_1 = SwarmAgent( + agent_1 = ConversableAgent( name="agent_1", system_message="Your task is to call hello_world() function.", llm_config=testing_llm_config, functions=[hello_world], ) - agent_2 = SwarmAgent( + agent_2 = ConversableAgent( name="agent_2", system_message="Your task is to let the user know what the previous agent said.", llm_config=testing_llm_config, @@ -647,13 +639,14 @@ def hello_world(context_variables: dict) -> SwarmResult: agents=[agent_1, agent_2], context_variables={}, messages="Begin by calling the hello_world() function.", - after_work=AFTER_WORK(AfterWorkOption.TERMINATE), + after_work=AfterWork(AfterWorkOption.TERMINATE), max_rounds=5, ) def test_after_work_callable(): - """Test Callable in an AFTER_WORK handoff""" + """Test Callable in an AfterWork handoff""" + testing_llm_config = { "config_list": [ { @@ -663,41 +656,44 @@ def test_after_work_callable(): ] } - agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) - agent2 = SwarmAgent("agent2", llm_config=testing_llm_config) - agent3 = SwarmAgent("agent3", llm_config=testing_llm_config) + agent1 = ConversableAgent("agent1", llm_config=testing_llm_config) + agent2 = ConversableAgent("agent2", llm_config=testing_llm_config) + agent3 = ConversableAgent("agent3", llm_config=testing_llm_config) def return_agent( - last_speaker: SwarmAgent, messages: list[dict[str, Any]], groupchat: GroupChat - ) -> Union[AfterWorkOption, SwarmAgent, str]: + last_speaker: ConversableAgent, messages: list[dict[str, Any]], groupchat: GroupChat + ) -> Union[AfterWorkOption, ConversableAgent, str]: return agent2 def return_agent_str( - last_speaker: SwarmAgent, messages: list[dict[str, Any]], groupchat: GroupChat - ) -> Union[AfterWorkOption, SwarmAgent, str]: + last_speaker: ConversableAgent, messages: list[dict[str, Any]], groupchat: GroupChat + ) -> Union[AfterWorkOption, ConversableAgent, str]: return "agent3" def return_after_work_option( - last_speaker: SwarmAgent, messages: list[dict[str, Any]], groupchat: GroupChat - ) -> Union[AfterWorkOption, SwarmAgent, str]: + last_speaker: ConversableAgent, messages: list[dict[str, Any]], groupchat: GroupChat + ) -> Union[AfterWorkOption, ConversableAgent, str]: return AfterWorkOption.TERMINATE - agent1.register_hand_off( + register_hand_off( + agent=agent1, hand_to=[ - AFTER_WORK(agent=return_agent), - ] + AfterWork(agent=return_agent), + ], ) - agent2.register_hand_off( + register_hand_off( + agent=agent2, hand_to=[ - AFTER_WORK(agent=return_agent_str), - ] + AfterWork(agent=return_agent_str), + ], ) - agent3.register_hand_off( + register_hand_off( + agent=agent3, hand_to=[ - AFTER_WORK(agent=return_after_work_option), - ] + AfterWork(agent=return_after_work_option), + ], ) # Fake generate_oai_reply @@ -724,7 +720,8 @@ def mock_generate_oai_reply(*args, **kwargs): def test_on_condition_unique_function_names(): - """Test that ON_CONDITION in handoffs generate unique function names""" + """Test that OnCondition in handoffs generate unique function names""" + testing_llm_config = { "config_list": [ { @@ -734,15 +731,16 @@ def test_on_condition_unique_function_names(): ] } - agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) - agent2 = SwarmAgent("agent2", llm_config=testing_llm_config) + agent1 = ConversableAgent("agent1", llm_config=testing_llm_config) + agent2 = ConversableAgent("agent2", llm_config=testing_llm_config) - agent1.register_hand_off( + register_hand_off( + agent=agent1, hand_to=[ - ON_CONDITION(target=agent2, condition="always take me to agent 2"), - ON_CONDITION(target=agent2, condition="sometimes take me there"), - ON_CONDITION(target=agent2, condition="always take me there"), - ] + OnCondition(target=agent2, condition="always take me to agent 2"), + OnCondition(target=agent2, condition="sometimes take me there"), + OnCondition(target=agent2, condition="always take me there"), + ], ) # Fake generate_oai_reply @@ -786,9 +784,9 @@ def test_prepare_swarm_agents(): } # Create test agents - agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) - agent2 = SwarmAgent("agent2", llm_config=testing_llm_config) - agent3 = SwarmAgent("agent3", llm_config=testing_llm_config) + agent1 = ConversableAgent("agent1", llm_config=testing_llm_config) + agent2 = ConversableAgent("agent2", llm_config=testing_llm_config) + agent3 = ConversableAgent("agent3", llm_config=testing_llm_config) # Add some functions to test tool executor aggregation def test_func1(): @@ -797,11 +795,11 @@ def test_func1(): def test_func2(): pass - agent1.add_single_function(test_func1) - agent2.add_single_function(test_func2) + agent1._add_single_function(test_func1) + agent2._add_single_function(test_func2) # Add handoffs to test validation - agent1.register_hand_off(AFTER_WORK(agent=agent2)) + register_hand_off(agent=agent1, hand_to=AfterWork(agent=agent2)) # Test valid preparation tool_executor, nested_chat_agents = _prepare_swarm_agents(agent1, [agent1, agent2]) @@ -813,14 +811,14 @@ def test_func2(): # Test invalid initial agent type with pytest.raises(AssertionError): - _prepare_swarm_agents(ConversableAgent("invalid"), [agent1, agent2]) + _prepare_swarm_agents(NotConversableAgent("invalid"), [agent1, agent2]) # Test invalid agents list with pytest.raises(AssertionError): - _prepare_swarm_agents(agent1, [agent1, ConversableAgent("invalid")]) + _prepare_swarm_agents(agent1, [agent1, NotConversableAgent("invalid")]) # Test missing handoff agent - agent3.register_hand_off(AFTER_WORK(agent=SwarmAgent("missing"))) + register_hand_off(agent=agent3, hand_to=AfterWork(agent=ConversableAgent("missing"))) with pytest.raises(AssertionError): _prepare_swarm_agents(agent1, [agent1, agent2, agent3]) @@ -836,8 +834,8 @@ def test_create_nested_chats(): ] } - test_agent = SwarmAgent("test_agent", llm_config=testing_llm_config) - test_agent_2 = SwarmAgent("test_agent_2", llm_config=testing_llm_config) + test_agent = ConversableAgent("test_agent", llm_config=testing_llm_config) + test_agent_2 = ConversableAgent("test_agent_2", llm_config=testing_llm_config) nested_chat_agents = [] nested_chat_one = { @@ -857,7 +855,7 @@ def test_create_nested_chats(): "use_async": False, } - test_agent.register_hand_off(ON_CONDITION(target=nested_chat_config, condition="test condition")) + register_hand_off(agent=test_agent, hand_to=OnCondition(target=nested_chat_config, condition="test condition")) # Create nested chats _create_nested_chats(test_agent, nested_chat_agents) @@ -868,14 +866,15 @@ def test_create_nested_chats(): # Verify nested chat configuration # The nested chat agent should have a handoff back to the passed in agent - assert nested_chat_agents[0].after_work.agent == test_agent + assert nested_chat_agents[0]._swarm_after_work.agent == test_agent def test_process_initial_messages(): """Test processing of initial messages in different scenarios""" - agent1 = SwarmAgent("agent1") - agent2 = SwarmAgent("agent2") - nested_agent = SwarmAgent("nested_chat_agent1_1") + + agent1 = ConversableAgent("agent1") + agent2 = ConversableAgent("agent2") + nested_agent = ConversableAgent("nested_chat_agent1_1") user_agent = UserProxyAgent("test_user") # Test single string message @@ -915,9 +914,10 @@ def test_process_initial_messages(): def test_setup_context_variables(): """Test setup of context variables across agents""" - tool_execution = SwarmAgent(__TOOL_EXECUTOR_NAME__) - agent1 = SwarmAgent("agent1") - agent2 = SwarmAgent("agent2") + + tool_execution = ConversableAgent(__TOOL_EXECUTOR_NAME__) + agent1 = ConversableAgent("agent1") + agent2 = ConversableAgent("agent2") groupchat = GroupChat(agents=[tool_execution, agent1, agent2], messages=[]) manager = GroupChatManager(groupchat) @@ -953,8 +953,9 @@ def test_cleanup_temp_user_messages(): @pytest.mark.asyncio async def test_a_initiate_swarm_chat(): """Test async swarm chat""" - agent1 = SwarmAgent("agent1") - agent2 = SwarmAgent("agent2") + + agent1 = ConversableAgent("agent1") + agent2 = ConversableAgent("agent2") user_agent = UserProxyAgent("test_user") # Mock async reply function @@ -990,5 +991,112 @@ async def mock_a_generate_oai_reply(*args, **kwargs): assert context_vars == test_context +def test_update_on_condition_str(): + """Test UpdateOnConditionStr updates condition strings properly for handoffs""" + + testing_llm_config = { + "config_list": [ + { + "model": "gpt-4o", + "api_key": "SAMPLE_API_KEY", + } + ] + } + + agent1 = ConversableAgent("agent1", llm_config=testing_llm_config) + agent2 = ConversableAgent("agent2", llm_config=testing_llm_config) + + # Test container to capture condition + class ConditionContainer: + def __init__(self): + self.captured_condition = None + + condition_container = ConditionContainer() + + # Test with string template + register_hand_off( + agent1, + hand_to=OnCondition( + target=agent2, condition=UpdateCondition(update_function="Transfer when {test_var} is active") + ), + ) + + # Mock LLM responses + def mock_generate_oai_reply_tool_1_2(*args, **kwargs): + # Get the function description (condition) from the agent's function map + func_name = "transfer_agent1_to_agent2" + # Store the condition for verification by accessing the function's description + func = args[0]._function_map[func_name] + condition_container.captured_condition = func._description + return True, { + "role": "assistant", + "name": "agent1", + "tool_calls": [{"type": "function", "function": {"name": func_name}}], + } + + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool_1_2) + agent2.register_reply([ConversableAgent, None], lambda *args, **kwargs: (True, "Response from agent2")) + + # Test string template substitution + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, + messages=TEST_MESSAGES, + agents=[agent1, agent2], + context_variables={"test_var": "condition1"}, + max_rounds=3, + ) + + assert condition_container.captured_condition == "Transfer when condition1 is active" + + # Test with callable function + def custom_update_function(agent: ConversableAgent, messages: list[dict]) -> str: + return f"Transfer based on {agent.get_context('test_var')} with {len(messages)} messages" + + agent3 = ConversableAgent("agent3", llm_config=testing_llm_config) + register_hand_off( + agent2, hand_to=OnCondition(target=agent3, condition=UpdateCondition(update_function=custom_update_function)) + ) + + # Reset condition container + condition_container.captured_condition = None + + def mock_generate_oai_reply_tool_2_3(*args, **kwargs): + # Get the function description (condition) from the agent's function map + func_name = "transfer_agent2_to_agent3" + # Store the condition for verification by accessing the function's description + func = args[0]._function_map[func_name] + condition_container.captured_condition = func._description + return True, { + "role": "assistant", + "name": "agent1", + "tool_calls": [{"type": "function", "function": {"name": func_name}}], + } + + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool_2_3) + agent3.register_reply([ConversableAgent, None], lambda *args, **kwargs: (True, "Response from agent3")) + + # Test callable function update + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent2, + messages=TEST_MESSAGES, + agents=[agent2, agent3], + context_variables={"test_var": "condition2"}, + max_rounds=3, + ) + + assert condition_container.captured_condition == "Transfer based on condition2 with 1 messages" + + # Test invalid update function + with pytest.raises(ValueError, match="Update function must be either a string or a callable"): + UpdateCondition(update_function=123) + + # Test invalid callable signature + def invalid_update_function(x: int) -> str: + return "test" + + with pytest.raises(ValueError, match="Update function must accept two parameters"): + UpdateCondition(update_function=invalid_update_function) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/agentchat/realtime_agent/test_swarm_start.py b/test/agentchat/realtime_agent/test_swarm_start.py index e7d5ecf384..026c7ac2e9 100644 --- a/test/agentchat/realtime_agent/test_swarm_start.py +++ b/test/agentchat/realtime_agent/test_swarm_start.py @@ -12,7 +12,7 @@ from fastapi import FastAPI, WebSocket from fastapi.testclient import TestClient -from autogen.agentchat.contrib.swarm_agent import SwarmAgent +from autogen.agentchat.conversable_agent import ConversableAgent from autogen.agentchat.realtime_agent import RealtimeAgent, RealtimeObserver, WebSocketAudioAdapter from autogen.tools.dependency_injection import Field as AG2Field @@ -57,7 +57,7 @@ async def handle_media_stream(websocket: WebSocket) -> None: def get_weather(location: Annotated[str, AG2Field(description="city")]) -> str: return "The weather is cloudy." if location == "Seattle" else "The weather is sunny." - weatherman = SwarmAgent( + weatherman = ConversableAgent( name="Weatherman", system_message="You are a weatherman. You can answer questions about the weather.", llm_config=credentials_gpt_4o_mini.llm_config, diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index cfa4977157..df6069507e 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# Copyright (c) 2023 - 2025, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 # @@ -19,7 +19,7 @@ from pydantic import BaseModel, Field import autogen -from autogen.agentchat import ConversableAgent, UserProxyAgent +from autogen.agentchat import ConversableAgent, UpdateSystemMessage, UserProxyAgent from autogen.agentchat.conversable_agent import register_function from autogen.exception_utils import InvalidCarryOverType, SenderRequired @@ -588,7 +588,7 @@ def test__wrap_function_sync(): class Currency(BaseModel): currency: CurrencySymbol = Field(description="Currency code") - amount: Annotated[float, Field(default=100.0, description="Amount of money in the currency")] + amount: float = Field(default=100.0, description="Amount of money in the currency") Currency(currency="USD", amount=100.0) @@ -626,7 +626,7 @@ async def test__wrap_function_async(): class Currency(BaseModel): currency: CurrencySymbol = Field(description="Currency code") - amount: Annotated[float, Field(default=100.0, description="Amount of money in the currency")] + amount: float = Field(default=100.0, description="Amount of money in the currency") Currency(currency="USD", amount=100.0) @@ -1537,6 +1537,36 @@ def test_context_variables(): assert agent._context_variables == expected_final_context +def test_invalid_functions_parameter(): + """Test initialization with valid and invalid parameters""" + + # Invalid functions parameter + with pytest.raises(TypeError): + ConversableAgent("test_agent", functions="invalid") + + +def test_update_system_message(): + """Tests the update_agent_state_before_reply functionality with multiple scenarios""" + + # Test invalid update function + with pytest.raises(ValueError, match="The update function must be either a string or a callable"): + ConversableAgent("agent3", update_agent_state_before_reply=UpdateSystemMessage(123)) + + # Test invalid callable (wrong number of parameters) + def invalid_update_function(context_variables): + return "Invalid function" + + with pytest.raises(ValueError, match="The update function must accept two parameters"): + ConversableAgent("agent4", update_agent_state_before_reply=UpdateSystemMessage(invalid_update_function)) + + # Test invalid callable (wrong return type) + def invalid_return_function(context_variables, messages) -> dict: + return {} + + with pytest.raises(ValueError, match="The update function must return a string"): + ConversableAgent("agent5", update_agent_state_before_reply=UpdateSystemMessage(invalid_return_function)) + + if __name__ == "__main__": # test_trigger() # test_context() @@ -1549,4 +1579,5 @@ def test_context_variables(): # test_function_registration_e2e_sync() # test_process_gemini_carryover() # test_process_carryover() - test_context_variables() + # test_context_variables() + test_invalid_functions_parameter() diff --git a/test/test_logging.py b/test/test_logging.py index 552757dbf5..d6a73f5022 100644 --- a/test/test_logging.py +++ b/test/test_logging.py @@ -295,9 +295,9 @@ def build(self): assert result["foo_val"] == expected_foo_val_field assert result["o"] == expected_o_field assert len(result["agents"]) == 2 - for agent in result["agents"]: - assert "autogen.agentchat.conversable_agent.ConversableAgent" in agent - assert "autogen.agentchat.conversable_agent.ConversableAgent" in result["first_agent"] + assert result["agents"][0] == "alice" + assert result["agents"][1] == "bob" + assert "alice" in result["first_agent"] @patch("logging.Logger.error") diff --git a/website/_blogs/2024-11-17-Swarm/index.mdx b/website/_blogs/2024-11-17-Swarm/index.mdx index 0af80b7b83..ebcd16c672 100644 --- a/website/_blogs/2024-11-17-Swarm/index.mdx +++ b/website/_blogs/2024-11-17-Swarm/index.mdx @@ -23,6 +23,12 @@ Besides these core features, AG2 provides: This feature builds on GroupChat, offering a simpler interface to use swarm orchestration. For comparison, see two implementations of the same example: one [using swarm orchestration](/notebooks/agentchat_swarm) and another [naive implementation with GroupChat (Legacy)](/notebooks/agentchat_swarm_w_groupchat_legacy). +````mdx-code-block +:::note +This blog has been updated as swarms can now accommodate any ConversableAgent. +::: +```` + ## Handoffs Before we dive into a swarm example, an important concept in swarm orchestration is when and how an agent hands off to another agent. @@ -32,19 +38,21 @@ Providing additional flexibility, we introduce the capability to define an after The following are the prioritized handoffs for each iteration of the swarm. 1. **Agent-level: Calls a tool that returns a swarm agent**: A swarm agent's tool call returns the next agent to hand off to. -2. **Agent-level: Calls a pre-defined conditional handoff**: A swarm agent has an `ON_CONDITION` handoff that is chosen by the LLM (behaves like a tool call). +2. **Agent-level: Calls a pre-defined conditional handoff**: A swarm agent has an `OnCondition` handoff that is chosen by the LLM (behaves like a tool call). 3. **Agent-level: After work hand off**: When no tool calls are made it can use an, optional, `AFTER_WORK` handoff that is a preset option or a nominated swarm agent. 4. **Swarm-level: After work handoff**: If the agent does not have an `AFTER_WORK` handoff, the swarm's `AFTER_WORK` handoff will be used. -In the following code sample a `SwarmAgent` named `responder` has: -- Two conditional handoffs registered (`ON_CONDITION`), specifying the agent to hand off to and the condition to trigger the handoff. +In the following code sample a `ConversableAgent` named `responder` has: +- Two conditional handoffs registered (`OnCondition`), specifying the agent to hand off to and the condition to trigger the handoff. - An after-work handoff (`AFTER_WORK`) nominated using one of the preset options (`TERMINATE`, `REVERT_TO_USER`, `STAY`). This could also be a swarm agent. ```python -responder.register_hand_off( +from autogen import register_hand_off, OnCondition, AfterWorkOption +register_hand_off( + agent=responder, hand_to=[ - ON_CONDITION(weather, "If you need weather data, hand off to the Weather_Agent"), - ON_CONDITION(travel_advisor, "If you have weather data but need formatted recommendations, hand off to the Travel_Advisor_Agent"), + OnCondition(weather, "If you need weather data, hand off to the Weather_Agent"), + OnCondition(travel_advisor, "If you have weather data but need formatted recommendations, hand off to the Travel_Advisor_Agent"), AFTER_WORK(AfterWorkOption.REVERT_TO_USER), ] ) @@ -63,8 +71,8 @@ history, context, last_agent = initiate_swarm_chat( ## Creating a swarm -1. Define the functions that can be used by your `SwarmAgent`s. -2. Create your `SwarmAgent`s (which derives from `ConversableAgent`). +1. Define the functions that can be used by your `ConversableAgent`s. +2. Create your `ConversableAgent`s. 3. For each swarm agent, specify the handoffs (transitions to another agent) and what to do when they have finished their work (termed *After Work*). 4. Optionally, create your context dictionary. 5. Call `initiate_swarm_chat`. @@ -74,7 +82,7 @@ history, context, last_agent = initiate_swarm_chat( This example of managing refunds demonstrates the context handling, swarm and agent-level conditional and after work hand offs, and the human-in-the-loop feature. ```python -from autogen import initiate_swarm_chat, SwarmAgent, SwarmResult, ON_CONDITION, AFTER_WORK, AfterWorkOption +from autogen import initiate_swarm_chat, ConversableAgent, SwarmResult, OnCondition, AFTER_WORK, AfterWorkOption from autogen import UserProxyAgent import os @@ -90,7 +98,7 @@ context_variables = { } # Functions that our swarm agents will be assigned -# They can return a SwarmResult, a SwarmAgent, or a string +# They can return a SwarmResult, a ConversableAgent, or a string # SwarmResult allows you to update context_variables and/or hand off to another agent def verify_customer_identity(passport_number: str, context_variables: dict) -> str: context_variables["passport_number"] = passport_number @@ -106,7 +114,7 @@ def process_refund_payment(context_variables: dict) -> str: return SwarmResult(values="Payment processed successfully", context_variables=context_variables) # Swarm Agents, similar to ConversableAgent, but with functions and hand offs (specified later) -customer_service = SwarmAgent( +customer_service = ConversableAgent( name="CustomerServiceRep", system_message="""You are a customer service representative. First verify the customer's identity by asking for the customer's passport number, @@ -116,7 +124,7 @@ customer_service = SwarmAgent( functions=[verify_customer_identity], ) -refund_specialist = SwarmAgent( +refund_specialist = ConversableAgent( name="RefundSpecialist", system_message="""You are a refund specialist. Review the case and approve the refund, then transfer to the payment processor.""", @@ -124,7 +132,7 @@ refund_specialist = SwarmAgent( functions=[approve_refund_and_transfer], ) -payment_processor = SwarmAgent( +payment_processor = ConversableAgent( name="PaymentProcessor", system_message="""You are a payment processor. Process the refund payment and provide a confirmation message to the customer.""", @@ -132,7 +140,7 @@ payment_processor = SwarmAgent( functions=[process_refund_payment], ) -satisfaction_surveyor = SwarmAgent( +satisfaction_surveyor = ConversableAgent( name="SatisfactionSurveyor", system_message="""You are a customer satisfaction specialist. Ask the customer to rate their experience with the refund process.""", @@ -141,14 +149,16 @@ satisfaction_surveyor = SwarmAgent( # Conditional and After work hand offs -customer_service.register_hand_off( +register_hand_off( + agent=customer_service, hand_to=[ - ON_CONDITION(refund_specialist, "After customer verification, transfer to refund specialist"), + OnCondition(refund_specialist, "After customer verification, transfer to refund specialist"), AFTER_WORK(AfterWorkOption.REVERT_TO_USER) ] ) -payment_processor.register_hand_off( +register_hand_off( + agent=payment_processor, hand_to=[ AFTER_WORK(satisfaction_surveyor), ] @@ -245,7 +255,7 @@ Next speaker: Tool_Execution Tool_Execution (to chat_manager): ***** Response from calling tool (call_Jz1viRLeJuOltPRcKfYZ8bgH) ***** -SwarmAgent --> RefundSpecialist +Swarm agent --> RefundSpecialist ********************************************************************** -------------------------------------------------------------------------------- @@ -330,7 +340,6 @@ Context Variables: ### Notes - Behind-the-scenes, swarm agents are supported by a tool execution agent, that executes tools on their behalf. Hence, the appearance of `Tool Execution` in the output. -- Currently only swarm agents can be added to a swarm. This is to maintain their ability to manage context variables, auto-execute functions, and support hand offs. Eventually, we may allow ConversableAgent to have the same capability and make "SwarmAgent" a simpler subclass with certain defaults changed (like AssistantAgent and UserProxyAgent). - Would you like to enhance the swarm feature or have found a bug? Please let us know by creating an issue on the [AG2 GitHub](https://github.com/ag2ai/ag2/issues). ## For Further Reading diff --git a/website/_blogs/2024-12-20-RealtimeAgent/index.mdx b/website/_blogs/2024-12-20-RealtimeAgent/index.mdx index 5180e4c083..559d1c0b95 100644 --- a/website/_blogs/2024-12-20-RealtimeAgent/index.mdx +++ b/website/_blogs/2024-12-20-RealtimeAgent/index.mdx @@ -165,7 +165,7 @@ FLIGHT_CANCELLATION_POLICY = """ #### **Agents Definition** ```python -triage_agent = SwarmAgent( +triage_agent = ConversableAgent( name="Triage_Agent", system_message=triage_instructions(context_variables=context_variables), llm_config=llm_config, @@ -175,7 +175,7 @@ triage_agent = SwarmAgent( - **Triage Agent:** Routes the user's request to the appropriate specialized agent based on the topic. ```python -flight_cancel = SwarmAgent( +flight_cancel = ConversableAgent( name="Flight_Cancel_Traversal", system_message=STARTER_PROMPT + FLIGHT_CANCELLATION_POLICY, llm_config=llm_config, @@ -185,10 +185,11 @@ flight_cancel = SwarmAgent( - **Flight Cancel Agent:** Handles cancellations, including refunds and flight credits, while ensuring policy steps are strictly followed. ```python -flight_modification.register_hand_off( - [ - ON_CONDITION(flight_cancel, "To cancel a flight"), - ON_CONDITION(flight_change, "To change a flight"), +register_hand_off( + agent=flight_modification, + hand_to=[ + OnCondition(flight_cancel, "To cancel a flight"), + OnCondition(flight_change, "To change a flight"), ] ) ``` diff --git a/website/docs/topics/swarm.ipynb b/website/docs/topics/swarm.ipynb index 53caaee7e4..60d7905ba0 100644 --- a/website/docs/topics/swarm.ipynb +++ b/website/docs/topics/swarm.ipynb @@ -11,7 +11,7 @@ "- **Headoffs**: Agents can transfer control to another agent via function calls, enabling smooth transitions within workflows. \n", "- **Context Variables**: Agents can dynamically update shared variables through function calls, maintaining context and adaptability throughout the process.\n", "\n", - "Instead of sending a task to a single LLM agent, you can assign it to a swarm of agents. Each agent in the swarm can decide whether to hand off the task to another agent. The chat terminates when the last active agent's response is a plain string (i.e., it doesn't suggest a tool call or handoff). " + "Instead of sending a task to a single LLM agent, you can assign it to a swarm of agents. Each agent in the swarm can decide whether to hand off the task to another agent. The chat terminates when the last active agent's response is a plain string (i.e., it doesn't suggest a tool call or handoff)." ] }, { @@ -21,9 +21,11 @@ "## Components\n", "We now introduce the main components that need to be used to create a swarm chat. \n", "\n", - "### Create a `SwarmAgent`\n", + "### Agents\n", "\n", - "All the agents passed to the swarm chat should be instances of `SwarmAgent`. `SwarmAgent` is very similar to `AssistantAgent`, but it has some additional features to support function registration and handoffs. When creating a `SwarmAgent`, you can pass in a list of functions. These functions will be converted to schemas to be passed to the LLMs, and you don't need to worry about registering the functions for execution. You can also pass back a `SwarmResult` class, where you can return a value, the next agent to call, and update context variables at the same time.\n", + "Any ConversableAgent-based agent can participate in a swarm. Agents will automatically be given additional features to support their participation in the swarm.\n", + "\n", + "When creating an agent, you can pass in a list of functions (through the `functions` parameter upon initialization). These functions will be converted to schemas to be passed to the LLMs, and you don't need to worry about registering the functions for execution. You can also pass back a `SwarmResult` class, where you can return a value, the next agent to call, and update context variables at the same time.\n", "\n", "**Notes for creating the function calls** \n", "- For input arguments, you must define the type of the argument, otherwise, the registration will fail (e.g. `arg_name: str`). \n", @@ -32,15 +34,17 @@ "- The function name will be used as the tool name.\n", "\n", "### Registering Handoffs to agents\n", - "While you can create a function to decide what next agent to call, we provide a quick way to register the handoff using `ON_CONDITION`. We will craft this transition function and add it to the LLM config directly.\n", + "While you can create a function to decide what next agent to call, we provide a quick way to register the handoff using `OnCondition`. We will craft this transition function and add it to the LLM config directly.\n", "\n", "```python\n", - "agent_2 = SwarmAgent(...)\n", - "agent_3 = SwarmAgent(...)\n", + "from autogen import register_hand_off, ConversableAgent, OnCondition\n", + "\n", + "agent_2 = ConversableAgent(...)\n", + "agent_3 = ConversableAgent(...)\n", "\n", - "# Register the handoff\n", - "agent_1 = SwarmAgent(...)\n", - "agent_1.handoff(hand_to=[ON_CONDITION(agent_2, \"condition_1\"), ON_CONDITION(agent_3, \"condition_2\")])\n", + "# Register the handoff using register_hand_off\n", + "agent_1 = ConversableAgent(...)\n", + "register_hand_off(agent=agent_1, hand_to=[OnCondition(agent_2, \"condition_1\"), OnCondition(agent_3, \"condition_2\")])\n", "\n", "# This is equivalent to:\n", "def transfer_to_agent_2():\n", @@ -51,12 +55,30 @@ " \"\"\"condition_2\"\"\"\n", " return agent_3\n", " \n", - "agent_1 = SwarmAgent(..., functions=[transfer_to_agent_2, transfer_to_agent_3])\n", - "# You can also use agent_1.add_functions to add more functions after initialization\n", + "agent_1 = ConversableAgent(..., functions=[transfer_to_agent_2, transfer_to_agent_3])\n", + "# You can also use agent_1._add_functions to add more functions after initialization\n", "```\n", "\n", + "### UpdateCondition\n", + "`UpdateCondition` offers a simple way to set up a boolean expression using context variables within `OnCondition`. Its functionality and implementation are quite similar to `UpdateSystemMessage` in that it will substitute in the context variables, allowing you to make use of them in the condition's string.\n", + "\n", + "The following code realizes the following logic:\n", + "- if context_variables['condition'] == 1, transfer to agent_1 \n", + "- if context_variables['condition'] == 3, transfer to agent_3\n", + "\n", + "```python\n", + "register_hand_off(\n", + " agent=agent_2,\n", + " hand_to=[\n", + " OnCondition(agent_1, \"transfer back to agent 1 if {condition} == 1\"),\n", + " OnCondition(agent_3, \"transfer back to agent 3 if {condition} == 3\")\n", + " ]\n", + ")\n", + "```\n", + "\n", + "\n", "### Registering Handoffs to a nested chat\n", - "In addition to transferring to an agent, you can also trigger a nested chat by doing a handoff and using `ON_CONDITION`. This is a useful way to perform sub-tasks without that work becoming part of the broader swarm's messages.\n", + "In addition to transferring to an agent, you can also trigger a nested chat by doing a handoff and using `OnCondition`. This is a useful way to perform sub-tasks without that work becoming part of the broader swarm's messages.\n", "\n", "Configuring the nested chat is similar to [establishing a nested chat for an agent](https://docs.ag2.ai/docs/tutorial/conversation-patterns#nested-chats).\n", "\n", @@ -121,8 +143,9 @@ "Finally, we add the nested chat as a handoff in the same way as we do to an agent:\n", "\n", "```python\n", - "agent_1.handoff(\n", - " hand_to=[ON_CONDITION(\n", + "register_hand_off(\n", + " agent=agent_1,\n", + " hand_to=[OnCondition(\n", " target={\n", " \"chat_queue\":[nested_chats],\n", " \"config\": Any,\n", @@ -144,11 +167,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### AFTER_WORK\n", + "### AfterWork\n", "\n", - "When the active agent's response doesn't suggest a tool call or handoff, the chat will terminate by default. However, you can register an `AFTER_WORK` handoff to control what to do next. You can register these `AFTER_WORK` handoffs at the agent level and also the swarm level (through the `after_work` parameter on `initiate_swarm_chat`). The agent level takes precedence over the swarm level.\n", + "When the active agent's response doesn't suggest a tool call or handoff, the chat will terminate by default. However, you can register an `AfterWork` handoff to control what to do next. You can register these `AfterWork` handoffs at the agent level and also the swarm level (through the `after_work` parameter on `initiate_swarm_chat`). The agent level takes precedence over the swarm level.\n", "\n", - "The AFTER_WORK takes a single parameter and this can be an agent, an agent's name, an `AfterWorkOption`, or a callable function.\n", + "The AfterWork takes a single parameter and this can be an agent, an agent's name, an `AfterWorkOption`, or a callable function.\n", "\n", "The `AfterWorkOption` options are:\n", "- `TERMINATE`: Terminate the chat \n", @@ -156,33 +179,33 @@ "- `REVERT_TO_USER`: Revert to the user agent. Only if a user agent is passed in when initializing. (See below for more details)\n", "\n", "The callable function signature is:\n", - "`def my_after_work_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, SwarmAgent, str]:`\n", + "`def my_after_work_func(last_speaker: ConversableAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, ConversableAgent, str]:`\n", "\n", - "Note: there should only be one `AFTER_WORK`, if your requirement is more complex, use the callable function parameter.\n", + "Note: there should only be one `AfterWork`, if your requirement is more complex, use the callable function parameter.\n", "\n", - "Here are examples of registering AFTER_WORKS\n", + "Here are examples of registering AfterWork's.\n", "\n", "```python\n", "# Register the handoff to an agent\n", - "agent_1.handoff(hand_to=[\n", - " ON_CONDITION(...), \n", - " ON_CONDITION(...),\n", - " AFTER_WORK(agent_4) # Fallback to agent_4 if no ON_CONDITION handoff is suggested\n", + "register_hand_off(agent=agent_1, hand_to=[\n", + " OnCondition(...), \n", + " OnCondition(...),\n", + " AfterWork(agent_4) # Fallback to agent_4 if no OnCondition handoff is suggested\n", "])\n", "\n", "# Register the handoff to an AfterWorkOption\n", - "agent_2.handoff(hand_to=[AFTER_WORK(AfterWorkOption.TERMINATE)]) # Terminate the chat if no handoff is suggested\n", + "register_hand_off(agent=agent_2, hand_to=[AfterWork(AfterWorkOption.TERMINATE)]) # Terminate the chat if no handoff is suggested\n", "\n", - "def my_after_work_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, SwarmAgent, str]:\n", + "def my_after_work_func(last_speaker: ConversableAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, ConversableAgent, str]:\n", " if last_speaker.get_context(\"agent_1_done\"):\n", " return agent_2\n", " else:\n", " return AfterWorkOption.TERMINATE\n", "\n", "# Register the handoff to a function that will return an agent or AfterWorkOption\n", - "agent_3.handoff(hand_to=[AFTER_WORK(my_after_work_func)])\n", + "register_hand_off(agent=agent_3, hand_to=[AfterWork(my_after_work_func)])\n", "\n", - "# Register the swarm level AFTER_WORK that becomes the default for agents that don't have one specified\n", + "# Register the swarm level AfterWork that becomes the default for agents that don't have one specified\n", "chat_history, context_variables, last_active_agent = initiate_swarm_chat(\n", " ...\n", " after_work=AfterWorkOption.TERMINATE # Or an agent or Callable\n", @@ -200,16 +223,18 @@ "\n", "It can be useful to update a swarm agent's state before they reply. For example, using an agent's context variables you could change their system message based on the state of the workflow.\n", "\n", - "When initialising a swarm agent use the `update_agent_state_before_reply` parameter to register updates that run after the agent is selected, but before they reply.\n", + "When initialising an agent use the `update_agent_state_before_reply` parameter to register updates that run after the agent is selected, but before they reply.\n", "\n", "`update_agent_state_before_reply` takes a list of any combination of the following (executing them in the provided order):\n", "\n", - "- `UPDATE_SYSTEM_MESSAGE` provides a simple way to update the agent's system message via an f-string that substitutes the values of context variables, or a Callable that returns a string\n", + "- `UpdateSystemMessage` provides a simple way to update the agent's system message via an f-string that substitutes the values of context variables, or a Callable that returns a string\n", "- Callable with two parameters of type `ConversableAgent` for the agent and `List[Dict[str Any]]` for the messages, and does not return a value\n", "\n", "Below is an example of setting these up when creating a Swarm agent.\n", "\n", "```python\n", + "from autogen import UpdateSystemMessage, ConversableAgent\n", + "\n", "# Creates a system message string\n", "def create_system_prompt_function(my_agent: ConversableAgent, messages: List[Dict[]]) -> str:\n", " preferred_name = my_agent.get_context(\"preferred_name\", \"(name not provided)\")\n", @@ -224,13 +249,13 @@ " agent.set_context(\"context_key\", 43)\n", " agent.update_system_message(\"You are a customer service representative.\")\n", "\n", - "# Create the SwarmAgent and set agent updates\n", - "customer_service = SwarmAgent(\n", + "# Create the swarm agent and set agent updates\n", + "customer_service = ConversableAgent(\n", " name=\"CustomerServiceRep\",\n", " system_message=\"You are a customer service representative.\",\n", " update_agent_state_before_reply=[\n", - " UPDATE_SYSTEM_MESSAGE(\"You are a customer service representative. Quote passport number '{passport_number}'\"),\n", - " UPDATE_SYSTEM_MESSAGE(create_system_prompt_function),\n", + " UpdateSystemMessage(\"You are a customer service representative. Quote passport number '{passport_number}'\"),\n", + " UpdateSystemMessage(create_system_prompt_function),\n", " my_callable_state_update_function]\n", " ...\n", ")\n", @@ -267,11 +292,11 @@ "\n", "In a swarm, the context variables are shared amongst Swarm agents. As context variables are available at the agent level, you can use the context variable getters/setters on the agent to view and change the shared context variables. If you're working with a function that returns a `SwarmResult` you should update the passed in context variables and return it in the `SwarmResult`, this will ensure the shared context is updated.\n", "\n", - "> What is the difference between ON_CONDITION and AFTER_WORK?\n", + "> What is the difference between OnCondition and AfterWork?\n", "\n", - "When registering an ON_CONDITION handoff, we are creating a function schema to be passed to the LLM. The LLM will decide whether to call this function.\n", + "When registering an OnCondition handoff, we are creating a function schema to be passed to the LLM. The LLM will decide whether to call this function.\n", "\n", - "When registering an AFTER_WORK handoff, we are defining the fallback mechanism when no tool calls are suggested. This is a higher level of control from the swarm chat level.\n", + "When registering an AfterWork handoff, we are defining the fallback mechanism when no tool calls are suggested. This is a higher level of control from the swarm chat level.\n", "\n", "> When to pass in a user agent?\n", "\n", @@ -317,12 +342,13 @@ "import random\n", "\n", "from autogen import (\n", - " AFTER_WORK,\n", - " ON_CONDITION,\n", + " AfterWork,\n", " AfterWorkOption,\n", - " SwarmAgent,\n", + " ConversableAgent,\n", + " OnCondition,\n", " SwarmResult,\n", " initiate_swarm_chat,\n", + " register_hand_off,\n", ")\n", "\n", "\n", @@ -332,8 +358,8 @@ " return SwarmResult(value=\"success\", context_variables=context_variables)\n", "\n", "\n", - "# 2. A function that returns an SwarmAgent object\n", - "def transfer_to_agent_2() -> SwarmAgent:\n", + "# 2. A function that returns a ConversableAgent object\n", + "def transfer_to_agent_2() -> ConversableAgent:\n", " \"\"\"Transfer to agent 2\"\"\"\n", " return agent_2\n", "\n", @@ -354,34 +380,34 @@ " return SwarmResult(value=\"success\", context_variables=context_variables)\n", "\n", "\n", - "agent_1 = SwarmAgent(\n", + "agent_1 = ConversableAgent(\n", " name=\"Agent_1\",\n", " system_message=\"You are Agent 1, first, call the function to update context 1, and transfer to Agent 2\",\n", " llm_config=llm_config,\n", " functions=[update_context_1, transfer_to_agent_2],\n", ")\n", "\n", - "agent_2 = SwarmAgent(\n", + "agent_2 = ConversableAgent(\n", " name=\"Agent_2\",\n", " system_message=\"You are Agent 2, call the function that updates context 2 and transfer to Agent 3\",\n", " llm_config=llm_config,\n", " functions=[update_context_2_and_transfer_to_3],\n", ")\n", "\n", - "agent_3 = SwarmAgent(\n", + "agent_3 = ConversableAgent(\n", " name=\"Agent_3\",\n", " system_message=\"You are Agent 3, tell a joke\",\n", " llm_config=llm_config,\n", ")\n", "\n", - "agent_4 = SwarmAgent(\n", + "agent_4 = ConversableAgent(\n", " name=\"Agent_4\",\n", " system_message=\"You are Agent 4, call the function to get a random number\",\n", " llm_config=llm_config,\n", " functions=[get_random_number],\n", ")\n", "\n", - "agent_5 = SwarmAgent(\n", + "agent_5 = ConversableAgent(\n", " name=\"Agent_5\",\n", " system_message=\"Update context 3 with the random number.\",\n", " llm_config=llm_config,\n", @@ -390,9 +416,9 @@ "\n", "\n", "# This is equivalent to writing a transfer function\n", - "agent_3.register_hand_off(ON_CONDITION(agent_4, \"Transfer to Agent 4\"))\n", + "register_hand_off(agent=agent_3, hand_to=OnCondition(agent_4, \"Transfer to Agent 4\"))\n", "\n", - "agent_4.register_hand_off([AFTER_WORK(agent_5)])\n", + "register_hand_off(agent=agent_4, hand_to=[AfterWork(agent_5)])\n", "\n", "print(\"Agent 1 function schema:\")\n", "for func_schema in agent_1.llm_config[\"tools\"]:\n", @@ -576,7 +602,7 @@ " agents=[agent_1, agent_2, agent_3, agent_4, agent_5],\n", " messages=\"start\",\n", " context_variables=context_variables,\n", - " after_work=AFTER_WORK(AfterWorkOption.TERMINATE), # this is the default\n", + " after_work=AfterWork(AfterWorkOption.TERMINATE), # this is the default\n", ")" ] }, @@ -603,12 +629,12 @@ "source": [ "### Demo with User Agent\n", "\n", - "We pass in a user agent to the swarm chat to accept user inputs. With `agent_6`, we register an `AFTER_WORK` handoff to revert to the user agent when no tool calls are suggested. " + "We pass in a user agent to the swarm chat to accept user inputs. With `agent_6`, we register an `AfterWork` handoff to revert to the user agent when no tool calls are suggested. " ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -701,25 +727,26 @@ "\n", "user_agent = UserProxyAgent(name=\"User\", code_execution_config=False)\n", "\n", - "agent_6 = SwarmAgent(\n", + "agent_6 = ConversableAgent(\n", " name=\"Agent_6\",\n", " system_message=\"You are Agent 6. Your job is to tell jokes.\",\n", " llm_config=llm_config,\n", ")\n", "\n", - "agent_7 = SwarmAgent(\n", + "agent_7 = ConversableAgent(\n", " name=\"Agent_7\",\n", " system_message=\"You are Agent 7, explain the joke.\",\n", " llm_config=llm_config,\n", ")\n", "\n", - "agent_6.register_hand_off(\n", - " [\n", - " ON_CONDITION(\n", + "register_hand_off(\n", + " agent=agent_6,\n", + " hand_to=[\n", + " OnCondition(\n", " agent_7, \"Used to transfer to Agent 7. Don't call this function, unless the user explicitly tells you to.\"\n", " ),\n", - " AFTER_WORK(AfterWorkOption.REVERT_TO_USER),\n", - " ]\n", + " AfterWork(AfterWorkOption.REVERT_TO_USER),\n", + " ],\n", ")\n", "\n", "chat_result, _, _ = initiate_swarm_chat(\n", diff --git a/website/snippets/data/NotebooksMetadata.mdx b/website/snippets/data/NotebooksMetadata.mdx index 812b7e9d13..7279592e2f 100644 --- a/website/snippets/data/NotebooksMetadata.mdx +++ b/website/snippets/data/NotebooksMetadata.mdx @@ -991,61 +991,5 @@ export const notebooksMetadata = [ "pydanticai" ], "source": "/notebook/tools_interoperability.ipynb" - }, - { - "title": "Agentic RAG workflow on tabular data from a PDF file", - "link": "/notebooks/agentchat_tabular_data_rag_workflow", - "description": "Agentic RAG workflow on tabular data from a PDF file", - "image": null, - "tags": [ - "RAG", - "groupchat" - ], - "source": "/notebook/agentchat_tabular_data_rag_workflow.ipynb" - }, - { - "title": "RealtimeAgent with WebRTC connection", - "link": "/notebooks/agentchat_realtime_webrtc", - "description": "RealtimeAgent using websockets", - "image": null, - "tags": [ - "realtime", - "websockets" - ], - "source": "/notebook/agentchat_realtime_webrtc.ipynb" - }, - { - "title": "Tools with Dependency Injection", - "link": "/notebooks/tools_dependency_injection", - "description": "Tools Dependency Injection", - "image": null, - "tags": [ - "tools", - "dependency injection", - "function calling" - ], - "source": "/notebook/tools_dependency_injection.ipynb" - }, - { - "title": "Chat Context Dependency Injection", - "link": "/notebooks/tools_chat_context_dependency_injection", - "description": "Chat Context Dependency Injection", - "image": null, - "tags": [ - "tools", - "dependency injection", - "function calling" - ], - "source": "/notebook/tools_chat_context_dependency_injection.ipynb" - }, - { - "title": "Using Neo4j's native GraphRAG SDK with AG2 agents for Question & Answering", - "link": "/notebooks/agentchat_graph_rag_neo4j_native", - "description": "Neo4j Native GraphRAG utilizes a knowledge graph and can be added as a capability to agents.", - "image": null, - "tags": [ - "RAG" - ], - "source": "/notebook/agentchat_graph_rag_neo4j_native.ipynb" } ];