From cb6052ccf393215add4a6dcd60623803a7cefd23 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Fri, 24 Nov 2023 02:45:49 +0800 Subject: [PATCH] refactor: abstract duplicate query into a function --- server/api/v1/query.py | 36 +----------- server/api/v1/query_with_image.py | 37 +----------- server/api/v1/search.py | 5 +- server/features/__init__.py | 2 +- server/features/query/__init__.py | 1 + server/features/query/query_llm.py | 58 +++++++++++++++++++ .../question_answering.py} | 0 server/schemas/v1/query.py | 2 - 8 files changed, 69 insertions(+), 72 deletions(-) create mode 100644 server/features/query/__init__.py create mode 100644 server/features/query/query_llm.py rename server/features/{question_answering/__init__.py => query/question_answering.py} (100%) diff --git a/server/api/v1/query.py b/server/api/v1/query.py index eec368b..a9f401b 100644 --- a/server/api/v1/query.py +++ b/server/api/v1/query.py @@ -1,16 +1,11 @@ -from json import dumps from typing import Annotated from fastapi import Depends from redis.asyncio import Redis from server.api.v1 import v1 -from server.config import Config -from server.databases.redis import create_query_parameters, redis_get -from server.databases.redis import redis_query as redis_query_helper from server.dependencies import get_redis_client -from server.features import LLM, Embedding, question_answering -from server.features.llm.types import Message +from server.features import query_llm from server.schemas.v1 import Answer, Query @@ -19,6 +14,7 @@ async def query( redis: Annotated[Redis, Depends(get_redis_client)], chat_id: str, request: Query, + top_k: int = 5, store_query: bool = True ) -> Answer: """ @@ -26,31 +22,5 @@ async def query( ------- the `/query` route provides an endpoint for performning retrieval-augmented generation """ - redis_query = redis_query_helper('tag', chat_id, request.top_k) - - redis_query_parameters = create_query_parameters( - Embedding().encode_query(request.query) - ) - - search_response = await redis.ft(Config.redis_index_name).search( - redis_query, - redis_query_parameters # type: ignore (this is a bug in the redis-py library) - ) - - context = ' '.join( - document['content'] for document - in search_response.docs # type: ignore - ) - - message_history: list[Message] = await redis_get(redis, f'chat:{chat_id}', _ := []) - message_history.append({ - 'role': 'user', - 'content': f'Given the following context:\n\n{context}\n\nPlease answer the following question:\n\n{request.query}' - }) - - messages = question_answering(message_history, LLM.query) - - if not store_query: - await redis.set(f'chat:{chat_id}', dumps(messages)) - + messages = await query_llm(redis, request.query, chat_id, top_k, store_query) return Answer(messages=messages) diff --git a/server/api/v1/query_with_image.py b/server/api/v1/query_with_image.py index 7a5f14c..30c06cc 100644 --- a/server/api/v1/query_with_image.py +++ b/server/api/v1/query_with_image.py @@ -1,16 +1,11 @@ -from json import dumps from typing import Annotated from fastapi import Depends, UploadFile from redis.asyncio import Redis from server.api.v1 import v1 -from server.config import Config -from server.databases.redis import create_query_parameters, redis_get -from server.databases.redis import redis_query as redis_query_helper from server.dependencies import get_redis_client -from server.features import LLM, Embedding, extract_text_from_image, question_answering -from server.features.llm.types import Message +from server.features import extract_text_from_image, query_llm from server.schemas.v1 import Answer @@ -25,35 +20,9 @@ async def query_with_image( """ Summary ------- - the `/query_with_image` route provides an endpoint for performning retrieval-augmented generation + the `/query_with_image` route is similar to `/query` but it accepts an image as input """ - redis_query = redis_query_helper('tag', chat_id, top_k) - extracted_query = extract_text_from_image(request.file) - - redis_query_parameters = create_query_parameters( - Embedding().encode_query(extracted_query) - ) - - search_response = await redis.ft(Config.redis_index_name).search( - redis_query, - redis_query_parameters # type: ignore (this is a bug in the redis-py library) - ) - - context = ' '.join( - document['content'] for document - in search_response.docs # type: ignore - ) - - message_history: list[Message] = await redis_get(redis, f'chat:{chat_id}', _ := []) - message_history.append({ - 'role': 'user', - 'content': f'Given the following context:\n\n{context}\n\nPlease answer the following question:\n\n{extracted_query}' - }) - - messages = question_answering(message_history, LLM.query) - - if not store_query: - await redis.set(f'chat:{chat_id}', dumps(messages)) + messages = await query_llm(redis, extracted_query, chat_id, top_k, store_query) return Answer(messages=messages) diff --git a/server/api/v1/search.py b/server/api/v1/search.py index 75d0d63..f214ff7 100644 --- a/server/api/v1/search.py +++ b/server/api/v1/search.py @@ -14,9 +14,10 @@ @v1.post('/{chat_id}/search') async def search( + redis: Annotated[Redis, Depends(get_redis_client)], chat_id: str, request: Query, - redis: Annotated[Redis, Depends(get_redis_client)] + top_k: int = 5 ) -> str: """ Summary @@ -28,7 +29,7 @@ async def search( ) search_response = await redis.ft(Config.redis_index_name).search( - redis_query_helper('tag', chat_id, request.top_k), + redis_query_helper('tag', chat_id, top_k), redis_query_parameters # type: ignore (this is a bug in the redis-py library) ) diff --git a/server/features/__init__.py b/server/features/__init__.py index 453266b..950bf71 100644 --- a/server/features/__init__.py +++ b/server/features/__init__.py @@ -10,4 +10,4 @@ extract_text_from_image as extract_text_from_image, ) from server.features.llm import LLM as LLM -from server.features.question_answering import question_answering as question_answering +from server.features.query import query_llm as query_llm diff --git a/server/features/query/__init__.py b/server/features/query/__init__.py new file mode 100644 index 0000000..42f3e13 --- /dev/null +++ b/server/features/query/__init__.py @@ -0,0 +1 @@ +from server.features.query.query_llm import query_llm as query_llm diff --git a/server/features/query/query_llm.py b/server/features/query/query_llm.py new file mode 100644 index 0000000..2216e8a --- /dev/null +++ b/server/features/query/query_llm.py @@ -0,0 +1,58 @@ +from json import dumps + +from redis.asyncio import Redis + +from server.config import Config +from server.databases.redis import create_query_parameters, redis_get +from server.databases.redis import redis_query as redis_query_helper +from server.features import LLM, Embedding +from server.features.llm.types import Message +from server.features.query.question_answering import question_answering + + +async def query_llm(redis: Redis, query: str, chat_id: str, top_k: int, store_query: bool) -> list[Message]: + """ + Summary + ------- + the query feature provides a reusable query function for all types of queries + + Parameters + ---------- + redis (Redis) : a Redis client + query (str) : a query string + chat_id (str) : a chat id + top_k (int) : the number of documents to retrieve + store_query (bool) : whether or not to store the query in the database + + Returns + ------- + messages (list[Message]) : a list of messages + """ + redis_query = redis_query_helper('tag', chat_id, top_k) + + redis_query_parameters = create_query_parameters( + Embedding().encode_query(query) + ) + + search_response = await redis.ft(Config.redis_index_name).search( + redis_query, + redis_query_parameters # type: ignore (this is a bug in the redis-py library) + ) + + context = ' '.join( + document['content'] for document + in search_response.docs # type: ignore + ) + + message_history: list[Message] = await redis_get(redis, f'chat:{chat_id}', _ := []) + message_history.append({ + 'role': 'user', + 'content': f'Given the following context:\n\n{context}\n\nPlease answer the following question:\n\n{query}' + }) + + messages = question_answering(message_history, LLM.query) + + if not store_query: + await redis.set(f'chat:{chat_id}', dumps(messages)) + + return messages diff --git a/server/features/question_answering/__init__.py b/server/features/query/question_answering.py similarity index 100% rename from server/features/question_answering/__init__.py rename to server/features/query/question_answering.py diff --git a/server/schemas/v1/query.py b/server/schemas/v1/query.py index b76a0f9..1863fd5 100644 --- a/server/schemas/v1/query.py +++ b/server/schemas/v1/query.py @@ -10,7 +10,5 @@ class Query(BaseModel): Attributes ---------- query (str) : the query - index_name (str) : the index name """ query: str = Field(examples=['Why did the chicken cross the road?']) - top_k: int = Field(examples=[1])