Skip to content

Commit

Permalink
Implement chat generators
Browse files Browse the repository at this point in the history
  • Loading branch information
TilmanGriesel committed Feb 2, 2025
1 parent 98c1df6 commit 9838f3b
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 66 deletions.
6 changes: 2 additions & 4 deletions services/api/src/api/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def run_rag():

load_duration = time.time_ns() - load_start

result = rag.run_query(
response_text = rag.run_query(
query=query, conversation=conversation, print_response=DEBUG
)

Expand All @@ -157,9 +157,7 @@ def run_rag():
load_duration=load_duration,
prompt_eval_count=len(conversation) + 1,
prompt_eval_duration=end_time - (prompt_start or start_time),
eval_count=len(result["llm"]["replies"][0].split())
if result and "llm" in result and "replies" in result["llm"]
else 0,
eval_count=len(response_text.split()),
eval_duration=end_time - (prompt_start or start_time),
)
q.put(json.dumps(final_data) + "\n")
Expand Down
20 changes: 7 additions & 13 deletions services/api/src/core/component_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from typing import Callable, Optional

from core.pipeline_config import ModelProvider, QueryPipelineConfig
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.embedders import HuggingFaceAPITextEmbedder
from haystack.components.generators import HuggingFaceAPIGenerator
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
from haystack.utils import Secret
from haystack_integrations.components.embedders.ollama import OllamaTextEmbedder
from haystack_integrations.components.generators.ollama import OllamaGenerator
from haystack_integrations.components.generators.ollama import OllamaChatGenerator
from haystack_integrations.components.retrievers.elasticsearch import (
ElasticsearchEmbeddingRetriever,
)
Expand All @@ -28,7 +27,7 @@ def __init__(
self.streaming_callback = streaming_callback
self.logger = logging.getLogger(__name__)

def create_text_embedder(self):
def create_embedder(self):
self.logger.info(
f"Initializing Text Embedder with model: {self.config.embedding_model}"
)
Expand Down Expand Up @@ -76,13 +75,8 @@ def create_retriever(self) -> ElasticsearchEmbeddingRetriever:
self.logger.info("Elasticsearch Retriever initialized successfully")
return retriever

def create_prompt_builder(self, template: str) -> PromptBuilder:
"""Create prompt builder with specified template."""
self.logger.info("Initializing Prompt Builder")
return PromptBuilder(template=template)

def create_generator(self):
"""Create text generator based on provider configuration."""
def create_chat_generator(self):
"""Create chat generator based on provider configuration."""
self.logger.info(f"Initializing Generator with model: {self.config.model_name}")

if self.config.provider == ModelProvider.OLLAMA:
Expand Down Expand Up @@ -131,7 +125,7 @@ def create_generator(self):
logging.info(f"Generation kwargs: {generation_kwargs}")

# Instantiate generator
generator = OllamaGenerator(
generator = OllamaChatGenerator(
model=self.config.model_name,
url=self.config.ollama_url,
generation_kwargs=generation_kwargs,
Expand All @@ -144,7 +138,7 @@ def create_generator(self):
"HuggingFace API key is required for HuggingFace provider"
)

generator = HuggingFaceAPIGenerator(
generator = HuggingFaceAPIChatGenerator(
api_type="serverless_inference_api",
api_params={
"model": self.config.model_name,
Expand Down
99 changes: 87 additions & 12 deletions services/api/src/core/conversation_logger.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,103 @@
import json
from datetime import datetime
from pathlib import Path
from typing import List
from typing import Any, Dict, List, Union

from haystack.dataclasses import ChatMessage, ChatRole


class ConversationLogger:
def __init__(self, system_info: dict, log_dir: str = "conversation_logs"):
"""Initialize the conversation logger.
Args:
system_info: Dictionary containing system information to be logged
log_dir: Directory where conversation logs will be stored
"""
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.system_info = system_info

def _serialize_chat_message(
self, message: Union[ChatMessage, Dict[str, Any]]
) -> Dict[str, Any]:
"""Serialize a ChatMessage object or dict into a consistent dictionary format.
Args:
message: ChatMessage object or dictionary to serialize
Returns:
Dictionary representation of the message
"""
try:
if isinstance(message, ChatMessage):
return {
"role": message.role.value
if isinstance(message.role, ChatRole)
else message.role,
"content": message.text,
"name": message.name,
"meta": message.meta,
}
elif isinstance(message, dict):
if "llm" in message and "replies" in message["llm"]:
replies = message["llm"]["replies"]
if replies and isinstance(replies[0], ChatMessage):
return self._serialize_chat_message(replies[0])
return message

raise ValueError(f"Unsupported message type: {type(message)}")

except Exception as e:
return {
"error": f"Serialization error: {str(e)}",
"content": str(message),
"type": str(type(message)),
}

def log_conversation(
self, query: str, response: dict, conversation: List[dict] = None
):
self,
query: str,
response: Union[ChatMessage, Dict[str, Any]],
conversation: List[ChatMessage] = None,
) -> None:
"""Log a conversation exchange to a JSON file.
Args:
query: The user's query string
response: Response containing LLM replies (either ChatMessage or dict)
conversation: Optional list of previous messages in the conversation
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = self.log_dir / f"conversation_{timestamp}.json"

log_entry = {
"timestamp": timestamp,
"query": query,
"system_info": self.system_info,
"response": response.get("llm", {}).get("replies", []),
"previous_conversation": conversation or [],
}
try:
response_meta = {}
if isinstance(response, dict) and "llm" in response:
response_meta = response.get("llm", {}).get("meta", {})

log_entry = {
"timestamp": timestamp,
"query": query,
"system_info": self.system_info,
"response": {
"llm": {
"replies": [self._serialize_chat_message(response)],
"meta": response_meta,
}
},
"previous_conversation": [
self._serialize_chat_message(msg) for msg in (conversation or [])
],
}

with open(log_file, "w", encoding="utf-8") as f:
json.dump(log_entry, f, indent=2, ensure_ascii=False)

with open(log_file, "w", encoding="utf-8") as f:
json.dump(log_entry, f, indent=2, ensure_ascii=False)
except Exception as e:
error_file = self.log_dir / f"error_{timestamp}.txt"
with open(error_file, "w", encoding="utf-8") as f:
f.write(f"Error logging conversation: {str(e)}\n")
f.write(f"Query: {query}\n")
f.write(f"Response type: {type(response)}\n")
f.write(f"Response: {str(response)}\n")
92 changes: 55 additions & 37 deletions services/api/src/core/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,32 @@
from core.model_manager import OllamaModelManager
from core.pipeline_config import QueryPipelineConfig
from haystack import Pipeline
from haystack.components.builders import ChatPromptBuilder
from haystack.dataclasses import ChatMessage


class RAGQueryPipeline:
QUERY_TEMPLATE = """
{% if conversation %}
Previous conversation:
{% for message in conversation %}
{{ message.role }}: {{ message.content }}
{% endfor %}
{% endif %}
{{ system_prompt }}
Context:
{% for document in documents %}
{{ document.content }}
Source: {{ document.meta.file_path }}
{% endfor %}
Question: {{ query }}?
template = [
ChatMessage.from_system(
"""
Answer the questions based on the given context.
{% if conversation %}
Previous conversation:
{% for message in conversation %}
{% endfor %}
{% endif %}
{{ system_prompt }}
Context:
{% for document in documents %}
{{ document.content }}
{% endfor %}
Question: {{ question }}
"""
)
]

def initialize_and_check_models(self) -> Generator[dict, None, None]:
"""Verify model availability and health, pulling models if needed."""
Expand Down Expand Up @@ -116,22 +121,21 @@ def create_query_pipeline(self) -> Pipeline:
pipeline = Pipeline()

# Create and add components
prompt_builder = self.component_factory.create_prompt_builder(
self.QUERY_TEMPLATE
)
text_embedder = self.component_factory.create_text_embedder()
embedder = self.component_factory.create_embedder()
retriever = self.component_factory.create_retriever()
llm_generator = self.component_factory.create_generator()
llm_generator = self.component_factory.create_chat_generator()

pipeline.add_component("prompt_builder", prompt_builder)
pipeline.add_component("text_embedder", text_embedder)
pipeline.add_component("embedder", embedder)
pipeline.add_component("retriever", retriever)
pipeline.add_component(
"prompt_builder", ChatPromptBuilder(template=self.template)
)
pipeline.add_component("llm", llm_generator)

# Connect components
pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
pipeline.connect("retriever.documents", "prompt_builder.documents")
pipeline.connect("prompt_builder.prompt", "llm.prompt")
pipeline.connect("embedder.embedding", "retriever.query_embedding")
pipeline.connect("retriever", "prompt_builder.documents")
pipeline.connect("prompt_builder.prompt", "llm.messages")

self.query_pipeline = pipeline
return pipeline
Expand All @@ -151,29 +155,43 @@ def run_query(
self.create_query_pipeline()

try:
messages = [ChatMessage.from_system(self.config.system_prompt)]
for message in conversation:
if message["role"] == "user":
messages.append(ChatMessage.from_user(message["content"]))
else:
messages.append(ChatMessage.from_assistant(message["content"]))

self.logger.info(f"Messages: {messages}")

# Prepare pipeline inputs
pipeline_inputs = {
"prompt_builder": {
"query": query,
"question": query,
"conversation": messages,
"system_prompt": self.config.system_prompt,
"conversation": conversation or [],
},
"text_embedder": {"text": query},
"embedder": {"text": query},
}

# Execute pipeline
response = self.query_pipeline.run(pipeline_inputs)

# Log conversation if enabled
if self.conversation_logger:
self.conversation_logger.log_conversation(query, response, conversation)
response_text = (
response["llm"]["replies"][0].text
if response["llm"]["replies"]
else None
)

# Print response if requested
if print_response and response["llm"]["replies"]:
if print_response and response_text:
self.logger.info(f"Query: {query}")
self.logger.info(f"Response: {response['llm']['replies'][0]}")
self.logger.info(f"Response: {response_text}")

# Log conversation if enabled
if self.conversation_logger:
self.conversation_logger.log_conversation(query, response, messages)

return response
return response_text

except elasticsearch.BadRequestError as e:
self.logger.error(f"Elasticsearch error: {str(e)}")
Expand Down

0 comments on commit 9838f3b

Please sign in to comment.