Skip to content

Commit

Permalink
refactor: abstract duplicate query into a function
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Nov 23, 2023
1 parent 1457434 commit cb6052c
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 72 deletions.
36 changes: 3 additions & 33 deletions server/api/v1/query.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from json import dumps
from typing import Annotated

from fastapi import Depends
from redis.asyncio import Redis

from server.api.v1 import v1
from server.config import Config
from server.databases.redis import create_query_parameters, redis_get
from server.databases.redis import redis_query as redis_query_helper
from server.dependencies import get_redis_client
from server.features import LLM, Embedding, question_answering
from server.features.llm.types import Message
from server.features import query_llm
from server.schemas.v1 import Answer, Query


Expand All @@ -19,38 +14,13 @@ async def query(
redis: Annotated[Redis, Depends(get_redis_client)],
chat_id: str,
request: Query,
top_k: int = 5,
store_query: bool = True
) -> Answer:
"""
Summary
-------
the `/query` route provides an endpoint for performning retrieval-augmented generation
"""
redis_query = redis_query_helper('tag', chat_id, request.top_k)

redis_query_parameters = create_query_parameters(
Embedding().encode_query(request.query)
)

search_response = await redis.ft(Config.redis_index_name).search(
redis_query,
redis_query_parameters # type: ignore (this is a bug in the redis-py library)
)

context = ' '.join(
document['content'] for document
in search_response.docs # type: ignore
)

message_history: list[Message] = await redis_get(redis, f'chat:{chat_id}', _ := [])
message_history.append({
'role': 'user',
'content': f'Given the following context:\n\n{context}\n\nPlease answer the following question:\n\n{request.query}'
})

messages = question_answering(message_history, LLM.query)

if not store_query:
await redis.set(f'chat:{chat_id}', dumps(messages))

messages = await query_llm(redis, request.query, chat_id, top_k, store_query)
return Answer(messages=messages)
37 changes: 3 additions & 34 deletions server/api/v1/query_with_image.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from json import dumps
from typing import Annotated

from fastapi import Depends, UploadFile
from redis.asyncio import Redis

from server.api.v1 import v1
from server.config import Config
from server.databases.redis import create_query_parameters, redis_get
from server.databases.redis import redis_query as redis_query_helper
from server.dependencies import get_redis_client
from server.features import LLM, Embedding, extract_text_from_image, question_answering
from server.features.llm.types import Message
from server.features import extract_text_from_image, query_llm
from server.schemas.v1 import Answer


Expand All @@ -25,35 +20,9 @@ async def query_with_image(
"""
Summary
-------
the `/query_with_image` route provides an endpoint for performning retrieval-augmented generation
the `/query_with_image` route is similar to `/query` but it accepts an image as input
"""
redis_query = redis_query_helper('tag', chat_id, top_k)

extracted_query = extract_text_from_image(request.file)

redis_query_parameters = create_query_parameters(
Embedding().encode_query(extracted_query)
)

search_response = await redis.ft(Config.redis_index_name).search(
redis_query,
redis_query_parameters # type: ignore (this is a bug in the redis-py library)
)

context = ' '.join(
document['content'] for document
in search_response.docs # type: ignore
)

message_history: list[Message] = await redis_get(redis, f'chat:{chat_id}', _ := [])
message_history.append({
'role': 'user',
'content': f'Given the following context:\n\n{context}\n\nPlease answer the following question:\n\n{extracted_query}'
})

messages = question_answering(message_history, LLM.query)

if not store_query:
await redis.set(f'chat:{chat_id}', dumps(messages))
messages = await query_llm(redis, extracted_query, chat_id, top_k, store_query)

return Answer(messages=messages)
5 changes: 3 additions & 2 deletions server/api/v1/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

@v1.post('/{chat_id}/search')
async def search(
redis: Annotated[Redis, Depends(get_redis_client)],
chat_id: str,
request: Query,
redis: Annotated[Redis, Depends(get_redis_client)]
top_k: int = 5
) -> str:
"""
Summary
Expand All @@ -28,7 +29,7 @@ async def search(
)

search_response = await redis.ft(Config.redis_index_name).search(
redis_query_helper('tag', chat_id, request.top_k),
redis_query_helper('tag', chat_id, top_k),
redis_query_parameters # type: ignore (this is a bug in the redis-py library)
)

Expand Down
2 changes: 1 addition & 1 deletion server/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
extract_text_from_image as extract_text_from_image,
)
from server.features.llm import LLM as LLM
from server.features.question_answering import question_answering as question_answering
from server.features.query import query_llm as query_llm
1 change: 1 addition & 0 deletions server/features/query/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from server.features.query.query_llm import query_llm as query_llm
58 changes: 58 additions & 0 deletions server/features/query/query_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from json import dumps

from redis.asyncio import Redis

from server.config import Config
from server.databases.redis import create_query_parameters, redis_get
from server.databases.redis import redis_query as redis_query_helper
from server.features import LLM, Embedding
from server.features.llm.types import Message
from server.features.query.question_answering import question_answering


async def query_llm(redis: Redis, query: str, chat_id: str, top_k: int, store_query: bool) -> list[Message]:
"""
Summary
-------
the query feature provides a reusable query function for all types of queries
Parameters
----------
redis (Redis) : a Redis client
query (str) : a query string
chat_id (str) : a chat id
top_k (int) : the number of documents to retrieve
store_query (bool) : whether or not to store the query in the database
Returns
-------
messages (list[Message]) : a list of messages
"""
redis_query = redis_query_helper('tag', chat_id, top_k)

redis_query_parameters = create_query_parameters(
Embedding().encode_query(query)
)

search_response = await redis.ft(Config.redis_index_name).search(
redis_query,
redis_query_parameters # type: ignore (this is a bug in the redis-py library)
)

context = ' '.join(
document['content'] for document
in search_response.docs # type: ignore
)

message_history: list[Message] = await redis_get(redis, f'chat:{chat_id}', _ := [])
message_history.append({
'role': 'user',
'content': f'Given the following context:\n\n{context}\n\nPlease answer the following question:\n\n{query}'
})

messages = question_answering(message_history, LLM.query)

if not store_query:
await redis.set(f'chat:{chat_id}', dumps(messages))

return messages
File renamed without changes.
2 changes: 0 additions & 2 deletions server/schemas/v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,5 @@ class Query(BaseModel):
Attributes
----------
query (str) : the query
index_name (str) : the index name
"""
query: str = Field(examples=['Why did the chicken cross the road?'])
top_k: int = Field(examples=[1])

0 comments on commit cb6052c

Please sign in to comment.