Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: RAG server changes for consistency #813

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions presets/ragengine/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,49 @@
from llama_index.llms.openai import OpenAI
from llama_index.core.llms.callbacks import llm_completion_callback
import requests
from urllib.parse import urlparse, urljoin
from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD

OPENAI_URL_PREFIX = "https://api.openai.com"
HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co"

class Inference(CustomLLM):
params: dict = {}
model: str = ""

def set_params(self, params: dict) -> None:
self.params = params

def get_param(self, key, default=None):
return self.params.get(key, default)
# Get base URL
def _get_base_url(self) -> str:
parsed = urlparse(LLM_INFERENCE_URL)
base_url = f"{parsed.scheme}://{parsed.netloc}"
return urljoin(base_url, "/v1/models")

#Fetch and set the model from the inference endpoint
def set_model(self) -> None:

try:
models_url = self._get_base_url()
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
response = requests.get(models_url, headers=headers)

if response.status_code == 404:
self.model = None
return

response.raise_for_status()

data = response.json()
if data.get("data") and len(data["data"]) > 0:
self.model = data["data"][0]["id"]
else:
raise ValueError("No model found in response")

except requests.RequestException as e:
raise Exception(f"Failed to fetch model information: {str(e)}")

@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
Expand Down Expand Up @@ -53,8 +83,14 @@ def _huggingface_remote_complete(self, prompt: str, **kwargs: Any) -> Completion

def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
data = {"prompt": prompt, **kwargs}

if self.model != None:
data = {"prompt": prompt, "model":self.model}
else:
data = {"prompt": prompt}

for param in self.params:
data[param] = self.params[param]

response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers)
response_data = response.json()

Expand Down
4 changes: 3 additions & 1 deletion presets/ragengine/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ async def index_documents(request: IndexRequest): # TODO: Research async/sync wh
@app.post("/query", response_model=QueryResponse)
async def query_index(request: QueryRequest):
try:
llm_params = request.llm_params or {} # Default to empty dict if no params provided
llm_params = {}
for key, value in request.model_extra.items():
llm_params[key] = value
rerank_params = request.rerank_params or {} # Default to empty dict if no params provided
return rag_ops.query(request.index_name, request.query, request.top_k, llm_params, rerank_params)
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions presets/ragengine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Dict, List, Optional

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

class Document(BaseModel):
text: str
Expand All @@ -22,7 +22,7 @@ class QueryRequest(BaseModel):
index_name: str
query: str
top_k: int = 10
llm_params: Optional[Dict] = None # Accept a dictionary for parameters
model_config = ConfigDict(extra='allow')
rerank_params: Optional[Dict] = None # Accept a dictionary for parameters

class ListDocumentsResponse(BaseModel):
Expand Down
19 changes: 13 additions & 6 deletions presets/ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from unittest.mock import patch
from unittest.mock import patch, MagicMock

from llama_index.core.storage.index_store import SimpleIndexStore

Expand Down Expand Up @@ -39,7 +39,15 @@ def test_index_documents_success():
assert not doc2["metadata"]

@patch('requests.post')
def test_query_index_success(mock_post):
@patch('requests.get')
def test_query_index_success(mock_get, mock_post):
mock_get.return_value = MagicMock(
status_code=200,
json=lambda: {
"data": [{"id": "test-model"}]
}
)

# Define Mock Response for Custom Inference API
mock_response = {
"result": "This is the completion from the API"
Expand All @@ -62,7 +70,7 @@ def test_query_index_success(mock_post):
"index_name": "test_index",
"query": "test query",
"top_k": 1,
"llm_params": {"temperature": 0.7}
"temperature": 0.7
}

response = client.post("/query", json=request_data)
Expand Down Expand Up @@ -135,7 +143,7 @@ def test_reranker_and_query_with_index(mock_post):
"index_name": "test_index",
"query": "what is the capital of france?",
"top_k": 5,
"llm_params": {"temperature": 0.7},
"temperature": 0.7,
"rerank_params": {"top_n": top_n}
}

Expand Down Expand Up @@ -171,14 +179,13 @@ def test_query_index_failure():
"index_name": "non_existent_index", # Use an index name that doesn't exist
"query": "test query",
"top_k": 1,
"llm_params": {"temperature": 0.7}
"temperature": 0.7
}

response = client.post("/query", json=request_data)
assert response.status_code == 500
assert response.json()["detail"] == "No such index: 'non_existent_index' exists."


def test_list_all_indexed_documents_success():
response = client.get("/indexed-documents")
assert response.status_code == 200
Expand Down
26 changes: 22 additions & 4 deletions presets/ragengine/tests/vector_store/test_base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import os
from unittest.mock import patch
from unittest.mock import patch, MagicMock
import pytest
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -66,7 +66,15 @@ def check_indexed_documents(self, vector_store_manager):
pass

@patch('requests.post')
def test_query_documents(self, mock_post, vector_store_manager):
@patch('requests.get')
def test_query_documents(self, mock_get, mock_post, vector_store_manager):
mock_get.return_value = MagicMock(
status_code=200,
json=lambda: {
"data": [{"id": "test-model"}]
}
)

mock_response = {
"result": "This is the completion from the API"
}
Expand All @@ -87,13 +95,23 @@ def test_query_documents(self, mock_post, vector_store_manager):
assert query_result["source_nodes"][0]["text"] == "First document"
assert query_result["source_nodes"][0]["score"] == pytest.approx(self.expected_query_score, rel=1e-6)

mock_get.assert_called_once()

mock_post.assert_called_once_with(
LLM_INFERENCE_URL,
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7},
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "model": "test-model", 'temperature': 0.7},
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
)

def test_add_document(self, vector_store_manager):
@patch('requests.get')
def test_add_document(self, mock_get, vector_store_manager):
mock_get.return_value = MagicMock(
status_code=200,
json=lambda: {
"data": [{"id": "test-model"}]
}
)

documents = [Document(text="Third document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)

Expand Down
3 changes: 3 additions & 0 deletions presets/ragengine/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def query(self,
"""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
if self.llm.model == "":
self.llm.set_model()

self.llm.set_params(llm_params)

node_postprocessors = []
Expand Down
Loading