Skip to content

Commit

Permalink
Make streaming vercel compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
WonderPG committed Jan 10, 2025
1 parent 818d7c0 commit 38a6934
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 27 deletions.
69 changes: 48 additions & 21 deletions src/neuroagent/agent_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async def get_chat_completion(
"messages": messages,
"tools": tools or None,
"tool_choice": agent.tool_choice,
"stream_options": {"include_usage": True},
"stream": stream,
}

Expand Down Expand Up @@ -307,7 +308,6 @@ async def astream(
content = await messages_to_openai_content(messages)
history = copy.deepcopy(content)
init_len = len(messages)
is_streaming = False

while len(history) - init_len < max_turns:
message: dict[str, Any] = {
Expand All @@ -333,27 +333,54 @@ async def astream(
model_override=model_override,
stream=True,
)
draft_tool_calls = []
draft_tool_calls_index = -1
async for chunk in completion: # type: ignore
delta = json.loads(chunk.choices[0].delta.model_dump_json())

# Check for tool calls
if delta["tool_calls"]:
tool = delta["tool_calls"][0]["function"]
if tool["name"]:
yield f"\nCalling tool : {tool['name']} with arguments : "
if tool["arguments"]:
yield tool["arguments"]

# Check for content
if delta["content"]:
if not is_streaming:
yield "\n<begin_llm_response>\n"
is_streaming = True
yield delta["content"]

delta.pop("role", None)
merge_chunk(message, delta)

for choice in chunk.choices:
if choice.finish_reason == "stop":
continue

elif choice.finish_reason == "tool_calls":
for tool_call in draft_tool_calls:
yield f"9:{{'toolCallId':'{tool_call['id']}','toolName':'{tool_call['name']}','args':{tool_call['arguments']}}}\n"

# Check for tool calls
elif choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
id = tool_call.id
name = tool_call.function.name
arguments = tool_call.function.arguments

if id is not None:
draft_tool_calls_index += 1
draft_tool_calls.append(
{"id": id, "name": name, "arguments": ""}
)
yield f"b:{{'toolCallId':{id},'toolName':{name}}}\n"

else:
draft_tool_calls[draft_tool_calls_index][
"arguments"
] += arguments
yield f"c:{{toolCallId:{id}; argsTextDelta:{arguments}}}\n"

else:
yield f"0:{json.dumps(choice.delta.content)}\n"

delta_json = choice.delta.model_dump()
delta_json.pop("role", None)
merge_chunk(message, delta_json)

if chunk.choices == []:
usage = chunk.usage
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens

yield 'd:{{"finishReason":"{reason}","usage":{{"promptTokens":{prompt},"completionTokens":{completion}}}}}\n'.format(
reason="tool-calls" if len(draft_tool_calls) > 0 else "stop",
prompt=prompt_tokens,
completion=completion_tokens,
)
message["tool_calls"] = list(message.get("tool_calls", {}).values())
if not message["tool_calls"]:
message["tool_calls"] = None
Expand Down
5 changes: 3 additions & 2 deletions src/neuroagent/app/routers/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,17 @@ async def get_tool_returns(
return tool_output


@router.patch("/validate/{thread_id}")
@router.patch("/validate/{thread_id}/{tool_call_id}")
async def validate_input(
user_request: HILValidation,
_: Annotated[Threads, Depends(get_thread)],
tool_call_id: str,
session: Annotated[AsyncSession, Depends(get_session)],
starting_agent: Annotated[Agent, Depends(get_starting_agent)],
) -> ToolCallSchema:
"""Validate HIL inputs."""
# We first find the AI TOOL message to modify.
tool_call = await session.get(ToolCalls, user_request.tool_call_id)
tool_call = await session.get(ToolCalls, tool_call_id)
if not tool_call:
raise HTTPException(status_code=404, detail="Specified tool call not found.")
if tool_call.validated is not None:
Expand Down
1 change: 0 additions & 1 deletion src/neuroagent/new_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class HILResponse(BaseModel):
class HILValidation(BaseModel):
"""Class to send the validated json to the api."""

tool_call_id: str
validated_inputs: dict[str, Any] | None = None
is_validated: bool = True

Expand Down
5 changes: 2 additions & 3 deletions src/neuroagent/stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Wrapper around streaming methods to reinitiate connections due to the way fastAPI StreamingResponse works."""

import json
from typing import Any, AsyncIterator

from fastapi import Request
Expand Down Expand Up @@ -51,9 +52,7 @@ async def stream_agent_response(
yield chunk
# Final chunk that contains the whole response
elif chunk.hil_messages:
yield str(
[hil_message.model_dump_json() for hil_message in chunk.hil_messages]
)
yield f"2:{json.dumps([hil_message.model_dump_json() for hil_message in chunk.hil_messages])}\n"

# Save the new messages in DB
thread.update_date = utc_now()
Expand Down

0 comments on commit 38a6934

Please sign in to comment.