Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: align versions with inference endpoints #8

Merged
merged 12 commits into from
Jan 3, 2025
7 changes: 7 additions & 0 deletions narrative_llm_tools/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def main() -> None:
)
parser.add_argument("file", help="Path to the JSONL file to validate.")
parser.add_argument("--threads", type=int, default=4, help="Number of threads to use")
parser.add_argument("--clean", type=str, help="Output validated lines to specified file")
parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress bar")

args = parser.parse_args()
Expand Down Expand Up @@ -57,6 +58,12 @@ def main() -> None:
# Collect errors from results
errors = [error for result in results for error in result.errors]

if args.clean:
with open(args.clean, "w") as f:
for result in results:
if not result.errors:
f.write(result.original_line + "\n")

if errors:
print("Validation FAILED.\n")
for err in sorted(errors, key=lambda x: int(x.split()[1].rstrip(":"))):
Expand Down
48 changes: 32 additions & 16 deletions narrative_llm_tools/handlers/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
import logging
from collections.abc import Hashable
from typing import Any, Literal, Optional, Protocol
from typing import Any, Literal, Protocol

from pydantic import BaseModel
from torch import Tensor
from transformers import pipeline # type: ignore

from narrative_llm_tools.rest_api_client.types import RestApiResponse
from narrative_llm_tools.rest_api_client.types import RestApiResponse, ReturnToLlmBehavior
from narrative_llm_tools.state.conversation_state import (
ConversationMessage,
ConversationState,
Expand All @@ -17,14 +17,15 @@
from narrative_llm_tools.tools import Tool
from narrative_llm_tools.utils.format_enforcer import get_format_enforcer

logger = logging.getLogger(__name__)
logger = logging.getLogger("narrative-llm-tools")
logger.setLevel(logging.WARNING)


class HandlerResponse(BaseModel):
"""Response from the handler."""

tool_calls: list[dict[str, Any]]
warnings: Optional[list[str]]
warnings: list[str] | None


class ModelConfig(BaseModel):
Expand All @@ -34,7 +35,6 @@ class ModelConfig(BaseModel):
path: str
max_new_tokens: int = 4096
device_map: str = "auto"
low_cpu_mem_usage: bool = False
begin_token: str = "<|begin_of_text|>"
eot_token: str = "<|eot_id|>"

Expand Down Expand Up @@ -111,14 +111,14 @@ class AuthenticationError(EndpointError):


class EndpointHandler:
def __init__(self, path: str = "", low_cpu_mem_usage: bool = False) -> None:
def __init__(self, path: str = "") -> None:
"""
Initialize the EndpointHandler with the provided model path.

Args:
path (str, optional): The path or identifier of the model. Defaults to "".
"""
self.config = ModelConfig(path=path, low_cpu_mem_usage=low_cpu_mem_usage)
self.config = ModelConfig(path=path)

try:
self.pipeline: Pipeline = self._create_pipeline()
Expand All @@ -135,11 +135,10 @@ def _create_pipeline(self) -> Pipeline:
model=self.config.path,
max_new_tokens=self.config.max_new_tokens,
device_map=self.config.device_map,
low_cpu_mem_usage=self.config.low_cpu_mem_usage,
)
return pipe # type: ignore

def __call__(self, data: dict[str, Any]) -> HandlerResponse:
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
"""
Generate model output given a conversation and optional tools/parameters.

Expand Down Expand Up @@ -295,7 +294,7 @@ def __call__(self, data: dict[str, Any]) -> HandlerResponse:
if not isinstance(tool_call, dict):
raise ModelOutputError("Model output is not a list of tool calls.")

return HandlerResponse(tool_calls=return_msg, warnings=None)
return HandlerResponse(tool_calls=return_msg, warnings=None).model_dump(exclude_none=True)

except (
ValidationError,
Expand Down Expand Up @@ -331,6 +330,7 @@ def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState)
"""Execute tool calls and update conversation state."""
logger.debug(f"Executing tool calls: {tool_calls}")
rest_api_catalog = state.get_rest_api_catalog()
logger.info(f"Rest API catalog: {rest_api_catalog}")

if not rest_api_catalog:
logger.info("No rest API catalog is available, skipping all tool calls.")
Expand All @@ -344,21 +344,37 @@ def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState)
api_client = rest_api_catalog[tool.name]
api_response: RestApiResponse = api_client.call(tool.parameters)
api_client_behavior = (
api_client.config.response_behavior.get(api_response.status)
if api_client.config.response_behavior.get(api_response.status)
api_client.config.response_behavior.get(str(api_response.status))
if api_client.config.response_behavior.get(str(api_response.status))
else api_client.config.response_behavior.get("default")
)

if api_response.type == "json" and api_client_behavior == "return_to_llm":
tool_responses.append(ToolResponse(name=tool.name, content=api_response.body))
logger.info(f"API response: {api_response}, behavior: {api_client_behavior}")
behavior_type = api_client_behavior.behavior_type if api_client_behavior else None

if (
behavior_type
and behavior_type == "return_to_llm"
):
llm_response_behavior: ReturnToLlmBehavior = api_client_behavior # type: ignore

response = (
llm_response_behavior.response
if llm_response_behavior.response
else api_response.body
)
tool_responses.append(ToolResponse(name=tool.name, content=response))
elif (
api_response.type == "json" and api_client_behavior == "return_response_to_user"
api_response.type == "json"
and behavior_type
and behavior_type == "return_response_to_user"
):
tool_responses.append(ToolResponse(name=tool.name, content=api_response.body))
return_to_user = True
elif (
api_response.type == "json"
and api_client_behavior == "return_request_to_user"
and behavior_type
and behavior_type == "return_request_to_user"
and api_response.request
):
tool_responses.append(
Expand Down
2 changes: 1 addition & 1 deletion narrative_llm_tools/rest_api_client/rest_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
RestApiResponse,
)

logger = logging.getLogger(__name__)
logger = logging.getLogger("narrative-llm-tools")


class RestApiClient(BaseModel):
Expand Down
11 changes: 6 additions & 5 deletions narrative_llm_tools/rest_api_client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __eq__(self, other: Any) -> bool:

class Behavior(BaseModel):
behavior_type: str
response: str | None = None

def __hash__(self) -> int:
return hash(self.behavior_type)
Expand All @@ -38,17 +39,17 @@ def __eq__(self, other: Any) -> bool:

class ReturnToLlmBehavior(Behavior):
behavior_type: Literal["return_to_llm"] = "return_to_llm"
llm_response: str | None = None
response: str | None = None


class ReturnResponseToUserBehavior(Behavior):
behavior_type: Literal["return_response_to_user"] = "return_response_to_user"
user_response: str | None = None
response: str | None = None


class ReturnRequestToUserBehavior(Behavior):
behavior_type: Literal["return_request_to_user"] = "return_request_to_user"
user_response: str | None = None
response: str | None = None


class RestApiResponse(BaseModel):
Expand All @@ -62,8 +63,8 @@ class RestApiConfig(BaseModel):
url: str
method: HttpMethod
auth: BearerTokenAuth | None = None
response_behavior: dict[int | Literal["default"], Behavior] = {
"default": ReturnToLlmBehavior(llm_response=None),
response_behavior: dict[str | Literal["default"], Behavior] = {
"default": ReturnToLlmBehavior(response=None),
}
query_path: str | None = None
parameter_location: ParameterLocation | None = None
Expand Down
55 changes: 32 additions & 23 deletions narrative_llm_tools/state/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@ def validate_conversation_structure(self) -> "Conversation":
conv = self.conversations

if len(conv) < 3:
all_errors.append(
f"'conversation' must have at least 3 messages. Found {len(conv)}."
)
all_errors.append(f"'conversation' must have at least 3 messages. Found {len(conv)}.")

system_count = 0
tool_catalog_count = 0
last_role = None
found_system = False
tool_catalog_schema = None
tool_catalog_schema = None
assistant_call_indices = []
user_count = 0

Expand Down Expand Up @@ -134,15 +132,22 @@ def validate_conversation_structure(self) -> "Conversation":
_, prev_arr = parse_json_array(prev_content, tool_catalog_schema)

if len(arr) != len(prev_arr):
msg_errors.append("tool_response array length must match the preceding 'assistant'/'tool_call' array.")
msg_errors.append(
"tool_response array length must match the preceding "
"'assistant'/'tool_call' array."
)
else:
for idx, (response, prev_call) in enumerate(zip(arr, prev_arr)):
for idx, (response, prev_call) in enumerate(
zip(arr, prev_arr, strict=False)
):
structure_errors = validate_tool_response_structure(response, idx)
if structure_errors:
msg_errors.extend(structure_errors)
continue

matching_errors = validate_tool_response_matching(response, prev_call, idx)

matching_errors = validate_tool_response_matching(
response, prev_call, idx
)
msg_errors.extend(matching_errors)

all_errors.extend(msg_errors)
Expand Down Expand Up @@ -218,20 +223,25 @@ def validate_conversation_object(obj: Any, line_number: int) -> list[str]:
class ValidationResult:
line_number: int
errors: list[str]
original_line: str


def validate_line(args: tuple[str, int]) -> ValidationResult:
line, line_number = args
if not line.strip():
return ValidationResult(line_number, [f"Line {line_number}: Empty line is not allowed."])
return ValidationResult(
line_number,
[f"Line {line_number}: Empty line is not allowed."],
line,
)

try:
conversation_obj = json.loads(line)
except json.JSONDecodeError as e:
return ValidationResult(line_number, [f"Line {line_number}: Invalid JSON - {str(e)}"])
return ValidationResult(line_number, [f"Line {line_number}: Invalid JSON - {str(e)}"], line)

line_errors = validate_conversation_object(conversation_obj, line_number)
return ValidationResult(line_number, line_errors)
return ValidationResult(line_number, line_errors, line)


def extract_enumerated_names(tool_catalog_schema: Mapping[str, Any]) -> set[str]:
Expand Down Expand Up @@ -291,33 +301,32 @@ def validate_tool_catalog_schema(schema_str: str) -> tuple[Any, list[str]]:
return None, errors


def validate_tool_response_structure(response: dict, idx: int) -> list[str]:
def validate_tool_response_structure(response: dict[str, Any], idx: int) -> list[str]:
"""Validate the structure of a single tool response object."""
errors = []

if not isinstance(response, dict):
errors.append(f"Response at index {idx} must be an object")
errors.append(f"Response at index {idx} must be an object") # type: ignore[unreachable]
return errors

if set(response.keys()) != {"name", "content"}:
errors.append(
f"Response at index {idx} must have exactly 'name' and 'content' fields"
)
errors.append(f"Response at index {idx} must have exactly 'name' and 'content' fields")
return errors

if not isinstance(response["name"], str) or not isinstance(response["content"], str):
errors.append(
f"Response at index {idx}: 'name' and 'content' must be strings"
)

errors.append(f"Response at index {idx}: 'name' and 'content' must be strings")

return errors

def validate_tool_response_matching(response: dict, prev_call: dict, idx: int) -> list[str]:

def validate_tool_response_matching(
response: dict[str, Any], prev_call: dict[str, Any], idx: int
) -> list[str]:
"""Validate that a tool response matches its corresponding tool call."""
errors = []
if response["name"] != prev_call.get("name"):
errors.append(
f"Response at index {idx}: name '{response['name']}' does not match "
f"tool call name '{prev_call.get('name')}'"
)
return errors
return errors
4 changes: 3 additions & 1 deletion narrative_llm_tools/state/conversation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from narrative_llm_tools.rest_api_client.rest_api_client import RestApiClient
from narrative_llm_tools.tools.json_schema_tools import JsonSchemaTools, Tool

logger = logging.getLogger(__name__)
logger = logging.getLogger("narrative-llm-tools")


class ConversationMessage(BaseModel):
Expand Down Expand Up @@ -285,6 +285,7 @@ def _handle_tool_call(self, message: ConversationMessage) -> None:
Handles adding a tool_call message and performing relevant state transitions.
"""
tool_calls = self.parse_tool_calls_content(message.content)
logger.info(f"Handling tool call: {message}")
self.raw_messages.append(message)

if self.responded_to_user(message.content):
Expand All @@ -300,6 +301,7 @@ def _handle_tool_response(self, message: ConversationMessage) -> None:
"""
Handles adding a tool response message and updating state accordingly.
"""
logger.info(f"Handling tool response: {message}")
self.raw_messages.append(message)

if self.status == ConversationStatus.WAITING_TOOL_RESPONSE:
Expand Down
10 changes: 3 additions & 7 deletions narrative_llm_tools/state/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,8 @@ def validate_value(cls, v: Any) -> str:
if isinstance(v, str):
return v
raise ValueError(f"Message 'content' must be a string, got {type(v)}")

model_config = {
'extra': 'forbid'
}

model_config = {"extra": "forbid"}


class SystemMessage(BaseMessage):
Expand Down Expand Up @@ -126,6 +124,4 @@ def validate_response(cls, v: Any) -> str:
class MessageWrapper(BaseModel):
message: Message

model_config = {
'extra': 'forbid'
}
model_config = {"extra": "forbid"}
2 changes: 1 addition & 1 deletion narrative_llm_tools/tools/json_schema_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from narrative_llm_tools.rest_api_client.rest_api_client import RestApiClient
from narrative_llm_tools.rest_api_client.types import RestApiConfig

logger = logging.getLogger(__name__)
logger = logging.getLogger("narrative-llm-tools")


class NameProperty(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion narrative_llm_tools/utils/format_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from narrative_llm_tools.tools.json_schema_tools import JsonSchemaTools

logger = logging.getLogger(__name__)
logger = logging.getLogger("narrative-llm-tools")


class TransformersPrefixAllowedTokensFn(Protocol):
Expand Down
Loading
Loading