Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure model provided in vLLM inference #820

Merged
merged 23 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion presets/ragengine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"""

# LLM (Large Language Model) configuration
LLM_INFERENCE_URL = os.getenv("LLM_INFERENCE_URL", "http://localhost:5000/chat")
LLM_INFERENCE_URL = os.getenv("LLM_INFERENCE_URL", "http://localhost:5000/v1/completions")
LLM_ACCESS_SECRET = os.getenv("LLM_ACCESS_SECRET", "default-access-secret")
# LLM_RESPONSE_FIELD = os.getenv("LLM_RESPONSE_FIELD", "result") # Uncomment if needed in the future

Expand Down
118 changes: 101 additions & 17 deletions presets/ragengine/inference/inference.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import Any
from dataclasses import field
from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata, CompletionResponseGen
from llama_index.llms.openai import OpenAI
from llama_index.core.llms.callbacks import llm_completion_callback
import requests
from requests.exceptions import HTTPError
from urllib.parse import urlparse, urljoin
from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

OPENAI_URL_PREFIX = "https://api.openai.com"
HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co"
DEFAULT_HEADERS = {
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
}

class Inference(CustomLLM):
params: dict = {}
_default_model: str = None
_model_retrieval_attempted: bool = False

def set_params(self, params: dict) -> None:
self.params = params
Expand All @@ -25,7 +39,7 @@
pass

@llm_completion_callback()
def complete(self, prompt: str, **kwargs) -> CompletionResponse:
def complete(self, prompt: str, formatted: bool, **kwargs) -> CompletionResponse:
try:
if LLM_INFERENCE_URL.startswith(OPENAI_URL_PREFIX):
return self._openai_complete(prompt, **kwargs, **self.params)
Expand All @@ -38,29 +52,99 @@
self.params = {}

def _openai_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
llm = OpenAI(
api_key=LLM_ACCESS_SECRET,
**kwargs # Pass all kwargs directly; kwargs may include model, temperature, max_tokens, etc.
)
return llm.complete(prompt)
return OpenAI(api_key=LLM_ACCESS_SECRET, **kwargs).complete(prompt)

Check warning on line 55 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L55

Added line #L55 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this part do something similar to the vllm part but if you want to merge it now for efficiency, it is good


def _huggingface_remote_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
data = {"messages": [{"role": "user", "content": prompt}]}
response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers)
response_data = response.json()
return CompletionResponse(text=str(response_data))
return self._post_request(

Check warning on line 58 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L58

Added line #L58 was not covered by tests
{"messages": [{"role": "user", "content": prompt}]},
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
)

def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
model = kwargs.pop("model", self._get_default_model())
Copy link
Collaborator Author

@ishaansehgal99 ishaansehgal99 Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this first gets what is user request-specified "model" if that is not present it gets default dynamically by hitting /v1/models endpoint

data = {"prompt": prompt, **kwargs}
if model:
data["model"] = model # Include the model only if it is not None

Check warning on line 67 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L67

Added line #L67 was not covered by tests

# DEBUG: Call the debugging function
# self._debug_curl_command(data)
try:
return self._post_request(data, headers=DEFAULT_HEADERS)
except HTTPError as e:
if e.response.status_code == 400:
logger.warning(

Check warning on line 75 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L73-L75

Added lines #L73 - L75 were not covered by tests
f"Potential issue with 'model' parameter in API response. "
f"Response: {str(e)}. Attempting to update the model name as a mitigation..."
)
self._default_model = self._fetch_default_model() # Fetch default model dynamically
if self._default_model:
logger.info(f"Default model '{self._default_model}' fetched successfully. Retrying request...")
data["model"] = self._default_model
return self._post_request(data, headers=DEFAULT_HEADERS)

Check warning on line 83 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L79-L83

Added lines #L79 - L83 were not covered by tests
else:
logger.error("Failed to fetch a default model. Aborting retry.")
raise # Re-raise the exception if not recoverable
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
raise

Check warning on line 89 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L85-L89

Added lines #L85 - L89 were not covered by tests

def _get_models_endpoint(self) -> str:
"""
Constructs the URL for the /v1/models endpoint based on LLM_INFERENCE_URL.
"""
parsed = urlparse(LLM_INFERENCE_URL)
return urljoin(f"{parsed.scheme}://{parsed.netloc}", "/v1/models")

response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers)
response_data = response.json()
def _fetch_default_model(self) -> str:
"""
Fetch the default model from the /v1/models endpoint.
"""
try:
models_url = self._get_models_endpoint()
response = requests.get(models_url, headers=DEFAULT_HEADERS)
response.raise_for_status() # Raise an exception for HTTP errors (includes 404)

Check warning on line 105 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L105

Added line #L105 was not covered by tests

models = response.json().get("data", [])
return models[0].get("id") if models else None

Check warning on line 108 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L107-L108

Added lines #L107 - L108 were not covered by tests
except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check the 404 status code here? sometimes, service server may be in initialization phase. i m thinking we should raise error except 404 error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current approach (catching Exception and handling 404 implicitly) is sufficient. It already logs 404, which can occur if vLLM is initializing or if the endpoint doesn't exist (e.g., non-vLLM). Let me know if you have a specific log message in mind to add

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if vLLM is initializing, you will get a network error instead of 404. because http server is not launched.

logger.error(f"Error fetching models from {models_url}: {e}. \"model\" parameter will not be included with inference call.")
return None

def _get_default_model(self) -> str:
"""
Returns the cached default model if available, otherwise fetches and caches it.
"""
if not self._default_model and not self._model_retrieval_attempted:
self._model_retrieval_attempted = True
self._default_model = self._fetch_default_model()
ishaansehgal99 marked this conversation as resolved.
Show resolved Hide resolved
return self._default_model

# Dynamically extract the field from the response based on the specified response_field
# completion_text = response_data.get(RESPONSE_FIELD, "No response field found") # not necessary for now
return CompletionResponse(text=str(response_data))
def _post_request(self, data: dict, headers: dict) -> CompletionResponse:
try:
response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers)
response.raise_for_status() # Raise exception for HTTP errors
response_data = response.json()
return CompletionResponse(text=str(response_data))
except requests.RequestException as e:
logger.error(f"Error during POST request to {LLM_INFERENCE_URL}: {e}")
raise

Check warning on line 130 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L128-L130

Added lines #L128 - L130 were not covered by tests

def _debug_curl_command(self, data: dict) -> None:
"""
Constructs and prints the equivalent curl command for debugging purposes.
"""
import json

Check warning on line 136 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L136

Added line #L136 was not covered by tests
# Construct curl command
curl_command = (

Check warning on line 138 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L138

Added line #L138 was not covered by tests
f"curl -X POST {LLM_INFERENCE_URL} "
+ " ".join([f'-H "{key}: {value}"' for key, value in {
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
}.items()])
+ f" -d '{json.dumps(data)}'"
)
logger.info("Equivalent curl command:")
logger.info(curl_command)

Check warning on line 147 in presets/ragengine/inference/inference.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/inference/inference.py#L146-L147

Added lines #L146 - L147 were not covered by tests

@property
def metadata(self) -> LLMMetadata:
Expand Down
14 changes: 10 additions & 4 deletions presets/ragengine/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,17 @@
@app.post("/query", response_model=QueryResponse)
async def query_index(request: QueryRequest):
try:
llm_params = request.llm_params or {} # Default to empty dict if no params provided
rerank_params = request.rerank_params or {} # Default to empty dict if no params provided
return rag_ops.query(request.index_name, request.query, request.top_k, llm_params, rerank_params)
llm_params = request.llm_params or {} # Default to empty dict if no params provided
rerank_params = request.rerank_params or {} # Default to empty dict if no params provided
return rag_ops.query(
request.index_name, request.query, request.top_k, llm_params, rerank_params
)
except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve)) # Validation issue
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(

Check warning on line 71 in presets/ragengine/main.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/main.py#L71

Added line #L71 was not covered by tests
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
)

@app.get("/indexed-documents", response_model=ListDocumentsResponse)
async def list_all_indexed_documents():
Expand Down
33 changes: 29 additions & 4 deletions presets/ragengine/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field, model_validator

from pydantic import BaseModel

class Document(BaseModel):
text: str
Expand All @@ -22,8 +23,32 @@
index_name: str
query: str
top_k: int = 10
llm_params: Optional[Dict] = None # Accept a dictionary for parameters
rerank_params: Optional[Dict] = None # Accept a dictionary for parameters
# Accept a dictionary for our LLM parameters
llm_params: Optional[Dict[str, Any]] = Field(
default_factory=dict,
description="Optional parameters for the language model, e.g., temperature, top_p",
)
# Accept a dictionary for rerank parameters
rerank_params: Optional[Dict[str, Any]] = Field(
default_factory=dict,
description="Optional parameters for reranking, e.g., top_n, batch_size",
)

@model_validator(mode="before")
def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
llm_params = values.get("llm_params", {})
rerank_params = values.get("rerank_params", {})

# Validate LLM parameters
if "temperature" in llm_params and not (0.0 <= llm_params["temperature"] <= 1.0):
raise ValueError("Temperature must be between 0.0 and 1.0.")

Check warning on line 44 in presets/ragengine/models.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/models.py#L44

Added line #L44 was not covered by tests
# TODO: More LLM Param Validations here
# Validate rerank parameters
top_k = values.get("top_k")
if "top_n" in rerank_params and rerank_params["top_n"] > top_k:
raise ValueError("Invalid configuration: 'top_n' for reranking cannot exceed 'top_k' from the RAG query.")

Check warning on line 49 in presets/ragengine/models.py

View check run for this annotation

Codecov / codecov/patch

presets/ragengine/models.py#L49

Added line #L49 was not covered by tests

return values

class ListDocumentsResponse(BaseModel):
documents: Dict[str, Dict[str, Dict[str, str]]]
Expand Down
2 changes: 1 addition & 1 deletion presets/ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_query_index_failure():
}

response = client.post("/query", json=request_data)
assert response.status_code == 500
assert response.status_code == 400
assert response.json()["detail"] == "No such index: 'non_existent_index' exists."


Expand Down
4 changes: 2 additions & 2 deletions presets/ragengine/tests/vector_store/test_base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def test_query_documents(self, mock_post, vector_store_manager):

mock_post.assert_called_once_with(
LLM_INFERENCE_URL,
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7},
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", 'temperature': 0.7},
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}", 'Content-Type': 'application/json'}
)

def test_add_document(self, vector_store_manager):
Expand Down
Loading