Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support-streaming-partial-tools #718

Merged
merged 12 commits into from
Nov 27, 2024
28 changes: 26 additions & 2 deletions mirascope/core/anthropic/_utils/_handle_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def _handle_chunk(
current_tool_call: ToolUseBlock,
current_tool_type: type[AnthropicTool] | None,
tool_types: list[type[AnthropicTool]] | None,
partial_tools: bool = False,
) -> tuple[
str,
AnthropicTool | None,
Expand Down Expand Up @@ -59,31 +60,54 @@ def _handle_chunk(
if chunk.type == "content_block_delta" and chunk.delta.type == "input_json_delta":
buffer += chunk.delta.partial_json

# Return partial tool if enabled
if partial_tools and current_tool_type:
partial_tool_call = ToolUseBlock(
id=current_tool_call.id,
input=buffer,
name=current_tool_call.name,
type="tool_use",
)
partial_tool = current_tool_type.from_tool_call(partial_tool_call, True)
partial_tool.delta = chunk.delta.partial_json
return buffer, partial_tool, current_tool_call, current_tool_type
return buffer, None, current_tool_call, current_tool_type


def handle_stream(
stream: Generator[MessageStreamEvent, None, None],
tool_types: list[type[AnthropicTool]] | None,
partial_tools: bool = False,
) -> Generator[tuple[AnthropicCallResponseChunk, AnthropicTool | None], None, None]:
"""Iterator over the stream and constructs tools as they are streamed."""
current_tool_call = ToolUseBlock(id="", input={}, name="", type="tool_use")
current_tool_type, buffer = None, ""
for chunk in stream:
buffer, tool, current_tool_call, current_tool_type = _handle_chunk(
buffer, chunk, current_tool_call, current_tool_type, tool_types
buffer,
chunk,
current_tool_call,
current_tool_type,
tool_types,
partial_tools,
)
yield AnthropicCallResponseChunk(chunk=chunk), tool


async def handle_stream_async(
stream: AsyncGenerator[MessageStreamEvent, None],
tool_types: list[type[AnthropicTool]] | None,
partial_tools: bool = False,
) -> AsyncGenerator[tuple[AnthropicCallResponseChunk, AnthropicTool | None], None]:
current_tool_call = ToolUseBlock(id="", input={}, name="", type="tool_use")
current_tool_type, buffer = None, ""
async for chunk in stream:
buffer, tool, current_tool_call, current_tool_type = _handle_chunk(
buffer, chunk, current_tool_call, current_tool_type, tool_types
buffer,
chunk,
current_tool_call,
current_tool_type,
tool_types,
partial_tools,
)
yield AnthropicCallResponseChunk(chunk=chunk), tool
7 changes: 4 additions & 3 deletions mirascope/core/anthropic/_utils/_setup_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from ...base import BaseMessageParam, BaseTool, _utils
from ...base._utils import AsyncCreateFn, CreateFn
from ...base.stream_config import StreamConfig
from .._call_kwargs import AnthropicCallKwargs
from ..call_params import AnthropicCallParams
from ..dynamic_config import AnthropicDynamicConfig, AsyncAnthropicDynamicConfig
Expand All @@ -36,7 +37,7 @@ def setup_call(
json_mode: bool,
call_params: AnthropicCallParams,
extract: bool,
stream: bool,
stream: bool | StreamConfig,
) -> tuple[
AsyncCreateFn[Message, MessageStreamEvent],
str | None,
Expand All @@ -58,7 +59,7 @@ def setup_call(
json_mode: bool,
call_params: AnthropicCallParams,
extract: bool,
stream: bool,
stream: bool | StreamConfig,
) -> tuple[
CreateFn[Message, MessageStreamEvent],
str | None,
Expand All @@ -85,7 +86,7 @@ def setup_call(
json_mode: bool,
call_params: AnthropicCallParams,
extract: bool,
stream: bool,
stream: bool | StreamConfig,
) -> tuple[
Callable[..., Message | Awaitable[Message]],
str | None,
Expand Down
17 changes: 13 additions & 4 deletions mirascope/core/anthropic/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

from __future__ import annotations

import copy
from typing import Any, cast

from anthropic.types import ToolParam, ToolUseBlock
from pydantic.json_schema import SkipJsonSchema
from typing_extensions import TypedDict

from ..base import BaseTool, ToolConfig
from ..base._partial import partial


class _CacheControl(TypedDict):
Expand Down Expand Up @@ -80,13 +81,21 @@ def format_book(title: str, author: str) -> str:
return ToolParam(**kwargs)

@classmethod
def from_tool_call(cls, tool_call: ToolUseBlock) -> AnthropicTool:
def from_tool_call(
cls, tool_call: ToolUseBlock, allow_partial: bool = False
) -> AnthropicTool:
"""Constructs an `AnthropicTool` instance from a `tool_call`.

Args:
tool_call: The Anthropic tool call from which to construct this tool
instance.
"""
model_json = copy.deepcopy(tool_call.input)
willbakst marked this conversation as resolved.
Show resolved Hide resolved
model_json["tool_call"] = tool_call.model_dump() # pyright: ignore [reportIndexIssue]
model_json = {"tool_call": tool_call}
model_json |= (
cls._dict_from_json(tool_call.input, True)
if isinstance(tool_call.input, str)
else cast(dict[str, Any], tool_call.input)
)
if allow_partial:
return partial(cls, {"tool_call", "delta"}).model_validate(model_json)
return cls.model_validate(model_json)
2 changes: 2 additions & 0 deletions mirascope/core/azure/_utils/_handle_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _handle_chunk(
def handle_stream(
stream: Generator[StreamingChatCompletionsUpdate, None, None],
tool_types: list[type[AzureTool]] | None,
partial_tools: bool = False,
) -> Generator[tuple[AzureCallResponseChunk, AzureTool | None], None, None]:
"""Iterator over the stream and constructs tools as they are streamed."""
current_tool_call = ChatCompletionsToolCall(
Expand Down Expand Up @@ -102,6 +103,7 @@ def handle_stream(
async def handle_stream_async(
stream: AsyncGenerator[StreamingChatCompletionsUpdate, None],
tool_types: list[type[AzureTool]] | None,
partial_tools: bool = False,
) -> AsyncGenerator[tuple[AzureCallResponseChunk, AzureTool | None], None]:
"""Async iterator over the stream and constructs tools as they are streamed."""
current_tool_call = ChatCompletionsToolCall(
Expand Down
7 changes: 4 additions & 3 deletions mirascope/core/azure/_utils/_setup_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ...base import BaseMessageParam, BaseTool, _utils
from ...base._utils import AsyncCreateFn, CreateFn, get_async_create_fn, get_create_fn
from ...base.call_params import CommonCallParams
from ...base.stream_config import StreamConfig
from .._call_kwargs import AzureCallKwargs
from ..call_params import AzureCallParams
from ..dynamic_config import AsyncAzureDynamicConfig, AzureDynamicConfig
Expand All @@ -41,7 +42,7 @@ def setup_call(
json_mode: bool,
call_params: AzureCallParams | CommonCallParams,
extract: bool,
stream: bool,
stream: bool | StreamConfig,
) -> tuple[
AsyncCreateFn[ChatCompletions, StreamingChatCompletionsUpdate],
str | None,
Expand All @@ -63,7 +64,7 @@ def setup_call(
json_mode: bool,
call_params: AzureCallParams | CommonCallParams,
extract: bool,
stream: bool,
stream: bool | StreamConfig,
) -> tuple[
CreateFn[ChatCompletions, StreamingChatCompletionsUpdate],
str | None,
Expand All @@ -85,7 +86,7 @@ def setup_call(
json_mode: bool,
call_params: AzureCallParams | CommonCallParams,
extract: bool,
stream: bool,
stream: bool | StreamConfig,
) -> tuple[
CreateFn[ChatCompletions, StreamingChatCompletionsUpdate]
| AsyncCreateFn[ChatCompletions, StreamingChatCompletionsUpdate],
Expand Down
4 changes: 3 additions & 1 deletion mirascope/core/base/_call_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .call_response_chunk import BaseCallResponseChunk
from .dynamic_config import BaseDynamicConfig
from .stream import BaseStream, stream_factory
from .stream_config import StreamConfig
from .structured_stream import structured_stream_factory
from .tool import BaseTool

Expand Down Expand Up @@ -123,7 +124,7 @@ def call_factory( # noqa: ANN202
def base_call(
model: str,
*,
stream: bool = False,
stream: bool | StreamConfig = False,
tools: list[type[BaseTool] | Callable] | None = None,
response_model: type[_ResponseModelT] | None = None,
output_parser: Callable[[_BaseCallResponseT], _ParsedOutputT]
Expand Down Expand Up @@ -221,6 +222,7 @@ def base_call(
json_mode=json_mode,
client=client,
call_params=call_params,
partial_tools=isinstance(stream, dict) and stream.get("partial_tools"),
) # pyright: ignore [reportReturnType, reportCallIssue]
return partial(
create_factory(TCallResponse=TCallResponse, setup_call=setup_call),
Expand Down
10 changes: 8 additions & 2 deletions mirascope/core/base/_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def _process_annotation(annotation: type) -> type:
) # pyright: ignore [reportReturnType]


def partial(wrapped_class: type[Model]) -> type[Model]:
def partial(
wrapped_class: type[Model], preserve_fields: set[str] | None = None
) -> type[Model]:
"""Generate a new class with all attributes optionals.

This decorator will wrap a class inheriting from BaseModel and will recursively
Expand All @@ -60,6 +62,8 @@ class User(BaseModel):
user = User() # All fields optional
```
"""
if preserve_fields is None:
preserve_fields = set()

def _make_field_optional(
field: FieldInfo,
Expand All @@ -83,7 +87,9 @@ def _make_field_optional(
__validators__=None,
__cls_kwargs__=None,
**{
field_name: _make_field_optional(field_info)
field_name: (field_info.annotation, field_info)
if field_name in preserve_fields
else _make_field_optional(field_info)
for field_name, field_info in wrapped_class.model_fields.items()
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def call(self: base) -> Any: # noqa: ANN401
if self.model_fields[field_name].alias
else field_name
): getattr(self, field_name)
for field_name in self.model_dump(exclude={"tool_call"})
for field_name in self.model_dump(exclude={"tool_call", "delta"})
}
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from typing_extensions import TypeIs

from ..stream_config import StreamConfig
from ._protocols import AsyncCreateFn, CreateFn

_StreamedResponse = TypeVar("_StreamedResponse")
Expand Down Expand Up @@ -54,7 +55,7 @@ def get_async_create_fn(
@overload
def create_or_stream(
*,
stream: Literal[True] = True,
stream: Literal[True] | StreamConfig = True,
**kwargs: Any, # noqa: ANN401
) -> Awaitable[AsyncGenerator[_StreamedResponse, None]]: ...

Expand All @@ -67,7 +68,7 @@ def create_or_stream(

def create_or_stream(
*,
stream: bool = False,
stream: bool | StreamConfig = False,
**kwargs: Any, # noqa: ANN401
) -> (
Awaitable[AsyncGenerator[_StreamedResponse, None]]
Expand Down Expand Up @@ -101,7 +102,7 @@ def get_create_fn(
@overload
def create_or_stream(
*,
stream: Literal[True] = True,
stream: Literal[True] | StreamConfig = True,
**kwargs: Any, # noqa: ANN401
) -> Generator[_StreamedResponse, None, None]: ...

Expand All @@ -114,7 +115,7 @@ def create_or_stream(

def create_or_stream(
*,
stream: bool = False,
stream: bool | StreamConfig = False,
**kwargs: Any, # noqa: ANN401
) -> Generator[_StreamedResponse, None, None] | _NonStreamedResponse:
if stream:
Expand Down
Loading