Skip to content

Commit

Permalink
RAG server patch for consistency
Browse files Browse the repository at this point in the history
Signed-off-by: Bangqi Zhu <[email protected]>
  • Loading branch information
Bangqi Zhu committed Jan 10, 2025
1 parent 5351ff8 commit c429d81
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 15 deletions.
40 changes: 38 additions & 2 deletions presets/ragengine/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,49 @@
from llama_index.llms.openai import OpenAI
from llama_index.core.llms.callbacks import llm_completion_callback
import requests
from urllib.parse import urlparse, urljoin
from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD

OPENAI_URL_PREFIX = "https://api.openai.com"
HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co"

class Inference(CustomLLM):
params: dict = {}
model: str = ""

def set_params(self, params: dict) -> None:
self.params = params

def get_param(self, key, default=None):
return self.params.get(key, default)
# Get base URL
def _get_base_url(self) -> str:
parsed = urlparse(LLM_INFERENCE_URL)
base_url = f"{parsed.scheme}://{parsed.netloc}"
return urljoin(base_url, "/v1/models")

#Fetch and set the model from the inference endpoint
def set_model(self) -> None:

try:
models_url = self._get_base_url()
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
response = requests.get(models_url, headers=headers)

if response.status_code == 404:
self.model = None
return

response.raise_for_status()

data = response.json()
if data.get("data") and len(data["data"]) > 0:
self.model = data["data"][0]["id"]
else:
raise ValueError("No model found in response")

except requests.RequestException as e:
raise Exception(f"Failed to fetch model information: {str(e)}")

@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
Expand Down Expand Up @@ -53,8 +83,14 @@ def _huggingface_remote_complete(self, prompt: str, **kwargs: Any) -> Completion

def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
data = {"prompt": prompt, **kwargs}

if self.model != None:
data = {"prompt": prompt, "model":self.model}
else:
data = {"prompt": prompt}

for param in self.params:
data[param] = self.params[param]

response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers)
response_data = response.json()

Expand Down
4 changes: 3 additions & 1 deletion presets/ragengine/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ async def index_documents(request: IndexRequest): # TODO: Research async/sync wh
@app.post("/query", response_model=QueryResponse)
async def query_index(request: QueryRequest):
try:
llm_params = request.llm_params or {} # Default to empty dict if no params provided
llm_params = {}
for key, value in request.model_extra.items():
llm_params[key] = value
rerank_params = request.rerank_params or {} # Default to empty dict if no params provided
return rag_ops.query(request.index_name, request.query, request.top_k, llm_params, rerank_params)
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions presets/ragengine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Dict, List, Optional

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

class Document(BaseModel):
text: str
Expand All @@ -22,7 +22,7 @@ class QueryRequest(BaseModel):
index_name: str
query: str
top_k: int = 10
llm_params: Optional[Dict] = None # Accept a dictionary for parameters
model_config = ConfigDict(extra='allow')
rerank_params: Optional[Dict] = None # Accept a dictionary for parameters

class ListDocumentsResponse(BaseModel):
Expand Down
19 changes: 13 additions & 6 deletions presets/ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from unittest.mock import patch
from unittest.mock import patch, MagicMock

from llama_index.core.storage.index_store import SimpleIndexStore

Expand Down Expand Up @@ -39,7 +39,15 @@ def test_index_documents_success():
assert not doc2["metadata"]

@patch('requests.post')
def test_query_index_success(mock_post):
@patch('requests.get')
def test_query_index_success(mock_get, mock_post):
mock_get.return_value = MagicMock(
status_code=200,
json=lambda: {
"data": [{"id": "test-model"}]
}
)

# Define Mock Response for Custom Inference API
mock_response = {
"result": "This is the completion from the API"
Expand All @@ -62,7 +70,7 @@ def test_query_index_success(mock_post):
"index_name": "test_index",
"query": "test query",
"top_k": 1,
"llm_params": {"temperature": 0.7}
"temperature": 0.7
}

response = client.post("/query", json=request_data)
Expand Down Expand Up @@ -135,7 +143,7 @@ def test_reranker_and_query_with_index(mock_post):
"index_name": "test_index",
"query": "what is the capital of france?",
"top_k": 5,
"llm_params": {"temperature": 0.7},
"temperature": 0.7,
"rerank_params": {"top_n": top_n}
}

Expand Down Expand Up @@ -171,14 +179,13 @@ def test_query_index_failure():
"index_name": "non_existent_index", # Use an index name that doesn't exist
"query": "test query",
"top_k": 1,
"llm_params": {"temperature": 0.7}
"temperature": 0.7
}

response = client.post("/query", json=request_data)
assert response.status_code == 500
assert response.json()["detail"] == "No such index: 'non_existent_index' exists."


def test_list_all_indexed_documents_success():
response = client.get("/indexed-documents")
assert response.status_code == 200
Expand Down
26 changes: 22 additions & 4 deletions presets/ragengine/tests/vector_store/test_base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import os
from unittest.mock import patch
from unittest.mock import patch, MagicMock
import pytest
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -66,7 +66,15 @@ def check_indexed_documents(self, vector_store_manager):
pass

@patch('requests.post')
def test_query_documents(self, mock_post, vector_store_manager):
@patch('requests.get')
def test_query_documents(self, mock_get, mock_post, vector_store_manager):
mock_get.return_value = MagicMock(
status_code=200,
json=lambda: {
"data": [{"id": "test-model"}]
}
)

mock_response = {
"result": "This is the completion from the API"
}
Expand All @@ -87,13 +95,23 @@ def test_query_documents(self, mock_post, vector_store_manager):
assert query_result["source_nodes"][0]["text"] == "First document"
assert query_result["source_nodes"][0]["score"] == pytest.approx(self.expected_query_score, rel=1e-6)

mock_get.assert_called_once()

mock_post.assert_called_once_with(
LLM_INFERENCE_URL,
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7},
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "model": "test-model", 'temperature': 0.7},
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
)

def test_add_document(self, vector_store_manager):
@patch('requests.get')
def test_add_document(self, mock_get, vector_store_manager):
mock_get.return_value = MagicMock(
status_code=200,
json=lambda: {
"data": [{"id": "test-model"}]
}
)

documents = [Document(text="Third document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)

Expand Down
3 changes: 3 additions & 0 deletions presets/ragengine/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def query(self,
"""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
if self.llm.model == "":
self.llm.set_model()

self.llm.set_params(llm_params)

node_postprocessors = []
Expand Down

0 comments on commit c429d81

Please sign in to comment.