Skip to content

Commit

Permalink
Add parallel tool call test
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu committed Jan 17, 2025
1 parent 21263b9 commit 1e65988
Showing 1 changed file with 138 additions and 2 deletions.
140 changes: 138 additions & 2 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core import FunctionCall, Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import LLMMessage
from autogen_core.models import FunctionExecutionResult, LLMMessage
from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
Expand Down Expand Up @@ -281,6 +281,142 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
assert state == state2


@pytest.mark.asyncio
async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task1"}),
),
),
ChatCompletionMessageToolCall(
id="2",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task2"}),
),
),
ChatCompletionMessageToolCall(
id="3",
type="function",
function=Function(
name="_echo_function",
arguments=json.dumps({"input": "task3"}),
),
),
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="pass", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
result = await agent.run(task="task")

assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].content == [
FunctionCall(id="1", arguments=r'{"input": "task1"}', name="_pass_function"),
FunctionCall(id="2", arguments=r'{"input": "task2"}', name="_pass_function"),
FunctionCall(id="3", arguments=r'{"input": "task3"}', name="_echo_function"),
]
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
expected_content = [
FunctionExecutionResult(call_id="1", content="pass"),
FunctionExecutionResult(call_id="2", content="pass"),
FunctionExecutionResult(call_id="3", content="task3"),
]
for expected in expected_content:
assert expected in result.messages[2].content
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass\npass\ntask3"
assert result.messages[3].models_usage is None

# Test streaming.
mock.curr_index = 0 # Reset the mock
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1

# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2


@pytest.mark.asyncio
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
handoff = Handoff(target="agent2")
Expand Down

0 comments on commit 1e65988

Please sign in to comment.