diff --git a/src/neuroagent/agent_routine.py b/src/neuroagent/agent_routine.py index d3906fd..4caa82c 100644 --- a/src/neuroagent/agent_routine.py +++ b/src/neuroagent/agent_routine.py @@ -53,6 +53,7 @@ async def get_chat_completion( "messages": messages, "tools": tools or None, "tool_choice": agent.tool_choice, + "stream_options": {"include_usage": True}, "stream": stream, } @@ -307,7 +308,6 @@ async def astream( content = await messages_to_openai_content(messages) history = copy.deepcopy(content) init_len = len(messages) - is_streaming = False while len(history) - init_len < max_turns: message: dict[str, Any] = { @@ -333,27 +333,54 @@ async def astream( model_override=model_override, stream=True, ) + draft_tool_calls = [] + draft_tool_calls_index = -1 async for chunk in completion: # type: ignore - delta = json.loads(chunk.choices[0].delta.model_dump_json()) - - # Check for tool calls - if delta["tool_calls"]: - tool = delta["tool_calls"][0]["function"] - if tool["name"]: - yield f"\nCalling tool : {tool['name']} with arguments : " - if tool["arguments"]: - yield tool["arguments"] - - # Check for content - if delta["content"]: - if not is_streaming: - yield "\n\n" - is_streaming = True - yield delta["content"] - - delta.pop("role", None) - merge_chunk(message, delta) - + for choice in chunk.choices: + if choice.finish_reason == "stop": + continue + + elif choice.finish_reason == "tool_calls": + for tool_call in draft_tool_calls: + yield f"9:{{'toolCallId':'{tool_call['id']}','toolName':'{tool_call['name']}','args':{tool_call['arguments']}}}\n" + + # Check for tool calls + elif choice.delta.tool_calls: + for tool_call in choice.delta.tool_calls: + id = tool_call.id + name = tool_call.function.name + arguments = tool_call.function.arguments + + if id is not None: + draft_tool_calls_index += 1 + draft_tool_calls.append( + {"id": id, "name": name, "arguments": ""} + ) + yield f"b:{{'toolCallId':{id},'toolName':{name}}}\n" + + else: + draft_tool_calls[draft_tool_calls_index][ + "arguments" + ] += arguments + yield f"c:{{toolCallId:{id}; argsTextDelta:{arguments}}}\n" + + else: + yield f"0:{json.dumps(choice.delta.content)}\n" + + delta_json = choice.delta.model_dump() + delta_json.pop("role", None) + merge_chunk(message, delta_json) + + if chunk.choices == []: + usage = chunk.usage + prompt_tokens = usage.prompt_tokens + completion_tokens = usage.completion_tokens + + yield 'd:{{"finishReason":"{reason}","usage":{{"promptTokens":{prompt},"completionTokens":{completion}}}}}\n'.format( + reason="tool-calls" if len(draft_tool_calls) > 0 else "stop", + prompt=prompt_tokens, + completion=completion_tokens, + ) message["tool_calls"] = list(message.get("tool_calls", {}).values()) if not message["tool_calls"]: message["tool_calls"] = None diff --git a/src/neuroagent/app/routers/tools.py b/src/neuroagent/app/routers/tools.py index 853a427..f43fa76 100644 --- a/src/neuroagent/app/routers/tools.py +++ b/src/neuroagent/app/routers/tools.py @@ -113,16 +113,17 @@ async def get_tool_returns( return tool_output -@router.patch("/validate/{thread_id}") +@router.patch("/validate/{thread_id}/{tool_call_id}") async def validate_input( user_request: HILValidation, _: Annotated[Threads, Depends(get_thread)], + tool_call_id: str, session: Annotated[AsyncSession, Depends(get_session)], starting_agent: Annotated[Agent, Depends(get_starting_agent)], ) -> ToolCallSchema: """Validate HIL inputs.""" # We first find the AI TOOL message to modify. - tool_call = await session.get(ToolCalls, user_request.tool_call_id) + tool_call = await session.get(ToolCalls, tool_call_id) if not tool_call: raise HTTPException(status_code=404, detail="Specified tool call not found.") if tool_call.validated is not None: diff --git a/src/neuroagent/new_types.py b/src/neuroagent/new_types.py index 498b695..f5ec3c4 100644 --- a/src/neuroagent/new_types.py +++ b/src/neuroagent/new_types.py @@ -30,7 +30,6 @@ class HILResponse(BaseModel): class HILValidation(BaseModel): """Class to send the validated json to the api.""" - tool_call_id: str validated_inputs: dict[str, Any] | None = None is_validated: bool = True diff --git a/src/neuroagent/stream.py b/src/neuroagent/stream.py index a6e1288..19026a5 100644 --- a/src/neuroagent/stream.py +++ b/src/neuroagent/stream.py @@ -1,5 +1,6 @@ """Wrapper around streaming methods to reinitiate connections due to the way fastAPI StreamingResponse works.""" +import json from typing import Any, AsyncIterator from fastapi import Request @@ -51,9 +52,7 @@ async def stream_agent_response( yield chunk # Final chunk that contains the whole response elif chunk.hil_messages: - yield str( - [hil_message.model_dump_json() for hil_message in chunk.hil_messages] - ) + yield f"2:{json.dumps([hil_message.model_dump_json() for hil_message in chunk.hil_messages])}\n" # Save the new messages in DB thread.update_date = utc_now()