-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Ensure model
provided in vLLM inference
#820
Changes from all commits
73bf3ed
0425a77
ab3db03
4ebb5a3
8e6acf9
54c1165
77490d0
b10cfc5
6d5e0f2
d8b061e
e552e98
45ba8cb
fc6f0ba
dc3a0be
13abd01
2c85ffe
b3f8a5c
27557d0
5cf21d3
7b10cd6
50fc8bc
cbef599
93e4f20
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,32 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import logging | ||
from typing import Any | ||
from dataclasses import field | ||
from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata, CompletionResponseGen | ||
from llama_index.llms.openai import OpenAI | ||
from llama_index.core.llms.callbacks import llm_completion_callback | ||
import requests | ||
from requests.exceptions import HTTPError | ||
from urllib.parse import urlparse, urljoin | ||
from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD | ||
|
||
# Configure logging | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
OPENAI_URL_PREFIX = "https://api.openai.com" | ||
HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co" | ||
DEFAULT_HEADERS = { | ||
"Authorization": f"Bearer {LLM_ACCESS_SECRET}", | ||
"Content-Type": "application/json" | ||
} | ||
|
||
class Inference(CustomLLM): | ||
params: dict = {} | ||
_default_model: str = None | ||
_model_retrieval_attempted: bool = False | ||
|
||
def set_params(self, params: dict) -> None: | ||
self.params = params | ||
|
@@ -25,7 +39,7 @@ | |
pass | ||
|
||
@llm_completion_callback() | ||
def complete(self, prompt: str, **kwargs) -> CompletionResponse: | ||
def complete(self, prompt: str, formatted: bool, **kwargs) -> CompletionResponse: | ||
try: | ||
if LLM_INFERENCE_URL.startswith(OPENAI_URL_PREFIX): | ||
return self._openai_complete(prompt, **kwargs, **self.params) | ||
|
@@ -38,29 +52,99 @@ | |
self.params = {} | ||
|
||
def _openai_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | ||
llm = OpenAI( | ||
api_key=LLM_ACCESS_SECRET, | ||
**kwargs # Pass all kwargs directly; kwargs may include model, temperature, max_tokens, etc. | ||
) | ||
return llm.complete(prompt) | ||
return OpenAI(api_key=LLM_ACCESS_SECRET, **kwargs).complete(prompt) | ||
|
||
def _huggingface_remote_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | ||
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} | ||
data = {"messages": [{"role": "user", "content": prompt}]} | ||
response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers) | ||
response_data = response.json() | ||
return CompletionResponse(text=str(response_data)) | ||
return self._post_request( | ||
{"messages": [{"role": "user", "content": prompt}]}, | ||
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} | ||
) | ||
|
||
def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | ||
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} | ||
model = kwargs.pop("model", self._get_default_model()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: this first gets what is user request-specified "model" if that is not present it gets default dynamically by hitting /v1/models endpoint |
||
data = {"prompt": prompt, **kwargs} | ||
if model: | ||
data["model"] = model # Include the model only if it is not None | ||
|
||
# DEBUG: Call the debugging function | ||
# self._debug_curl_command(data) | ||
try: | ||
return self._post_request(data, headers=DEFAULT_HEADERS) | ||
except HTTPError as e: | ||
if e.response.status_code == 400: | ||
logger.warning( | ||
f"Potential issue with 'model' parameter in API response. " | ||
f"Response: {str(e)}. Attempting to update the model name as a mitigation..." | ||
) | ||
self._default_model = self._fetch_default_model() # Fetch default model dynamically | ||
if self._default_model: | ||
logger.info(f"Default model '{self._default_model}' fetched successfully. Retrying request...") | ||
data["model"] = self._default_model | ||
return self._post_request(data, headers=DEFAULT_HEADERS) | ||
else: | ||
logger.error("Failed to fetch a default model. Aborting retry.") | ||
raise # Re-raise the exception if not recoverable | ||
except Exception as e: | ||
logger.error(f"An unexpected error occurred: {e}") | ||
raise | ||
|
||
def _get_models_endpoint(self) -> str: | ||
""" | ||
Constructs the URL for the /v1/models endpoint based on LLM_INFERENCE_URL. | ||
""" | ||
parsed = urlparse(LLM_INFERENCE_URL) | ||
return urljoin(f"{parsed.scheme}://{parsed.netloc}", "/v1/models") | ||
|
||
response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers) | ||
response_data = response.json() | ||
def _fetch_default_model(self) -> str: | ||
""" | ||
Fetch the default model from the /v1/models endpoint. | ||
""" | ||
try: | ||
models_url = self._get_models_endpoint() | ||
response = requests.get(models_url, headers=DEFAULT_HEADERS) | ||
response.raise_for_status() # Raise an exception for HTTP errors (includes 404) | ||
|
||
models = response.json().get("data", []) | ||
return models[0].get("id") if models else None | ||
except Exception as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we check the 404 status code here? sometimes, service server may be in initialization phase. i m thinking we should raise error except 404 error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. current approach (catching Exception and handling 404 implicitly) is sufficient. It already logs 404, which can occur if vLLM is initializing or if the endpoint doesn't exist (e.g., non-vLLM). Let me know if you have a specific log message in mind to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if vLLM is initializing, you will get a network error instead of 404. because http server is not launched. |
||
logger.error(f"Error fetching models from {models_url}: {e}. \"model\" parameter will not be included with inference call.") | ||
return None | ||
|
||
def _get_default_model(self) -> str: | ||
""" | ||
Returns the cached default model if available, otherwise fetches and caches it. | ||
""" | ||
if not self._default_model and not self._model_retrieval_attempted: | ||
self._model_retrieval_attempted = True | ||
self._default_model = self._fetch_default_model() | ||
ishaansehgal99 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self._default_model | ||
|
||
# Dynamically extract the field from the response based on the specified response_field | ||
# completion_text = response_data.get(RESPONSE_FIELD, "No response field found") # not necessary for now | ||
return CompletionResponse(text=str(response_data)) | ||
def _post_request(self, data: dict, headers: dict) -> CompletionResponse: | ||
try: | ||
response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers) | ||
response.raise_for_status() # Raise exception for HTTP errors | ||
response_data = response.json() | ||
return CompletionResponse(text=str(response_data)) | ||
except requests.RequestException as e: | ||
logger.error(f"Error during POST request to {LLM_INFERENCE_URL}: {e}") | ||
raise | ||
|
||
def _debug_curl_command(self, data: dict) -> None: | ||
""" | ||
Constructs and prints the equivalent curl command for debugging purposes. | ||
""" | ||
import json | ||
# Construct curl command | ||
curl_command = ( | ||
f"curl -X POST {LLM_INFERENCE_URL} " | ||
+ " ".join([f'-H "{key}: {value}"' for key, value in { | ||
"Authorization": f"Bearer {LLM_ACCESS_SECRET}", | ||
"Content-Type": "application/json" | ||
}.items()]) | ||
+ f" -d '{json.dumps(data)}'" | ||
) | ||
logger.info("Equivalent curl command:") | ||
logger.info(curl_command) | ||
|
||
@property | ||
def metadata(self) -> LLMMetadata: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose this part do something similar to the vllm part but if you want to merge it now for efficiency, it is good