Skip to content

Commit

Permalink
Add return_contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
szymon-planeta committed Jan 29, 2024
1 parent 159453f commit e5ff00e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 2 deletions.
15 changes: 13 additions & 2 deletions backend/danswer/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]


class DanswerContext(BaseModel):
content: str
document_id: str
semantic_identifier: str
blurb: str


class DanswerContexts(BaseModel):
contexts: list[DanswerContext]


class DanswerAnswer(BaseModel):
answer: str | None

Expand All @@ -87,11 +98,11 @@ class QAResponse(SearchResponse, DanswerAnswer):
error_msg: str | None = None


AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes]
AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes, DanswerContexts]


AnswerQuestionStreamReturn = Iterator[
DanswerAnswerPiece | DanswerQuotes | StreamingError
DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError
]


Expand Down
22 changes: 22 additions & 0 deletions backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections.abc import Callable
from collections.abc import Iterator
import itertools
from typing import cast

from sqlalchemy.orm import Session

from danswer.chat.chat_utils import get_chunks_for_qa
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContext
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LLMMetricsContainer
from danswer.chat.models import LLMRelevanceFilterResponse
Expand Down Expand Up @@ -67,6 +70,7 @@ def stream_answer_objects(
| LLMRelevanceFilterResponse
| DanswerAnswerPiece
| DanswerQuotes
| DanswerContexts
| StreamingError
| ChatMessageDetail
]:
Expand Down Expand Up @@ -229,6 +233,22 @@ def stream_answer_objects(
else no_gen_ai_response()
)

if qa_model is not None and query_req.return_contexts:
contexts = DanswerContexts(
contexts=[
DanswerContext(
content=context_doc.content,
document_id=context_doc.document_id,
semantic_identifier=context_doc.semantic_identifier,
blurb=context_doc.semantic_identifier,
)
for context_doc in llm_chunks
]
)

response_packets = itertools.chain(response_packets, [contexts])


# Capture outputs and errors
llm_output = ""
error: str | None = None
Expand Down Expand Up @@ -316,6 +336,8 @@ def get_search_answer(
qa_response.llm_chunks_indices = packet.relevant_chunk_indices
elif isinstance(packet, DanswerQuotes):
qa_response.quotes = packet
elif isinstance(packet, DanswerContexts):
qa_response.quotes = packet
elif isinstance(packet, StreamingError):
qa_response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/one_shot_answer/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ def answer_question_stream(
prompt: str,
llm_context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
return_contexts: bool = False,
) -> AnswerQuestionStreamReturn:
raise NotImplementedError
3 changes: 3 additions & 0 deletions backend/danswer/one_shot_answer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel
from pydantic import root_validator

from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
Expand All @@ -25,6 +26,7 @@ class DirectQARequest(BaseModel):
persona_id: int
retrieval_options: RetrievalDetails
chain_of_thought: bool = False
return_contexts: bool = False

@root_validator
def check_chain_of_thought_and_prompt_id(
Expand Down Expand Up @@ -53,3 +55,4 @@ class OneShotQAResponse(BaseModel):
error_msg: str | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
chat_message_id: int | None = None
contexts: DanswerContexts | None = None

0 comments on commit e5ff00e

Please sign in to comment.