diff --git a/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt b/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt index 2215a6a269..1c73b4c785 100644 --- a/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt +++ b/bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt @@ -30,10 +30,7 @@ import ai.tock.bot.engine.action.SendSentence import ai.tock.bot.engine.action.SendSentenceWithFootnotes import ai.tock.bot.engine.dialog.Dialog import ai.tock.bot.engine.user.PlayerType -import ai.tock.genai.orchestratorclient.requests.ChatMessage -import ai.tock.genai.orchestratorclient.requests.ChatMessageType -import ai.tock.genai.orchestratorclient.requests.DialogDetails -import ai.tock.genai.orchestratorclient.requests.RAGQuery +import ai.tock.genai.orchestratorclient.requests.* import ai.tock.genai.orchestratorclient.responses.ObservabilityInfo import ai.tock.genai.orchestratorclient.responses.RAGResponse import ai.tock.genai.orchestratorclient.responses.TextWithFootnotes @@ -189,10 +186,14 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler { ) ), questionAnsweringLlmSetting = ragConfiguration.llmSetting, - questionAnsweringPromptInputs = mapOf( - "question" to action.toString(), - "locale" to userPreferences.locale.displayLanguage, - "no_answer" to ragConfiguration.noAnswerSentence + questionAnsweringPrompt = PromptTemplate( + formatter = Formatter.F_STRING.id, + template = ragConfiguration.llmSetting.prompt, + inputs = mapOf( + "question" to action.toString(), + "locale" to userPreferences.locale.displayLanguage, + "no_answer" to ragConfiguration.noAnswerSentence + ) ), embeddingQuestionEmSetting = ragConfiguration.emSetting, documentIndexName = indexName, diff --git a/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/requests/RAGQuery.kt b/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/requests/RAGQuery.kt index 1e5364ad7a..dfb0bafddb 100644 --- a/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/requests/RAGQuery.kt +++ b/gen-ai/orchestrator-client/src/main/kotlin/ai/tock/genai/orchestratorclient/requests/RAGQuery.kt @@ -24,10 +24,10 @@ import ai.tock.genai.orchestratorcore.models.vectorstore.VectorStoreSetting data class RAGQuery( // val condenseQuestionLlmSetting: LLMSetting, - // val condenseQuestionPromptInputs: Map, + // val condenseQuestionPrompt: PromptTemplate, val dialog: DialogDetails?, val questionAnsweringLlmSetting: LLMSetting, - val questionAnsweringPromptInputs: Map, + val questionAnsweringPrompt: PromptTemplate, val embeddingQuestionEmSetting: EMSetting, val documentIndexName: String, val documentSearchParams: DocumentSearchParamsBase, diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/llm/llm_setting.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/llm/llm_setting.py index 6f55a57123..b192192a00 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/llm/llm_setting.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/models/llm/llm_setting.py @@ -39,8 +39,3 @@ class BaseLLMSetting(BaseModel): ge=0, le=2, ) - prompt: str = Field( - description='The prompt to generate completions for.', - examples=['How to learn to ride a bike without wheels!'], - min_length=1, - ) diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/requests/requests.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/requests/requests.py index 2eb91860c1..e6bcd14a06 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/requests/requests.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/routers/requests/requests.py @@ -85,6 +85,10 @@ class BaseQuery(BaseModel): observability_setting: Optional[ObservabilitySetting] = Field( description='The observability settings.', default=None ) + compressor_setting: Optional[DocumentCompressorSetting] = Field( + description='Compressor settings, to rerank relevant documents returned by retriever.', + default=None, + ) class QAQuery(BaseQuery): @@ -159,43 +163,20 @@ class RagQuery(BaseQuery): """The RAG query model""" dialog: Optional[DialogDetails] = Field(description='The user dialog details.') - question_answering_prompt_inputs: Any = Field( - description='Key-value inputs for the llm prompt when used as a template. Please note that the ' - 'chat_history field must not be specified here, it will be override by the dialog.history field', - ) # condense_question_llm_setting: LLMSetting = # Field(description="LLM setting, used to condense the user's question.") - # condense_question_prompt_inputs: Any = ( - # Field( - # description='Key-value inputs for the condense question llm prompt, when used as a template.', - # ), + # condense_question_prompt: PromptTemplate = Field( + # description='Prompt template, used to create a prompt with inputs for jinja and fstring format' # ) question_answering_llm_setting: LLMSetting = Field( description='LLM setting, used to perform a QA Prompt.' ) - question_answering_prompt_inputs: Any = Field( - description='Key-value inputs for the llm prompt when used as a template. Please note that the ' - 'chat_history field must not be specified here, it will be override by the dialog.history field', - ) - embedding_question_em_setting: EMSetting = Field( - description="Embedding model setting, used to calculate the user's question vector." - ) - document_index_name: str = Field( - description='Index name corresponding to a document collection in the vector database.', - ) - document_search_params: DocumentSearchParams = Field( - description='The document search parameters. Ex: number of documents, metadata filter', - ) - observability_setting: Optional[ObservabilitySetting] = Field( - description='The observability settings.', default=None + question_answering_prompt : PromptTemplate = Field( + description='Prompt template, used to create a prompt with inputs for jinja and fstring format' ) guardrail_setting: Optional[GuardrailSetting] = Field( description='Guardrail settings, to classify LLM output toxicity.', default=None ) - compressor_setting: Optional[DocumentCompressorSetting] = Field( - description='Compressor settings, to rerank relevant documents returned by retriever.', - default=None, - ) documents_required: Optional[bool] = Field( description='Specifies whether the presence of documents is mandatory for generating answers. ' 'If set to True, the system will only provide answers when relevant documents are found. ' @@ -223,7 +204,11 @@ class RagQuery(BaseQuery): 'secret': 'ab7***************************A1IV4B', }, 'temperature': 1.2, - 'prompt': """Use the following context to answer the question at the end. + 'model': 'gpt-3.5-turbo', + }, + 'question_answering_prompt': { + 'formatter': 'f-string', + 'template': """Use the following context to answer the question at the end. If you don't know the answer, just say {no_answer}. Context: @@ -233,12 +218,11 @@ class RagQuery(BaseQuery): {question} Answer in {locale}:""", - 'model': 'gpt-3.5-turbo', - }, - 'question_answering_prompt_inputs': { - 'question': 'How to get started playing guitar ?', - 'no_answer': "Sorry, I don't know.", - 'locale': 'French', + 'inputs': { + 'question': 'How to get started playing guitar ?', + 'no_answer': 'Sorry, I don t know.', + 'locale': 'French', + } }, 'embedding_question_em_setting': { 'provider': 'OpenAI', diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/completion/completion_service.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/completion/completion_service.py index bda3b79b51..b65d5351f2 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/completion/completion_service.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/completion/completion_service.py @@ -16,23 +16,14 @@ import logging import time -from typing import Optional -from jinja2 import Template, TemplateError from langchain_core.output_parsers import NumberedListOutputParser from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate -from langchain_core.runnables import RunnableConfig -from gen_ai_orchestrator.errors.exceptions.exceptions import ( - GenAIPromptTemplateException, -) from gen_ai_orchestrator.errors.handlers.openai.openai_exception_handler import ( openai_exception_handler, ) -from gen_ai_orchestrator.models.errors.errors_models import ErrorInfo from gen_ai_orchestrator.models.observability.observability_trace import ObservabilityTrace -from gen_ai_orchestrator.models.prompt.prompt_formatter import PromptFormatter -from gen_ai_orchestrator.models.prompt.prompt_template import PromptTemplate from gen_ai_orchestrator.routers.requests.requests import ( SentenceGenerationQuery, ) @@ -42,6 +33,7 @@ from gen_ai_orchestrator.services.langchain.factories.langchain_factory import ( get_llm_factory, create_observability_callback_handler, ) +from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template logger = logging.getLogger(__name__) @@ -90,29 +82,3 @@ async def generate_and_split_sentences( ) return SentenceGenerationResponse(sentences=sentences) - - -def validate_prompt_template(prompt: PromptTemplate): - """ - Prompt template validation - - Args: - prompt: The prompt template - - Returns: - Nothing. - Raises: - GenAIPromptTemplateException: if template is incorrect - """ - if PromptFormatter.JINJA2 == prompt.formatter: - try: - Template(prompt.template).render(prompt.inputs) - except TemplateError as exc: - logger.error('Prompt completion - template validation failed!') - logger.error(exc) - raise GenAIPromptTemplateException( - ErrorInfo( - error=exc.__class__.__name__, - cause=str(exc), - ) - ) diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/rag_callback_handler.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/rag_callback_handler.py new file mode 100644 index 0000000000..e85003d9ac --- /dev/null +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/rag_callback_handler.py @@ -0,0 +1,61 @@ +# Copyright (C) 2023-2024 Credit Mutuel Arkea +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Retriever callback handler for LangChain.""" + +import logging +from typing import Any, Dict, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain_core.messages import SystemMessage, AIMessage +from langchain_core.prompt_values import ChatPromptValue, StringPromptValue + +logger = logging.getLogger(__name__) + + +class RAGCallbackHandler(BaseCallbackHandler): + """Customized RAG callback handler that retrieves data from the chain execution.""" + + records: Dict[str, Any] = { + 'chat_prompt': None, + 'chat_chain_output': None, + 'rag_prompt': None, + 'rag_chain_output': None, + 'documents': None, + } + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + + if kwargs['name'] == 'chat_chain_output' and isinstance(inputs, AIMessage): + self.records['chat_chain_output'] = inputs.content + + if kwargs['name'] == 'rag_chain_output' and isinstance(inputs, AIMessage): + self.records['rag_chain_output'] = inputs.content + + if kwargs['name'] == 'RunnableAssign' and 'documents' in inputs: + self.records['documents'] = inputs['documents'] + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Print out that we finished a chain.""" # if outputs is instance of StringPromptValue + + if isinstance(outputs, ChatPromptValue): + self.records['chat_prompt'] = next( + (msg.content for msg in outputs.messages if isinstance(msg, SystemMessage)), None + ) + + if isinstance(outputs, StringPromptValue): + self.records['rag_prompt'] = outputs.text diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/retriever_json_callback_handler.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/retriever_json_callback_handler.py deleted file mode 100644 index 2ee502e34a..0000000000 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/callbacks/retriever_json_callback_handler.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (C) 2023-2024 Credit Mutuel Arkea -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""Retriever callback handler for LangChain.""" - -import logging -import re -from typing import Any, Dict, List, Optional, Union - -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult - -logger = logging.getLogger(__name__) - - -class RetrieverJsonCallbackHandler(BaseCallbackHandler): - """Callback Handler that reorganize logs to json data.""" - - def __init__(self, color: Optional[str] = None) -> None: - """Initialize callback handler.""" - self.logger = logger - self.color = color - - self.records: Dict[str, Any] = { - # "on_llm_start_records": [], - # "on_llm_token_records": [], - # "on_llm_end_records": [], - 'on_chain_start_records': [], - 'on_chain_end_records': [], - # "on_tool_start_records": [], - # "on_tool_end_records": [], - 'on_text_records': [], - # "on_agent_finish_records": [], - # "on_agent_action_records": [], - 'action_records': [], - } - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - # filter to gest only input documents - if 'input_documents' in inputs: - docs = inputs['input_documents'] - input_documents = [ - {'page_content': doc.page_content, 'metadata': doc.metadata} - for doc in docs - ] - json_data = { - 'event_name': 'on_chain_start', - 'inputs': { - 'input_documents': input_documents, - 'question': inputs['question'], - 'chat_history': inputs['chat_history'], - }, - } - if json_data not in self.records['on_chain_start_records']: - self.records['on_chain_start_records'].append(json_data) - if json_data not in self.records['action_records']: - self.records['action_records'].append(json_data) - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - # reponse FAQ - if 'text' in outputs: - json_data = {'event_name': 'on_chain_end', 'output': outputs['text']} - if json_data not in self.records['on_chain_end_records']: - self.records['on_chain_end_records'].append(json_data) - if json_data not in self.records['action_records']: - self.records['action_records'].append(json_data) - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Do nothing.""" - pass - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Do nothing.""" - pass - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Do nothing.""" - pass - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = '', - **kwargs: Any, - ) -> None: - """Run when agent ends.""" - json_data = { - 'event_name': 'on_text', - 'text': self.normalise_prompt(text), - } - if json_data not in self.records['on_text_records']: - self.records['on_text_records'].append(json_data) - if json_data not in self.records['action_records']: - self.records['action_records'].append(json_data) - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def show_records(self, record_name: str = None): - """Show registered records from handler""" - if record_name != None and record_name in self.records: - records = self.records[record_name] - else: - records = self.records - return records - - - def normalise_prompt(self, prompt: str): - """ - Remove 'on after prompt' and color on prompt. - To identify the color ansi sequence, the function uses this regular expression : \x1B\[[0-?]*[ -/]*[@-~] - - Args: - prompt: the prompt to normalise - """ - - # remove ansi escape sequences - ansi_escape = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]') - prompt = ansi_escape.sub('', prompt) - - # remove a static sentence - return prompt.replace('Prompt after formatting:\n', '') - diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py index 4af1ef5d8b..259de72878 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/langchain/rag_chain.py @@ -18,8 +18,8 @@ """ import logging -import re import time +from functools import partial from logging import ERROR, WARNING from typing import List, Optional @@ -31,9 +31,12 @@ ) from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.documents import Document -from langchain_core.prompts import PromptTemplate +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate, ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import RunnablePassthrough, RunnableParallel, RunnableSerializable from langchain_core.vectorstores import VectorStoreRetriever from langfuse.callback import CallbackHandler as LangfuseCallbackHandler +from typing_extensions import Any from gen_ai_orchestrator.errors.exceptions.exceptions import ( GenAIGuardCheckException, @@ -58,10 +61,10 @@ TextWithFootnotes, ) from gen_ai_orchestrator.routers.requests.requests import RagQuery -from gen_ai_orchestrator.routers.responses.responses import RagResponse, ObservabilityInfo -from gen_ai_orchestrator.services.langchain.callbacks.retriever_json_callback_handler import ( - RetrieverJsonCallbackHandler, +from gen_ai_orchestrator.services.langchain.callbacks.rag_callback_handler import ( + RAGCallbackHandler, ) +from gen_ai_orchestrator.routers.responses.responses import RagResponse, ObservabilityInfo from gen_ai_orchestrator.services.langchain.factories.langchain_factory import ( create_observability_callback_handler, get_compressor_factory, @@ -70,13 +73,14 @@ get_llm_factory, get_vector_store_factory, ) +from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template logger = logging.getLogger(__name__) @opensearch_exception_handler @openai_exception_handler(provider='OpenAI or AzureOpenAIService') -async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse: +async def execute_rag_chain(query: RagQuery, debug: bool) -> RagResponse: """ RAG chain execution, using the LLM and Embedding settings specified in the query @@ -111,17 +115,17 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse: ) inputs = { - **query.question_answering_prompt_inputs, + **query.question_answering_prompt.inputs, 'chat_history': message_history.messages, } logger.debug( - 'RAG chain - Use RetrieverJsonCallbackHandler for debugging : %s', + 'RAG chain - Use RAGCallbackHandler for debugging : %s', debug, ) callback_handlers = [] - records_callback_handler = RetrieverJsonCallbackHandler() + records_callback_handler = RAGCallbackHandler() observability_handler = None if debug: # Debug callback handler @@ -143,7 +147,7 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse: ) # RAG Guard - __rag_guard(inputs, response, query.documents_required) + rag_guard(inputs, response, query.documents_required) # Guardrail if query.guardrail_setting: @@ -168,13 +172,13 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse: content=get_source_content(doc), score=doc.metadata.get('retriever_score', None) ), - response['source_documents'], + response['documents'], ) ), ), observability_info=get_observability_info(observability_handler), debug=get_rag_debug_data( - query, response, records_callback_handler, rag_duration + query, records_callback_handler, rag_duration ) if debug else None, @@ -208,9 +212,10 @@ def get_source_content(doc: Document) -> str: return doc.page_content -def create_rag_chain(query: RagQuery, vector_db_async_mode: Optional[bool] = True) -> ConversationalRetrievalChain: +def create_rag_chain(query: RagQuery, vector_db_async_mode: Optional[bool] = True) -> RunnableSerializable[ + Any, dict[str, Any]]: """ - Create the RAG chain from RagQuery, using the LLM and Embedding settings specified in the query + Create the RAG chain from RagQuery, using the LLM and Embedding settings specified in the query. Args: query: The RAG query @@ -218,6 +223,7 @@ def create_rag_chain(query: RagQuery, vector_db_async_mode: Optional[bool] = Tru Returns: The RAG chain. """ + llm_factory = get_llm_factory(setting=query.question_answering_llm_setting) em_factory = get_em_factory(setting=query.embedding_question_em_setting) vector_store_factory = get_vector_store_factory( @@ -233,37 +239,76 @@ def create_rag_chain(query: RagQuery, vector_db_async_mode: Optional[bool] = Tru if query.compressor_setting: retriever = add_document_compressor(retriever, query.compressor_setting) + # Log progress and validate prompt template + logger.info('RAG chain - Validating LLM prompt template') + validate_prompt_template(query.question_answering_prompt) + logger.debug('RAG chain - Document index name: %s', query.document_index_name) - logger.debug('RAG chain - Create a ConversationalRetrievalChain from LLM') - - return ConversationalRetrievalChain.from_llm( - llm=llm_factory.get_language_model(), - retriever=retriever, - return_source_documents=True, - return_generated_question=True, - combine_docs_chain_kwargs={ - 'prompt': PromptTemplate( - template=llm_factory.setting.prompt, - input_variables=__find_input_variables(llm_factory.setting.prompt), - ) - }, + + # Build LLM and prompt templates + llm = llm_factory.get_language_model() + rag_prompt = build_rag_prompt(query) + + # Construct the RAG chain using the prompt and LLM, + # This chain will consume the documents retrieved by the retriever as input. + rag_chain = construct_rag_chain(llm, rag_prompt) + + # Build the chat chain for question contextualization + chat_chain = build_question_condensation_chain(llm) + + # Function to contextualize the question based on chat history + contextualize_question_fn = partial(contextualize_question, chat_chain=chat_chain) + + # Final RAG chain with retriever and source documents + rag_chain_with_retriever = ( + contextualize_question_fn | + RunnableParallel( {"documents": retriever, "question": RunnablePassthrough()} ) | + RunnablePassthrough.assign(answer=rag_chain) ) + return rag_chain_with_retriever -def __find_input_variables(template): - """ - Search for input variables on a given template - Args: - template: the template to search on +def build_rag_prompt(query: RagQuery) -> LangChainPromptTemplate: """ + Build the RAG prompt template. + """ + return LangChainPromptTemplate.from_template( + template=query.question_answering_prompt.template, + template_format=query.question_answering_prompt.formatter.value, + partial_variables=query.question_answering_prompt.inputs + ) - motif = r'\{([^}]+)\}' - variables = re.findall(motif, template) - return variables +def construct_rag_chain(llm, rag_prompt): + """ + Construct the RAG chain from LLM and prompt. + """ + return { + "context": lambda inputs: "\n\n".join(doc.page_content for doc in inputs["documents"]), + "question": lambda inputs: inputs["question"] # Override the user's original question with the condensed one + } | rag_prompt | llm | StrOutputParser(name="rag_chain_output") +def build_question_condensation_chain(llm) -> ChatPromptTemplate: + """ + Build the chat chain for contextualizing questions. + """ + return ChatPromptTemplate.from_messages([ + ("system", """Given a chat history and the latest user question which might reference context in \ + the chat history, formulate a standalone question which can be understood without the chat history. \ + Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""), + MessagesPlaceholder(variable_name="chat_history"), + ("human", "{question}"), + ]) | llm | StrOutputParser(name="chat_chain_output") + +def contextualize_question(inputs: dict, chat_chain) -> str: + """ + Contextualize the question based on the chat history. + """ + if inputs.get("chat_history") and len(inputs["chat_history"]) > 0: + return chat_chain + return inputs["question"] -def __rag_guard(inputs, response, documents_required): +def rag_guard(inputs, response, documents_required): """ Validates the RAG system's response based on the presence or absence of source documents and the `documentsRequired` setting. @@ -274,7 +319,7 @@ def __rag_guard(inputs, response, documents_required): documents_required (bool): Specifies whether documents are mandatory for the response. """ - no_docs_retrieved = response['source_documents'] == [] + no_docs_retrieved = response['documents'] == [] no_docs_but_required = no_docs_retrieved and documents_required chain_can_give_no_answer_reply = 'no_answer' in inputs chain_reply_no_answer = False @@ -287,12 +332,12 @@ def __rag_guard(inputs, response, documents_required): return # Everything else isn't expected message = 'The RAG system cannot provide an answer when no documents are found and documents are required' - __rag_log(level=ERROR, message=message, inputs=inputs, response=response) + rag_log(level=ERROR, message=message, inputs=inputs, response=response) raise GenAIGuardCheckException(ErrorInfo(cause=message)) return -def __rag_log(level, message, inputs, response): +def rag_log(level, message, inputs, response): """ RAG logging @@ -311,68 +356,43 @@ def __rag_log(level, message, inputs, response): 'message': message, 'question': inputs['question'], 'answer': response['answer'], - 'documents': response['source_documents'], + 'documents': response['documents'], }, ) -def get_rag_documents(response) -> List[RagDocument]: +def get_rag_documents(handler: RAGCallbackHandler) -> List[RagDocument]: """ Get documents used on RAG context Args: response: the rag answer """ + return [ # Get first 100 char of content RagDocument( - content=doc.page_content[0: len(doc.metadata['title']) + 100] + '...', + content=doc.page_content[0:len(doc.metadata['title'])+100] + '...', metadata=RagDocumentMetadata(**doc.metadata), ) - for doc in response['source_documents'] + for doc in handler.records['documents'] ] -def get_condense_question(handler: RetrieverJsonCallbackHandler) -> Optional[str]: - """Get the condensed question""" - - on_text_records = handler.show_records('on_text_records') - # If the handler records 2 texts (prompts), this means that 2 LLM providers are invoked - if len(on_text_records) == 2: - # So the user question is condensed - on_chain_start_records = handler.show_records('on_chain_start_records') - return on_chain_start_records[0]['inputs']['question'] - else: - # Else, the user's question was not formulated - return None - - -def get_llm_prompts(handler: RetrieverJsonCallbackHandler) -> (Optional[str], str): - """Get used llm prompt""" - - on_text_records = handler.show_records('on_text_records') - # If the handler records 2 texts (prompts), this means that 2 LLM providers are invoked - if len(on_text_records) == 2: - return on_text_records[0]['text'], on_text_records[1]['text'] - - # Else, only the LLM for "question answering" was invoked - return None, on_text_records[0]['text'] - - def get_rag_debug_data( - query: RagQuery, response, records_callback_handler, rag_duration + query: RagQuery, records_callback_handler: RAGCallbackHandler, rag_duration ) -> RagDebugData: """RAG debug data assembly""" return RagDebugData( - user_question=query.question_answering_prompt_inputs['question'], - condense_question_prompt=get_llm_prompts(records_callback_handler)[0], - condense_question=get_condense_question(records_callback_handler), - question_answering_prompt=get_llm_prompts(records_callback_handler)[1], - documents=get_rag_documents(response), + user_question=query.question_answering_prompt.inputs['question'], + condense_question_prompt=records_callback_handler.records['chat_prompt'], + condense_question=records_callback_handler.records['chat_chain_output'], + question_answering_prompt=records_callback_handler.records['rag_prompt'], + documents=get_rag_documents(records_callback_handler), document_index_name=query.document_index_name, document_search_params=query.document_search_params, - answer=response['answer'], + answer=records_callback_handler.records['rag_chain_output'], duration=rag_duration, ) diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/llm/llm_service.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/llm/llm_service.py index 14536566e9..b9f9a25d3a 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/llm/llm_service.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/llm/llm_service.py @@ -51,28 +51,3 @@ async def check_llm_setting(query: LLMProviderSettingStatusQuery) -> bool: trace_name=ObservabilityTrace.CHECK_LLM_SETTINGS.value) return await get_llm_factory(query.setting).check_llm_setting(langfuse_callback_handler) - - -def llm_inference_with_parser( - llm_factory: LangChainLLMFactory, parser: BaseOutputParser -) -> AIMessage: - """ - Perform LLM inference and format the output content based on the given parser. - - :param llm_factory: LangChain LLM Factory. - :param parser: Parser to format the output. - - :return: Result of the language model inference with the content formatted. - """ - - # Change the prompt with added format instructions - format_instructions = parser.get_format_instructions() - formatted_prompt = llm_factory.setting.prompt + '\n' + format_instructions - - # Inference of the LLM with the formatted prompt - llm_output = llm_factory.invoke(formatted_prompt) - - # Apply the parsing on the LLM output - llm_output.content = parser.parse(llm_output.content) - - return llm_output diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/rag/rag_service.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/rag/rag_service.py index cdb931012e..6d5b40d909 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/rag/rag_service.py +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/rag/rag_service.py @@ -16,9 +16,9 @@ from gen_ai_orchestrator.routers.requests.requests import RagQuery from gen_ai_orchestrator.routers.responses.responses import RagResponse -from gen_ai_orchestrator.services.langchain.rag_chain import execute_qa_chain +from gen_ai_orchestrator.services.langchain.rag_chain import execute_rag_chain async def rag(query: RagQuery, debug: bool) -> RagResponse: """Launch execution of the RAG chain""" - return await execute_qa_chain(query, debug) + return await execute_rag_chain(query, debug) diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/__init__.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/__init__.py new file mode 100644 index 0000000000..0b6c73c789 --- /dev/null +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright (C) 2024 Credit Mutuel Arkea +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/prompt_utility.py b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/prompt_utility.py new file mode 100644 index 0000000000..7ec3af49fb --- /dev/null +++ b/gen-ai/orchestrator-server/src/main/python/server/src/gen_ai_orchestrator/services/utils/prompt_utility.py @@ -0,0 +1,37 @@ +import logging + +from jinja2 import Template, TemplateError + +from gen_ai_orchestrator.errors.exceptions.exceptions import ( + GenAIPromptTemplateException, +) +from gen_ai_orchestrator.models.errors.errors_models import ErrorInfo +from gen_ai_orchestrator.models.prompt.prompt_formatter import PromptFormatter +from gen_ai_orchestrator.models.prompt.prompt_template import PromptTemplate + +logger = logging.getLogger(__name__) + +def validate_prompt_template(prompt: PromptTemplate): + """ + Prompt template validation + + Args: + prompt: The prompt template + + Returns: + Nothing. + Raises: + GenAIPromptTemplateException: if template is incorrect + """ + if PromptFormatter.JINJA2 == prompt.formatter: + try: + Template(prompt.template).render(prompt.inputs) + except TemplateError as exc: + logger.error('Prompt completion - template validation failed!') + logger.error(exc) + raise GenAIPromptTemplateException( + ErrorInfo( + error=exc.__class__.__name__, + cause=str(exc), + ) + ) diff --git a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_completion_service.py b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_completion_service.py index 6637b7af7d..58c1a00aa2 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_completion_service.py +++ b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_completion_service.py @@ -13,9 +13,7 @@ # limitations under the License. # from gen_ai_orchestrator.models.prompt.prompt_template import PromptTemplate -from gen_ai_orchestrator.services.completion.completion_service import ( - validate_prompt_template, -) +from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template def test_validate_prompt_template(): diff --git a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_langchain_callbacks.py b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_langchain_callbacks.py index af2a66a6d0..8895550589 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_langchain_callbacks.py +++ b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_langchain_callbacks.py @@ -13,112 +13,58 @@ # limitations under the License. # from langchain_core.documents import Document +from langchain_core.messages import AIMessage, SystemMessage, HumanMessage +from langchain_core.prompt_values import StringPromptValue, ChatPromptValue -from gen_ai_orchestrator.services.langchain.callbacks.retriever_json_callback_handler import ( - RetrieverJsonCallbackHandler, +from gen_ai_orchestrator.services.langchain.callbacks.rag_callback_handler import ( + RAGCallbackHandler, ) -def test_retriever_json_callback_handler_on_chain_start(): +def test_rag_callback_handler_qa_documents(): """Check records are added (in the correct entries)""" - handler = RetrieverJsonCallbackHandler() - _inputs = { - 'input_documents': [ - Document( - page_content='some page content', - metadata={'some meta': 'some meta value'}, - ) - ], - 'question': 'What is happening?', - 'chat_history': [], - } - handler.on_chain_start(serialized={}, inputs=_inputs) - expected_json_data = { - 'event_name': 'on_chain_start', - 'inputs': { - 'input_documents': [ - { - 'page_content': 'some page content', - 'metadata': {'some meta': 'some meta value'}, - } - ], - 'question': _inputs['question'], - 'chat_history': _inputs['chat_history'], - }, - } - assert handler.records['on_chain_start_records'][0] == expected_json_data - assert handler.records['action_records'][0] == expected_json_data - - -def test_retriever_json_callback_handler_on_chain_start_no_double_entries(): - """Check records are added only once in history.""" - handler = RetrieverJsonCallbackHandler() - _inputs = { - 'input_documents': [ - Document( - page_content='some page content', - metadata={'some meta': 'some meta value'}, - ) - ], - 'question': 'What is happening?', - 'chat_history': [], - } - handler.on_chain_start(serialized={}, inputs=_inputs) - expected_json_data = { - 'event_name': 'on_chain_start', - 'inputs': { - 'input_documents': [ - { - 'page_content': 'some page content', - 'metadata': {'some meta': 'some meta value'}, - } - ], - 'question': _inputs['question'], - 'chat_history': _inputs['chat_history'], - }, - } - assert expected_json_data in handler.records['on_chain_start_records'] - assert expected_json_data in handler.records['action_records'] - assert len(handler.records['on_chain_start_records']) == 1 - assert len(handler.records['action_records']) == 1 - handler.on_chain_start(serialized={}, inputs=_inputs) - assert expected_json_data in handler.records['on_chain_start_records'] - assert expected_json_data in handler.records['action_records'] - assert len(handler.records['on_chain_start_records']) == 1 - assert len(handler.records['action_records']) == 1 - - -def test_retriever_json_callback_handler_on_chain_start_no_inputs(): - """Check no records are added if none are present in chain inputs.""" - handler = RetrieverJsonCallbackHandler() - _inputs = {'question': 'What is happening?', 'chat_history': []} - handler.on_chain_start(serialized={}, inputs=_inputs) - assert len(handler.records['on_chain_start_records']) == 0 - assert len(handler.records['action_records']) == 0 + handler = RAGCallbackHandler() + docs = [Document( + page_content='some page content', + metadata={'some meta': 'some meta value'}, + )] + handler.on_chain_start(serialized={}, + inputs={'documents': docs}, + **{'name': 'RunnableAssign'}) + assert handler.records['documents'] == docs +def test_rag_callback_handler_chat_prompt_output(): + """Check records are added (in the correct entries)""" + handler = RAGCallbackHandler() + llm_output = 'llm result !' + handler.on_chain_start(serialized={}, + inputs=AIMessage(content=llm_output), + **{'name': 'chat_chain_output'}) + assert handler.records['chat_chain_output'] == llm_output -def test_retriever_json_callback_handler_on_chain_end(): +def test_rag_callback_handler_qa_prompt_output(): """Check records are added (in the correct entries)""" - handler = RetrieverJsonCallbackHandler() - _outputs = { - 'text': 'This is what is happening', - } - handler.on_chain_end(outputs=_outputs) - expected_json_data = { - 'event_name': 'on_chain_end', - 'output': 'This is what is happening', - } - assert handler.records['on_chain_end_records'][0] == expected_json_data - assert handler.records['action_records'][0] == expected_json_data + handler = RAGCallbackHandler() + llm_output = 'llm result !' + handler.on_chain_start(serialized={}, + inputs=AIMessage(content=llm_output), + **{'name': 'rag_chain_output'}) + assert handler.records['rag_chain_output'] == llm_output +def test_rag_callback_handler_chat_prompt(): + """Check records are added (in the correct entries)""" + handler = RAGCallbackHandler() + prompt = 'A custom prompt !' + outputs = ChatPromptValue(messages=[ + SystemMessage(content=prompt), + HumanMessage(content='hi !') + ]) + handler.on_chain_end(serialized={}, outputs=outputs) + assert handler.records['chat_prompt'] == prompt -def test_retriever_json_callback_handler_on_text(): +def test_rag_callback_handler_qa_prompt(): """Check records are added (in the correct entries)""" - handler = RetrieverJsonCallbackHandler() - handler.on_text(text='Some text arrives') - expected_json_data = { - 'event_name': 'on_text', - 'text': 'Some text arrives', - } - assert handler.records['on_text_records'][0] == expected_json_data - assert handler.records['action_records'][0] == expected_json_data + handler = RAGCallbackHandler() + prompt = 'A custom prompt !' + handler.on_chain_end(serialized={}, outputs=StringPromptValue(text=prompt)) + assert handler.records['rag_prompt'] == prompt diff --git a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py index c6c9470c75..6e26bed0af 100644 --- a/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py +++ b/gen-ai/orchestrator-server/src/main/python/server/tests/services/test_rag_chain.py @@ -30,48 +30,23 @@ ) from gen_ai_orchestrator.routers.requests.requests import RagQuery from gen_ai_orchestrator.services.langchain import rag_chain -from gen_ai_orchestrator.services.langchain.callbacks.retriever_json_callback_handler import ( - RetrieverJsonCallbackHandler, -) + from gen_ai_orchestrator.services.langchain.factories.langchain_factory import ( get_guardrail_factory, ) from gen_ai_orchestrator.services.langchain.impls.document_compressor.bloomz_rerank import BloomzRerank from gen_ai_orchestrator.services.langchain.rag_chain import ( check_guardrail_output, - execute_qa_chain, - get_condense_question, - get_llm_prompts, + execute_rag_chain, ) -# 'Mock an item where it is used, not where it came from.' -# (https://www.toptal.com/python/an-introduction-to-mocking-in-python) -# See https://docs.python.org/3/library/unittest.mock.html#where-to-patch -# Here: -# --> Not where it came from: -# @patch('llm_orchestrator.services.langchain.factories.langchain_factory.get_llm_factory') -# --> But where it is used (in the execute_qa_chain method of the llm_orchestrator.services.langchain.rag_chain -# module that imports get_llm_factory): - - -@patch( - 'gen_ai_orchestrator.services.langchain.rag_chain.ContextualCompressionRetriever' -) @patch('gen_ai_orchestrator.services.langchain.impls.document_compressor.bloomz_rerank.requests.post') -@patch( - 'gen_ai_orchestrator.services.langchain.factories.langchain_factory.get_callback_handler_factory' -) -@patch('gen_ai_orchestrator.services.langchain.rag_chain.get_llm_factory') -@patch('gen_ai_orchestrator.services.langchain.rag_chain.get_em_factory') -@patch('gen_ai_orchestrator.services.langchain.rag_chain.get_vector_store_factory') -@patch('gen_ai_orchestrator.services.langchain.rag_chain.PromptTemplate') -@patch('gen_ai_orchestrator.services.langchain.rag_chain.__find_input_variables') -@patch( - 'gen_ai_orchestrator.services.langchain.rag_chain.ConversationalRetrievalChain.from_llm' -) -@patch('gen_ai_orchestrator.services.langchain.rag_chain.RetrieverJsonCallbackHandler') -@patch('gen_ai_orchestrator.services.langchain.rag_chain.__rag_guard') +@patch('gen_ai_orchestrator.services.langchain.factories.langchain_factory.get_compressor_factory') +@patch('gen_ai_orchestrator.services.langchain.factories.langchain_factory.get_callback_handler_factory') +@patch('gen_ai_orchestrator.services.langchain.rag_chain.create_rag_chain') +@patch('gen_ai_orchestrator.services.langchain.rag_chain.RAGCallbackHandler') +@patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_guard') @patch('gen_ai_orchestrator.services.langchain.rag_chain.RagResponse') @patch('gen_ai_orchestrator.services.langchain.rag_chain.TextWithFootnotes') @patch('gen_ai_orchestrator.services.langchain.rag_chain.RagDebugData') @@ -82,15 +57,10 @@ async def test_rag_chain( mocked_rag_response, mocked_rag_guard, mocked_callback_init, - mocked_chain_builder, - mocked_find_input_variables, - mocked_prompt_template, - mocked_get_vector_store_factory, - mocked_get_em_factory, - mocked_get_llm_factory, + mocked_create_rag_chain, mocked_get_callback_handler_factory, + mocked_get_document_compressor_factory, mocked_guardrail_parse, - mocked_compressor_builder, ): """Test the full execute_qa_chain method by mocking all external calls.""" # Build a test RagQuery @@ -106,7 +76,11 @@ async def test_rag_chain( 'provider': 'OpenAI', 'api_key': {'type': 'Raw', 'secret': 'ab7***************************A1IV4B'}, 'temperature': 1.2, - 'prompt': """Use the following context to answer the question at the end. + 'model': 'gpt-3.5-turbo', + }, + 'question_answering_prompt': { + 'formatter': 'f-string', + 'template': """Use the following context to answer the question at the end. If you don't know the answer, just say {no_answer}. Context: @@ -116,12 +90,11 @@ async def test_rag_chain( {question} Answer in {locale}:""", - 'model': 'gpt-3.5-turbo', - }, - 'question_answering_prompt_inputs': { - 'question': 'How to get started playing guitar ?', - 'no_answer': 'Sorry, I don t know.', - 'locale': 'French', + 'inputs' : { + 'question': 'How to get started playing guitar ?', + 'no_answer': 'Sorry, I don t know.', + 'locale': 'French', + } }, 'embedding_question_em_setting': { 'provider': 'OpenAI', @@ -172,18 +145,25 @@ async def test_rag_chain( 'documents_required': True, } query = RagQuery(**query_dict) + inputs = { + **query.question_answering_prompt.inputs, + 'chat_history': [ + HumanMessage(content='Hello, how can I do this?'), + AIMessage(content='you can do this with the following method ....'), + ], + } + docs = [Document( + page_content='some page content', + metadata={'id':'123-abc', 'title':'my-title', 'source': None}, + )] + response = {'answer': 'an answer from llm', 'documents': docs} # Setup mock factories/init return value - em_factory_instance = mocked_get_em_factory.return_value - llm_factory_instance = mocked_get_llm_factory.return_value observability_factory_instance = mocked_get_callback_handler_factory.return_value - mocked_chain = mocked_chain_builder.return_value mocked_callback = mocked_callback_init.return_value - mocked_compressor = mocked_compressor_builder.return_value mocked_langfuse_callback = observability_factory_instance.get_callback_handler() - mocked_chain.ainvoke = AsyncMock( - return_value={'answer': 'an answer from llm', 'source_documents': []} - ) + mocked_chain = mocked_create_rag_chain.return_value + mocked_chain.ainvoke = AsyncMock(return_value=response) mocked_rag_answer = mocked_chain.ainvoke.return_value mocked_response = MagicMock() @@ -193,47 +173,15 @@ async def test_rag_chain( mocked_guardrail_parse.return_value = mocked_response # Call function - await execute_qa_chain(query, debug=True) + await execute_rag_chain(query, debug=True) - # Assert factories are called with the expected settings from query - mocked_get_llm_factory.assert_called_once_with( - setting=query.question_answering_llm_setting - ) - mocked_get_em_factory.assert_called_once_with( - setting=query.embedding_question_em_setting - ) - mocked_get_vector_store_factory.assert_called_once_with( - setting=query.vector_store_setting, - index_name=query.document_index_name, - embedding_function=em_factory_instance.get_embedding_model(), - ) + # Assert that the given observability_setting is used mocked_get_callback_handler_factory.assert_called_once_with( setting=query.observability_setting ) - - # Assert LangChain qa chain is created using the expected settings from query - mocked_chain_builder.assert_called_once_with( - llm=llm_factory_instance.get_language_model(), - retriever=mocked_compressor, - return_source_documents=True, - return_generated_question=True, - combine_docs_chain_kwargs={ - # PromptTemplate must be mocked or searching for params in it will fail - 'prompt': mocked_prompt_template( - template=query.question_answering_llm_setting.prompt, - input_variables=['no_answer', 'context', 'question', 'locale'], - ) - }, - ) # Assert qa chain is ainvoke()d with the expected settings from query mocked_chain.ainvoke.assert_called_once_with( - input={ - **query.question_answering_prompt_inputs, - 'chat_history': [ - HumanMessage(content='Hello, how can I do this?'), - AIMessage(content='you can do this with the following method ....'), - ], - }, + input=inputs, config={'callbacks': [mocked_callback, mocked_langfuse_callback]}, ) # Assert the response is build using the expected settings @@ -245,13 +193,18 @@ async def test_rag_chain( debug=mocked_rag_debug_data(query, mocked_rag_answer, mocked_callback, 1), observability_info=None ) - + mocked_get_document_compressor_factory( + setting=query.compressor_setting + ) + # Assert the rag guardrail is called mocked_guardrail_parse.assert_called_once_with( os.path.join(query.guardrail_setting.api_base, 'guardrail'), json={'text': [mocked_rag_answer['answer']]}, ) - mocked_compressor_builder.assert_called_once() - + # Assert the rag guard is called + mocked_rag_guard.assert_called_once_with( + inputs, response, query.documents_required + ) @patch('gen_ai_orchestrator.services.langchain.impls.guardrail.bloomz_guardrail.requests.post') def test_guardrail_parse_succeed_with_toxicities_encountered( @@ -454,113 +407,60 @@ def test_check_guardrail_output_is_ok(): assert check_guardrail_output(guardrail_output) is True -def test_find_input_variables(): - template = 'This is a {sample} text with {multiple} curly brace sections' - input_vars = rag_chain.__find_input_variables(template) - assert input_vars == ['sample', 'multiple'] - - -@patch('gen_ai_orchestrator.services.langchain.rag_chain.__rag_log') +@patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_log') def test_rag_guard_fails_if_no_docs_in_valid_answer(mocked_log): inputs = {'no_answer': "Sorry, I don't know."} response = { 'answer': 'a valid answer', - 'source_documents': [], + 'documents': [], } try: - rag_chain.__rag_guard(inputs, response,documents_required=True) + rag_chain.rag_guard(inputs, response,documents_required=True) except Exception as e: assert isinstance(e, GenAIGuardCheckException) -@patch('gen_ai_orchestrator.services.langchain.rag_chain.__rag_log') +@patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_log') def test_rag_guard_accepts_no_answer_even_with_docs(mocked_log): inputs = {'no_answer': "Sorry, I don't know."} response = { 'answer': "Sorry, I don't know.", - 'source_documents': ['a doc as a string'], + 'documents': ['a doc as a string'], } - rag_chain.__rag_guard(inputs, response, documents_required=True) - assert response['source_documents'] == ['a doc as a string'] - - -def test_get_llm_prompts_one_record(): - handler = RetrieverJsonCallbackHandler() - handler.on_text(text='LLM 1') - llm_1, llm_2 = get_llm_prompts(handler) - assert llm_1 is None - assert llm_2 == 'LLM 1' - - -def test_get_llm_prompts_one_record(): - handler = RetrieverJsonCallbackHandler() - handler.on_text(text='LLM 1') - handler.on_text(text='LLM 2') - llm_1, llm_2 = get_llm_prompts(handler) - assert llm_1 == 'LLM 1' - assert llm_2 == 'LLM 2' - - -def test_get_condense_question_none(): - handler = RetrieverJsonCallbackHandler() - handler.on_text(text='LLM 1') - handler.on_chain_start( - serialized={}, - inputs={ - 'input_documents': [], - 'question': 'Is this a question ?', - 'chat_history': 'chat_history', - }, - ) - question = get_condense_question(handler) - assert question is None - - -def test_get_condense_question(): - handler = RetrieverJsonCallbackHandler() - handler.on_text(text='LLM 1') - handler.on_text(text='LLM 2') - handler.on_chain_start( - serialized={}, - inputs={ - 'input_documents': [], - 'question': 'Is this a question ?', - 'chat_history': 'chat_history', - }, - ) - question = get_condense_question(handler) - assert question == 'Is this a question ?' + rag_chain.rag_guard(inputs, response, documents_required=True) + assert response['documents'] == ['a doc as a string'] + -@patch('gen_ai_orchestrator.services.langchain.rag_chain.__rag_log') +@patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_log') def test_rag_guard_valid_answer_with_docs(mocked_log): inputs = {'no_answer': "Sorry, I don't know."} response = { 'answer': 'a valid answer', - 'source_documents': ['doc1', 'doc2'], + 'documents': ['doc1', 'doc2'], } - rag_chain.__rag_guard(inputs, response, documents_required=True) - assert response['source_documents'] == ['doc1', 'doc2'] + rag_chain.rag_guard(inputs, response, documents_required=True) + assert response['documents'] == ['doc1', 'doc2'] -@patch('gen_ai_orchestrator.services.langchain.rag_chain.__rag_log') +@patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_log') def test_rag_guard_no_answer_with_no_docs(mocked_log): inputs = {'no_answer': "Sorry, I don't know."} response = { 'answer': "Sorry, I don't know.", - 'source_documents': [], + 'documents': [], } - rag_chain.__rag_guard(inputs, response, documents_required=True) - assert response['source_documents'] == [] + rag_chain.rag_guard(inputs, response, documents_required=True) + assert response['documents'] == [] -@patch('gen_ai_orchestrator.services.langchain.rag_chain.__rag_log') +@patch('gen_ai_orchestrator.services.langchain.rag_chain.rag_log') def test_rag_guard_without_no_answer_input(mocked_log): """Test that __rag_guard handles missing no_answer input correctly.""" inputs = {} # No 'no_answer' key response = { 'answer': 'some answer', - 'source_documents': [], + 'documents': [], } with pytest.raises(GenAIGuardCheckException) as exc: - rag_chain.__rag_guard(inputs, response, documents_required=True) + rag_chain.rag_guard(inputs, response, documents_required=True) mocked_log.assert_called_once()