diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index e04efd92a2f..cd83f1a36a3 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -74,6 +74,17 @@ class DanswerQuotes(BaseModel): quotes: list[DanswerQuote] +class DanswerContext(BaseModel): + content: str + document_id: str + semantic_identifier: str + blurb: str + + +class DanswerContexts(BaseModel): + contexts: list[DanswerContext] + + class DanswerAnswer(BaseModel): answer: str | None @@ -87,11 +98,11 @@ class QAResponse(SearchResponse, DanswerAnswer): error_msg: str | None = None -AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes] +AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes, DanswerContexts] AnswerQuestionStreamReturn = Iterator[ - DanswerAnswerPiece | DanswerQuotes | StreamingError + DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError ] diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 04f95fbdbb2..ed3b95b1ecd 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -1,11 +1,14 @@ from collections.abc import Callable from collections.abc import Iterator +import itertools from typing import cast from sqlalchemy.orm import Session from danswer.chat.chat_utils import get_chunks_for_qa from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import DanswerContext +from danswer.chat.models import DanswerContexts from danswer.chat.models import DanswerQuotes from danswer.chat.models import LLMMetricsContainer from danswer.chat.models import LLMRelevanceFilterResponse @@ -67,6 +70,7 @@ def stream_answer_objects( | LLMRelevanceFilterResponse | DanswerAnswerPiece | DanswerQuotes + | DanswerContexts | StreamingError | ChatMessageDetail ]: @@ -229,6 +233,22 @@ def stream_answer_objects( else no_gen_ai_response() ) + if qa_model is not None and query_req.return_contexts: + contexts = DanswerContexts( + contexts=[ + DanswerContext( + content=context_doc.content, + document_id=context_doc.document_id, + semantic_identifier=context_doc.semantic_identifier, + blurb=context_doc.semantic_identifier, + ) + for context_doc in llm_chunks + ] + ) + + response_packets = itertools.chain(response_packets, [contexts]) + + # Capture outputs and errors llm_output = "" error: str | None = None @@ -316,6 +336,8 @@ def get_search_answer( qa_response.llm_chunks_indices = packet.relevant_chunk_indices elif isinstance(packet, DanswerQuotes): qa_response.quotes = packet + elif isinstance(packet, DanswerContexts): + qa_response.quotes = packet elif isinstance(packet, StreamingError): qa_response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): diff --git a/backend/danswer/one_shot_answer/interfaces.py b/backend/danswer/one_shot_answer/interfaces.py index ca916d699df..ad16d4c9e30 100644 --- a/backend/danswer/one_shot_answer/interfaces.py +++ b/backend/danswer/one_shot_answer/interfaces.py @@ -22,5 +22,6 @@ def answer_question_stream( prompt: str, llm_context_docs: list[InferenceChunk], metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, + return_contexts: bool = False, ) -> AnswerQuestionStreamReturn: raise NotImplementedError diff --git a/backend/danswer/one_shot_answer/models.py b/backend/danswer/one_shot_answer/models.py index 1e5d94d27c7..6401b34404e 100644 --- a/backend/danswer/one_shot_answer/models.py +++ b/backend/danswer/one_shot_answer/models.py @@ -3,6 +3,7 @@ from pydantic import BaseModel from pydantic import root_validator +from danswer.chat.models import DanswerContexts from danswer.chat.models import DanswerQuotes from danswer.chat.models import QADocsResponse from danswer.configs.constants import MessageType @@ -25,6 +26,7 @@ class DirectQARequest(BaseModel): persona_id: int retrieval_options: RetrievalDetails chain_of_thought: bool = False + return_contexts: bool = False @root_validator def check_chain_of_thought_and_prompt_id( @@ -53,3 +55,4 @@ class OneShotQAResponse(BaseModel): error_msg: str | None = None answer_valid: bool = True # Reflexion result, default True if Reflexion not run chat_message_id: int | None = None + contexts: DanswerContexts | None = None