Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release/v1.14 #755

Merged
merged 7 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
- id: ruff-format
files: "^mirascope|^tests|^examples|^docs"
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.390
rev: v1.1.391
hooks:
- id: pyright
- repo: local
Expand Down
14 changes: 6 additions & 8 deletions mirascope/core/anthropic/_utils/_setup_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AsyncAnthropicVertex,
)
from anthropic.types import Message, MessageParam, MessageStreamEvent
from pydantic import BaseModel

from ...base import BaseMessageParam, BaseTool, _utils
from ...base._utils import AsyncCreateFn, CreateFn
Expand All @@ -36,7 +37,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AnthropicCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
AsyncCreateFn[Message, MessageStreamEvent],
Expand All @@ -58,7 +59,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AnthropicCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
CreateFn[Message, MessageStreamEvent],
Expand All @@ -85,7 +86,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AnthropicCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
Callable[..., Message | Awaitable[Message]],
Expand All @@ -111,17 +112,14 @@ def setup_call(
call_kwargs["system"] = messages.pop(0)["content"] # pyright: ignore [reportGeneralTypeIssues]

if json_mode:
json_mode_content = _utils.json_mode_content(
tool_types[0] if tool_types else None
)
json_mode_content = _utils.json_mode_content(response_model)
if isinstance(messages[-1]["content"], str):
messages[-1]["content"] += json_mode_content
else:
messages[-1]["content"] = list(messages[-1]["content"]) + [
{"type": "text", "text": json_mode_content}
]
call_kwargs.pop("tools", None)
elif extract:
elif response_model:
assert tool_types, "At least one tool must be provided for extraction."
call_kwargs["tool_choice"] = {"type": "tool", "name": tool_types[0]._name()}
call_kwargs |= {
Expand Down
8 changes: 5 additions & 3 deletions mirascope/core/anthropic/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
usage docs: learn/calls.md#handling-responses
"""

from functools import cached_property

from anthropic.types import (
Message,
MessageParam,
Expand Down Expand Up @@ -97,13 +99,13 @@ def cost(self) -> float | None:
return calculate_cost(self.input_tokens, self.output_tokens, self.model)

@computed_field
@property
@cached_property
def message_param(self) -> SerializeAsAny[MessageParam]:
"""Returns the assistants's response as a message parameter."""
return MessageParam(**self.response.model_dump(include={"content", "role"}))

@computed_field
@property
@cached_property
def tools(self) -> list[AnthropicTool] | None:
"""Returns any available tool calls as their `AnthropicTool` definition.

Expand All @@ -125,7 +127,7 @@ def tools(self) -> list[AnthropicTool] | None:
return extracted_tools

@computed_field
@property
@cached_property
def tool(self) -> AnthropicTool | None:
"""Returns the 0th tool for the 0th choice message.

Expand Down
30 changes: 17 additions & 13 deletions mirascope/core/azure/_utils/_setup_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@
UserMessage,
)
from azure.core.credentials import AzureKeyCredential
from pydantic import BaseModel

from ...base import BaseMessageParam, BaseTool, _utils
from ...base._utils import AsyncCreateFn, CreateFn, get_async_create_fn, get_create_fn
from ...base._utils import (
DEFAULT_TOOL_DOCSTRING,
AsyncCreateFn,
CreateFn,
get_async_create_fn,
get_create_fn,
)
from ...base.call_params import CommonCallParams
from ...base.stream_config import StreamConfig
from .._call_kwargs import AzureCallKwargs
Expand All @@ -41,7 +48,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AzureCallParams | CommonCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
AsyncCreateFn[ChatCompletions, StreamingChatCompletionsUpdate],
Expand All @@ -63,7 +70,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AzureCallParams | CommonCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
CreateFn[ChatCompletions, StreamingChatCompletionsUpdate],
Expand All @@ -85,7 +92,7 @@ def setup_call(
tools: list[type[BaseTool] | Callable] | None,
json_mode: bool,
call_params: AzureCallParams | CommonCallParams,
extract: bool,
response_model: type[BaseModel] | None,
stream: bool | StreamConfig,
) -> tuple[
CreateFn[ChatCompletions, StreamingChatCompletionsUpdate]
Expand All @@ -108,25 +115,22 @@ def setup_call(
messages = cast(list[BaseMessageParam | ChatRequestMessage], messages)
messages = convert_message_params(messages)
if json_mode:
if tool_types and tool_types[0].model_config.get("strict", False):
if response_model and response_model.model_config.get("strict", False):
call_kwargs["response_format"] = ChatCompletionsResponseFormatJSON(
{
"name": tool_types[0]._name(),
"description": tool_types[0]._description(),
"name": response_model.__name__,
"description": response_model.__doc__ or DEFAULT_TOOL_DOCSTRING,
"strict": True,
"schema": tool_types[0].model_json_schema(
"schema": response_model.model_json_schema(
schema_generator=GenerateAzureStrictToolJsonSchema
),
}
)
else:
call_kwargs["response_format"] = ChatCompletionsResponseFormatJSON()
json_mode_content = _utils.json_mode_content(
tool_types[0] if tool_types else None
).strip()
json_mode_content = _utils.json_mode_content(response_model).strip()
messages.append(UserMessage(content=json_mode_content))
call_kwargs.pop("tools", None)
elif extract:
elif response_model:
assert tool_types, "At least one tool must be provided for extraction."
if tool_types and tool_types[0].model_config.get("strict", False):
warnings.warn(
Expand Down
8 changes: 5 additions & 3 deletions mirascope/core/azure/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
usage docs: learn/calls.md#handling-responses
"""

from functools import cached_property

from azure.ai.inference.models import (
AssistantMessage,
ChatCompletions,
Expand Down Expand Up @@ -101,7 +103,7 @@ def cost(self) -> float | None:
return calculate_cost(self.input_tokens, self.output_tokens, self.model)

@computed_field
@property
@cached_property
def message_param(self) -> SerializeAsAny[AssistantMessage]:
"""Returns the assistants's response as a message parameter."""
message_param = self.response.choices[0].message
Expand All @@ -110,7 +112,7 @@ def message_param(self) -> SerializeAsAny[AssistantMessage]:
)

@computed_field
@property
@cached_property
def tools(self) -> list[AzureTool] | None:
"""Returns any available tool calls as their `AzureTool` definition.

Expand All @@ -134,7 +136,7 @@ def tools(self) -> list[AzureTool] | None:
return extracted_tools

@computed_field
@property
@cached_property
def tool(self) -> AzureTool | None:
"""Returns the 0th tool for the 0th choice message.

Expand Down
1 change: 1 addition & 0 deletions mirascope/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"CacheControlPart",
"call_factory",
"CommonCallParams",
"DocumentPart",
"FromCallArgs",
"GenerateJsonSchemaNoTitles",
"ImagePart",
Expand Down
16 changes: 16 additions & 0 deletions mirascope/core/base/_call_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ._create import create_factory
from ._extract import extract_factory
from ._extract_with_tools import extract_with_tools_factory
from ._utils import (
BaseType,
GetJsonOutput,
Expand Down Expand Up @@ -192,6 +193,20 @@ def base_call(
client=client,
call_params=call_params,
) # pyright: ignore [reportReturnType, reportCallIssue]
elif tools:
return partial(
extract_with_tools_factory(
TCallResponse=TCallResponse,
setup_call=setup_call,
get_json_output=get_json_output,
),
model=model,
tools=tools,
response_model=response_model,
output_parser=output_parser,
client=client,
call_params=call_params,
) # pyright: ignore [reportReturnType, reportCallIssue]
else:
return partial(
extract_factory(
Expand Down Expand Up @@ -228,6 +243,7 @@ def base_call(
create_factory(TCallResponse=TCallResponse, setup_call=setup_call),
model=model,
tools=tools,
response_model=None,
output_parser=output_parser,
json_mode=json_mode,
client=client,
Expand Down
11 changes: 9 additions & 2 deletions mirascope/core/base/_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from functools import wraps
from typing import Any, ParamSpec, TypeVar, cast, overload

from pydantic import BaseModel

from ._utils import (
SameSyncAndAsyncClientSetupCall,
SetupCall,
Expand Down Expand Up @@ -72,6 +74,7 @@ def decorator(
fn: Callable[_P, _BaseDynamicConfigT],
model: str,
tools: list[type[BaseTool] | Callable] | None,
response_model: type[BaseModel] | None,
output_parser: Callable[[_BaseCallResponseT], _ParsedOutputT] | None,
json_mode: bool,
client: _SameSyncAndAsyncClientT | _SyncBaseClientT | None,
Expand All @@ -83,6 +86,7 @@ def decorator(
fn: Callable[_P, Messages.Type],
model: str,
tools: list[type[BaseTool] | Callable] | None,
response_model: type[BaseModel] | None,
output_parser: Callable[[_BaseCallResponseT], _ParsedOutputT] | None,
json_mode: bool,
client: _SameSyncAndAsyncClientT | _SyncBaseClientT | None,
Expand All @@ -98,6 +102,7 @@ def decorator(
],
model: str,
tools: list[type[BaseTool] | Callable] | None,
response_model: type[BaseModel] | None,
output_parser: Callable[[_BaseCallResponseT], _ParsedOutputT] | None,
json_mode: bool,
client: _SameSyncAndAsyncClientT | _AsyncBaseClientT | None,
Expand All @@ -112,6 +117,7 @@ def decorator(
fn: Callable[_P, Awaitable[Messages.Type] | Coroutine[Any, Any, Messages.Type]],
model: str,
tools: list[type[BaseTool] | Callable] | None,
response_model: type[BaseModel] | None,
output_parser: Callable[[_BaseCallResponseT], _ParsedOutputT] | None,
json_mode: bool,
client: _SameSyncAndAsyncClientT | _AsyncBaseClientT | None,
Expand All @@ -132,6 +138,7 @@ def decorator(
| Callable[_P, Awaitable[Messages.Type] | Coroutine[Any, Any, Messages.Type]],
model: str,
tools: list[type[BaseTool] | Callable] | None,
response_model: type[BaseModel] | None,
output_parser: Callable[[_BaseCallResponseT], _ParsedOutputT] | None,
json_mode: bool,
client: _SameSyncAndAsyncClientT | _AsyncBaseClientT | _SyncBaseClientT | None,
Expand Down Expand Up @@ -174,7 +181,7 @@ async def inner_async(
tools=tools,
json_mode=json_mode,
call_params=call_params,
extract=False,
response_model=response_model,
stream=False,
)
start_time = datetime.datetime.now().timestamp() * 1000
Expand Down Expand Up @@ -218,7 +225,7 @@ def inner(
tools=tools,
json_mode=json_mode,
call_params=call_params,
extract=False,
response_model=response_model,
stream=False,
)
start_time = datetime.datetime.now().timestamp() * 1000
Expand Down
9 changes: 6 additions & 3 deletions mirascope/core/base/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def extract_factory( # noqa: ANN202
):
"""Returns the wrapped function with the provider specific interfaces."""
create_decorator = create_factory(
TCallResponse=TCallResponse, setup_call=setup_call
TCallResponse=TCallResponse,
setup_call=setup_call,
)

@overload
Expand Down Expand Up @@ -110,10 +111,12 @@ def decorator(
]:
fn._model = model # pyright: ignore [reportFunctionMemberAccess]
fn.__mirascope_call__ = True # pyright: ignore [reportFunctionMemberAccess]
tool = setup_extract_tool(response_model, TToolType)
create_decorator_kwargs = {
"model": model,
"tools": [tool],
"tools": [setup_extract_tool(response_model, TToolType)]
if not json_mode
else None,
"response_model": response_model,
"output_parser": None,
"json_mode": json_mode,
"client": client,
Expand Down
Loading
Loading