Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixed bug in collection of text context parts #90

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions aidial_analytics_realtime/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,20 @@ async def make_point(
response_content = "\n".join(response_contents)

if chat_id:
topic = await topic_model.get_topic_by_text(
"\n\n".join(request_contents + response_contents)
topic = to_string(
await topic_model.get_topic_by_text(
"\n\n".join(request_contents + response_contents)
)
)
case RequestType.EMBEDDING:
request_contents = get_embeddings_request_contents(logger, request)

request_content = "\n".join(request_contents)
if chat_id:
topic = await topic_model.get_topic_by_text(
"\n\n".join(request_contents)
topic = to_string(
await topic_model.get_topic_by_text(
"\n\n".join(request_contents)
)
)
case _:
assert_never(request_type)
Expand Down
58 changes: 38 additions & 20 deletions aidial_analytics_realtime/dial.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,71 @@
from logging import Logger
from typing import List
from typing import Iterator, List


def get_chat_completion_request_contents(
logger: Logger, request: dict
) -> List[str]:
return [
content
for message in request["messages"]
for content in _get_chat_completion_message_contents(logger, message)
]
return list(_chat_completion_request_contents(logger, request))


def get_chat_completion_response_contents(
logger: Logger, response: dict
) -> List[str]:
message = response["choices"][0]["message"]
return _get_chat_completion_message_contents(logger, message)
return list(_chat_completion_response_contents(logger, response))


def get_embeddings_request_contents(logger: Logger, request: dict) -> List[str]:
return list(_embeddings_request_contents(logger, request))


def _chat_completion_request_contents(
logger: Logger, request: dict
) -> Iterator[str]:
for message in request["messages"]:
yield from _chat_completion_message_contents(logger, message)


def _chat_completion_response_contents(
logger: Logger, response: dict
) -> Iterator[str]:
message = response["choices"][0]["message"]
yield from _chat_completion_message_contents(logger, message)


def _embeddings_request_contents(
logger: Logger, request: dict
) -> Iterator[str]:
inp = request.get("input")

if isinstance(inp, str):
return [inp]
yield from _non_empty_string(inp)
elif isinstance(inp, list):
return [i for i in inp if isinstance(i, str)]
for i in inp:
if isinstance(i, str):
yield from _non_empty_string(i)
else:
logger.warning(f"Unexpected type of embeddings input: {type(inp)}")
return []


def _get_chat_completion_message_contents(
def _chat_completion_message_contents(
logger: Logger, message: dict
) -> List[str]:
) -> Iterator[str]:
content = message.get("content")
if content is None:
return []
return
elif isinstance(content, str):
return [content]
yield from _non_empty_string(content)
elif isinstance(content, list):
ret: List[str] = []
for content_part in content:
if isinstance(content_part, dict):
if content_part.get("type") == "text" and (
text := content_part.get("content")
text := content_part.get("text")
):
ret.extend(text)
return ret
yield from _non_empty_string(text)
else:
logger.warning(f"Unexpected message content type: {type(content)}")
return []


def _non_empty_string(value: str) -> Iterator[str]:
if non_empty := value.strip():
yield non_empty
8 changes: 6 additions & 2 deletions aidial_analytics_realtime/topic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@ def __init__(
# Make sure the model is loaded
self._get_topic_by_text("test")

async def get_topic_by_text(self, text: str) -> str:
async def get_topic_by_text(self, text: str) -> str | None:
return await run_in_cpu_tasks_executor(self._get_topic_by_text, text)

def _get_topic_by_text(self, text: str) -> str:
def _get_topic_by_text(self, text: str) -> str | None:
text = text.strip()
if not text:
return None

topics, _ = self.model.transform([text])
topic = self.model.get_topic_info(topics[0])

Expand Down
5 changes: 5 additions & 0 deletions tests/influx_writer_mock.py → tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ def __init__(self):

async def __call__(self, record):
self.points.append(str(record))


class TestTopicModel:
async def get_topic_by_text(self, text: str) -> str | None:
return text or None
Loading
Loading