Skip to content

Commit

Permalink
Fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleksii-Klimov committed Jan 23, 2024
1 parent 87283f9 commit 3d79c15
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 111 deletions.
126 changes: 75 additions & 51 deletions tests/unit_tests/chain/test_command_chain_best_effort.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from aidial_assistant.chain.history import History, ScopedMessage
from aidial_assistant.commands.base import Command, TextResult
from aidial_assistant.model.model_client import ModelClient
from aidial_assistant.model.model_client import ModelClient, ModelClientRequest
from aidial_assistant.utils.open_ai import (
assistant_message,
system_message,
Expand Down Expand Up @@ -82,16 +82,20 @@ async def test_model_doesnt_support_protocol():
]
assert model_client.agenerate.call_args_list == [
call(
[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"),
]
ModelClientRequest(
messages=[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"),
]
)
),
call(
[
system_message(SYSTEM_MESSAGE),
user_message(USER_MESSAGE),
]
ModelClientRequest(
messages=[
system_message(SYSTEM_MESSAGE),
user_message(USER_MESSAGE),
]
)
),
]

Expand Down Expand Up @@ -132,26 +136,34 @@ async def test_model_partially_supports_protocol():
]
assert model_client.agenerate.call_args_list == [
call(
[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"),
]
ModelClientRequest(
messages=[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"),
]
)
),
call(
[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(USER_MESSAGE),
assistant_message(TEST_COMMAND_REQUEST),
user_message(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"),
]
ModelClientRequest(
messages=[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(USER_MESSAGE),
assistant_message(TEST_COMMAND_REQUEST),
user_message(
f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"
),
]
)
),
call(
[
system_message(SYSTEM_MESSAGE),
user_message(
f"user_message={USER_MESSAGE}, error={FAILED_PROTOCOL_ERROR}, dialogue={succeeded_dialogue}"
),
]
ModelClientRequest(
messages=[
system_message(SYSTEM_MESSAGE),
user_message(
f"user_message={USER_MESSAGE}, error={FAILED_PROTOCOL_ERROR}, dialogue={succeeded_dialogue}"
),
]
)
),
]

Expand Down Expand Up @@ -197,26 +209,34 @@ async def test_no_tokens_for_tools():
]
assert model_client.agenerate.call_args_list == [
call(
[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"),
]
ModelClientRequest(
messages=[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"),
]
)
),
call(
[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(USER_MESSAGE),
assistant_message(TEST_COMMAND_REQUEST),
user_message(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"),
]
ModelClientRequest(
messages=[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(USER_MESSAGE),
assistant_message(TEST_COMMAND_REQUEST),
user_message(
f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"
),
]
)
),
call(
[
system_message(SYSTEM_MESSAGE),
user_message(
f"user_message={USER_MESSAGE}, error={NO_TOKENS_ERROR}, dialogue=[]"
),
]
ModelClientRequest(
messages=[
system_message(SYSTEM_MESSAGE),
user_message(
f"user_message={USER_MESSAGE}, error={NO_TOKENS_ERROR}, dialogue=[]"
),
]
)
),
]

Expand Down Expand Up @@ -255,18 +275,22 @@ async def test_model_request_limit_exceeded():
]
assert model_client.agenerate.call_args_list == [
call(
[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"),
]
ModelClientRequest(
messages=[
system_message(f"system_prefix={SYSTEM_MESSAGE}"),
user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"),
]
)
),
call(
[
system_message(SYSTEM_MESSAGE),
user_message(
f"user_message={USER_MESSAGE}, error={LIMIT_EXCEEDED_ERROR}, dialogue=[]"
),
]
ModelClientRequest(
messages=[
system_message(SYSTEM_MESSAGE),
user_message(
f"user_message={USER_MESSAGE}, error={LIMIT_EXCEEDED_ERROR}, dialogue=[]"
),
]
)
),
]
assert model_request_limiter.verify_limit.call_args_list == [
Expand Down
49 changes: 0 additions & 49 deletions tests/unit_tests/chain/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,55 +71,6 @@ async def test_history_truncation(
]


@pytest.mark.asyncio
async def test_truncation_overflow():
history = History(
assistant_system_message_template=Template(""),
best_effort_template=Template(""),
scoped_messages=[
ScopedMessage(message=system_message("a"), user_index=0),
ScopedMessage(message=user_message("b"), user_index=1),
],
)

model_client = Mock(spec=ModelClient)
model_client.get_discarded_messages.return_value = 1

with pytest.raises(Exception) as exc_info:
await history.truncate(model_client, MAX_PROMPT_TOKENS)

assert (
str(exc_info.value) == "No user messages left after history truncation."
)


@pytest.mark.asyncio
async def test_truncation_with_incorrect_message_sequence():
history = History(
assistant_system_message_template=Template(""),
best_effort_template=Template(""),
scoped_messages=[
ScopedMessage(
message=user_message("a"),
scope=MessageScope.INTERNAL,
user_index=0,
),
ScopedMessage(message=user_message("b"), user_index=0),
],
)

model_client = Mock(spec=ModelClient)
model_client.get_discarded_messages.return_value = 1

with pytest.raises(Exception) as exc_info:
await history.truncate(model_client, MAX_PROMPT_TOKENS)

assert (
str(exc_info.value)
== "Internal messages must be followed by an assistant reply."
)


def test_protocol_messages_with_system_message():
system_content = "<system message>"
user_content = "<user message>"
Expand Down
14 changes: 8 additions & 6 deletions tests/unit_tests/model/test_model_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any
from unittest.mock import Mock, call

import pytest
from openai import AsyncOpenAI
from openai._types import NOT_GIVEN
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from pydantic import BaseModel

Expand Down Expand Up @@ -35,7 +37,7 @@ class Choice(BaseModel):

class Chunk(BaseModel):
choices: list[Choice]
statistics: dict[str, int] | None = None
statistics: dict[str, Any] | None = None
usage: Usage | None = None


Expand All @@ -47,7 +49,7 @@ async def test_discarded_messages():
[
Chunk(
choices=[Choice(delta=Delta(content=""))],
statistics={"discarded_messages": 2},
statistics={"discarded_messages": [0, 1]},
)
]
)
Expand All @@ -61,7 +63,7 @@ async def test_discarded_messages():
)

assert extra_results_callback.on_discarded_messages.call_args_list == [
call(2)
call([0, 1])
]


Expand Down Expand Up @@ -139,8 +141,8 @@ async def test_api_args():
messages=messages,
**MODEL_ARGS,
stream=True,
tools=None,
max_tokens=None,
extra_body={},
tools=NOT_GIVEN,
max_tokens=NOT_GIVEN,
extra_body=None,
)
]
8 changes: 4 additions & 4 deletions tests/unit_tests/utils/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,22 @@ def test_parse_history():
ScopedMessage(
scope=MessageScope.INTERNAL,
message=assistant_message(FIRST_REQUEST_FIXED),
user_index=0,
user_index=1,
),
ScopedMessage(
scope=MessageScope.INTERNAL,
message=user_message(FIRST_RESPONSE),
user_index=0,
user_index=1,
),
ScopedMessage(
scope=MessageScope.INTERNAL,
message=assistant_message(SECOND_REQUEST),
user_index=0,
user_index=1,
),
ScopedMessage(
scope=MessageScope.INTERNAL,
message=user_message(content=SECOND_RESPONSE),
user_index=0,
user_index=1,
),
ScopedMessage(
scope=MessageScope.USER,
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def agenerate(

@staticmethod
def agenerate_key(request: ModelClientRequest) -> str:
return json.dumps(request)
return json.dumps(request.json())


class TestCommand(Command):
Expand Down

0 comments on commit 3d79c15

Please sign in to comment.