-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Expand model configuration and logging (#90)
* Cleanup * Simplify pipeline config * Only register Ollama routes when using an Ollama setup * Allow for silent model request ignore * Update default system prompt to reflect major version * Fix system prompt handling and add env logging * Allow model parameter changes
- Loading branch information
1 parent
4d0dba4
commit a1acaad
Showing
7 changed files
with
203 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
You are Chipper, a helpful and professional assistant. | ||
You are Chipper 2.0, a helpful and professional assistant. | ||
Given the above conversation and the following information, answer the question. | ||
Ignore your own knowledge. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,105 +1,187 @@ | ||
import os | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
from api.config import system_prompt_value | ||
from api.config import SYSTEM_PROMPT_VALUE, logger | ||
from core.pipeline_config import ModelProvider, QueryPipelineConfig | ||
|
||
|
||
def get_env_param(param_name, converter=None, default=None): | ||
value = os.getenv(param_name) | ||
class EnvKeys(str, Enum): | ||
PROVIDER = "PROVIDER" | ||
MODEL_NAME = "MODEL_NAME" | ||
HF_MODEL_NAME = "HF_MODEL_NAME" | ||
EMBEDDING_MODEL = "EMBEDDING_MODEL_NAME" | ||
HF_EMBEDDING_MODEL = "HF_EMBEDDING_MODEL_NAME" | ||
HF_API_KEY = "HF_API_KEY" | ||
OLLAMA_URL = "OLLAMA_URL" | ||
ALLOW_MODEL_PULL = "ALLOW_MODEL_PULL" | ||
ES_URL = "ES_URL" | ||
ES_INDEX = "ES_INDEX" | ||
ES_TOP_K = "ES_TOP_K" | ||
ES_NUM_CANDIDATES = "ES_NUM_CANDIDATES" | ||
ES_BASIC_AUTH_USER = "ES_BASIC_AUTH_USERNAME" | ||
ES_BASIC_AUTH_PASSWORD = "ES_BASIC_AUTH_PASSWORD" | ||
ENABLE_CONVERSATION_LOGS = "ENABLE_CONVERSATION_LOGS" | ||
|
||
|
||
@dataclass | ||
class GenerationParams: | ||
context_window: tuple[str, type, str] = ("CONTEXT_WINDOW", int, "8192") | ||
temperature: tuple[str, type, None] = ("TEMPERATURE", float, None) | ||
seed: tuple[str, type, None] = ("SEED", int, None) | ||
top_k: tuple[str, type, None] = ("TOP_K", int, None) | ||
top_p: tuple[str, type, None] = ("TOP_P", float, None) | ||
min_p: tuple[str, type, None] = ("MIN_P", float, None) | ||
repeat_last_n: tuple[str, type, None] = ("REPEAT_LAST_N", int, None) | ||
repeat_penalty: tuple[str, type, None] = ("REPEAT_PENALTY", float, None) | ||
num_predict: tuple[str, type, None] = ("NUM_PREDICT", int, None) | ||
tfs_z: tuple[str, type, None] = ("TFS_Z", float, None) | ||
|
||
|
||
def get_env_value( | ||
key: str, converter: Optional[Callable] = None, default: Optional[str] = None | ||
) -> Any: | ||
"""Get and convert environment variable value with optional default.""" | ||
value = os.getenv(key) | ||
if value is None: | ||
return None | ||
|
||
if converter is not None: | ||
if converter: | ||
try: | ||
if default is not None and value == "": | ||
return converter(default) | ||
return converter(value) | ||
return converter(default if value == "" else value) | ||
except (ValueError, TypeError): | ||
return None | ||
return value | ||
|
||
|
||
def create_pipeline_config(model: str = None, index: str = None) -> QueryPipelineConfig: | ||
provider_name = os.getenv("PROVIDER", "ollama") | ||
def get_provider_specific_config() -> dict[str, Any]: | ||
"""Get provider-specific configuration.""" | ||
provider = ( | ||
ModelProvider.HUGGINGFACE | ||
if provider_name.lower() == "hf" | ||
if os.getenv(EnvKeys.PROVIDER, "ollama").lower() == "hf" | ||
else ModelProvider.OLLAMA | ||
) | ||
|
||
if provider == ModelProvider.HUGGINGFACE: | ||
model_name = model or os.getenv("HF_MODEL_NAME") | ||
embedding_model = os.getenv("HF_EMBEDDING_MODEL_NAME") | ||
else: | ||
model_name = model or os.getenv("MODEL_NAME") | ||
embedding_model = os.getenv("EMBEDDING_MODEL_NAME") | ||
|
||
config_params = { | ||
config = { | ||
"provider": provider, | ||
"embedding_model": embedding_model, | ||
"model_name": model_name, | ||
"system_prompt": system_prompt_value, | ||
"model_name": os.getenv( | ||
EnvKeys.HF_MODEL_NAME | ||
if provider == ModelProvider.HUGGINGFACE | ||
else EnvKeys.MODEL_NAME | ||
), | ||
"embedding_model": os.getenv( | ||
EnvKeys.HF_EMBEDDING_MODEL | ||
if provider == ModelProvider.HUGGINGFACE | ||
else EnvKeys.EMBEDDING_MODEL | ||
), | ||
"system_prompt": SYSTEM_PROMPT_VALUE, | ||
} | ||
|
||
# Provider specific parameters | ||
if provider == ModelProvider.HUGGINGFACE: | ||
if (hf_key := os.getenv("HF_API_KEY")) is not None: | ||
config_params["hf_api_key"] = hf_key | ||
else: | ||
if (ollama_url := os.getenv("OLLAMA_URL")) is not None: | ||
config_params["ollama_url"] = ollama_url | ||
|
||
# Model pull configuration | ||
allow_pull = os.getenv("ALLOW_MODEL_PULL") | ||
if allow_pull is not None: | ||
config_params["allow_model_pull"] = allow_pull.lower() == "true" | ||
|
||
# Core generation parameters | ||
if (context_window := get_env_param("CONTEXT_WINDOW", int, "8192")) is not None: | ||
config_params["context_window"] = context_window | ||
|
||
for param in ["TEMPERATURE", "SEED", "TOP_K"]: | ||
if ( | ||
value := get_env_param(param, float if param == "TEMPERATURE" else int) | ||
) is not None: | ||
config_params[param.lower()] = value | ||
|
||
# Advanced sampling parameters | ||
for param in ["TOP_P", "MIN_P"]: | ||
if (value := get_env_param(param, float)) is not None: | ||
config_params[param.lower()] = value | ||
|
||
# Mirostat parameters | ||
if (mirostat := get_env_param("MIROSTAT", int)) is not None: | ||
config_params["mirostat"] = mirostat | ||
for param in ["MIROSTAT_ETA", "MIROSTAT_TAU"]: | ||
if (value := get_env_param(param, float)) is not None: | ||
config_params[param.lower()] = value | ||
config["hf_api_key"] = os.getenv(EnvKeys.HF_API_KEY) | ||
elif ollama_url := os.getenv(EnvKeys.OLLAMA_URL): | ||
config["ollama_url"] = ollama_url | ||
|
||
return config | ||
|
||
|
||
def get_elasticsearch_config(index: Optional[str] = None) -> dict[str, Any]: | ||
"""Get Elasticsearch configuration if enabled.""" | ||
if not (es_url := os.getenv(EnvKeys.ES_URL)): | ||
return {} | ||
|
||
# Elasticsearch parameters | ||
if (es_url := os.getenv("ES_URL")) is not None: | ||
config_params["es_url"] = es_url | ||
config = { | ||
"es_url": es_url, | ||
"es_index": index or os.getenv(EnvKeys.ES_INDEX), | ||
"es_basic_auth_user": os.getenv(EnvKeys.ES_BASIC_AUTH_USER), | ||
"es_basic_auth_password": os.getenv(EnvKeys.ES_BASIC_AUTH_PASSWORD), | ||
} | ||
|
||
for env_key, default in [ | ||
(EnvKeys.ES_TOP_K, "5"), | ||
(EnvKeys.ES_NUM_CANDIDATES, "-1"), | ||
]: | ||
if value := get_env_value(env_key, int, default): | ||
config[env_key.lower()] = value | ||
|
||
return config | ||
|
||
|
||
def create_pipeline_config( | ||
model: Optional[str] = None, | ||
index: Optional[str] = None, | ||
temperature: Optional[float] = None, | ||
top_k: Optional[int] = None, | ||
top_p: Optional[float] = None, | ||
min_p: Optional[float] = None, | ||
repeat_last_n: Optional[int] = None, | ||
repeat_penalty: Optional[float] = None, | ||
num_predict: Optional[int] = None, | ||
tfs_z: Optional[float] = None, | ||
context_window: Optional[int] = None, | ||
seed: Optional[int] = None, | ||
**additional_params: Dict[str, Any], | ||
) -> QueryPipelineConfig: | ||
"""Create pipeline configuration from environment variables with optional parameter overrides.""" | ||
config = get_provider_specific_config() | ||
if model: | ||
config["model_name"] = model | ||
|
||
# Add generation parameters from environment first | ||
params = GenerationParams() | ||
for param in params.__annotations__: | ||
env_key, converter, default = getattr(params, param) | ||
if value := get_env_value(env_key, converter, default): | ||
config[param] = value | ||
|
||
# Override with any provided parameters | ||
generation_params = { | ||
"temperature": temperature, | ||
"top_k": top_k, | ||
"top_p": top_p, | ||
"min_p": min_p, | ||
"repeat_last_n": repeat_last_n, | ||
"repeat_penalty": repeat_penalty, | ||
"num_predict": num_predict, | ||
"tfs_z": tfs_z, | ||
"context_window": context_window, | ||
"seed": seed, | ||
} | ||
|
||
if index is not None: | ||
config_params["es_index"] = index | ||
elif (es_index := os.getenv("ES_INDEX")) is not None: | ||
config_params["es_index"] = es_index | ||
# Update config with provided non-None parameters | ||
config.update({k: v for k, v in generation_params.items() if v is not None}) | ||
|
||
# Add any additional parameters passed | ||
config.update(additional_params) | ||
|
||
# Add mirostat parameters | ||
if mirostat := get_env_value("MIROSTAT", int): | ||
config["mirostat"] = mirostat | ||
for param in ["MIROSTAT_ETA", "MIROSTAT_TAU"]: | ||
if value := get_env_value(param, float): | ||
config[param.lower()] = value | ||
|
||
if (es_top_k := get_env_param("ES_TOP_K", int, "5")) is not None: | ||
config_params["es_top_k"] = es_top_k | ||
# Add model pull configuration | ||
if allow_pull := os.getenv(EnvKeys.ALLOW_MODEL_PULL): | ||
config["allow_model_pull"] = allow_pull.lower() == "true" | ||
|
||
if ( | ||
es_num_candidates := get_env_param("ES_NUM_CANDIDATES", int, "-1") | ||
) is not None: | ||
config_params["es_num_candidates"] = es_num_candidates | ||
# Add conversation logs setting | ||
if value := os.getenv(EnvKeys.ENABLE_CONVERSATION_LOGS): | ||
config["enable_conversation_logs"] = value.lower() == "true" | ||
|
||
if (es_user := os.getenv("ES_BASIC_AUTH_USERNAME")) is not None: | ||
config_params["es_basic_auth_user"] = es_user | ||
# Add stop sequence | ||
if stop_sequence := os.getenv("STOP_SEQUENCE"): | ||
config["stop_sequence"] = stop_sequence | ||
|
||
if (es_pass := os.getenv("ES_BASIC_AUTH_PASSWORD")) is not None: | ||
config_params["es_basic_auth_password"] = es_pass | ||
# Add Elasticsearch config | ||
config.update(get_elasticsearch_config(index)) | ||
|
||
if (enable_conversation_logs := os.getenv("ENABLE_CONVERSATION_LOGS")) is not None: | ||
config_params["enable_conversation_logs"] = enable_conversation_logs | ||
logger.info("\nPipeline Configuration:") | ||
for key, value in sorted(config.items()): | ||
if any(sensitive in key.lower() for sensitive in ["password", "key", "auth"]): | ||
logger.info(f" {key}: ****") | ||
else: | ||
logger.info(f" {key}: {value}") | ||
|
||
return QueryPipelineConfig(**config_params) | ||
return QueryPipelineConfig(**config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.