Skip to content

Commit

Permalink
Expand model configuration and logging (#90)
Browse files Browse the repository at this point in the history
* Cleanup

* Simplify pipeline config

* Only register Ollama routes when using an Ollama setup

* Allow for silent model request ignore

* Update default system prompt to reflect major version

* Fix system prompt handling and add env logging

* Allow model parameter changes
  • Loading branch information
TilmanGriesel authored Feb 1, 2025
1 parent 4d0dba4 commit a1acaad
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 86 deletions.
6 changes: 6 additions & 0 deletions services/api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ HF_API_KEY=your-huggingface-api-key
ALLOW_MODEL_PULL=true
ALLOW_MODEL_CHANGE=true
ALLOW_INDEX_CHANGE=true
ALLOW_MODEL_PARAMETER_CHANGE=true

# Compatibility
# Ignoring the requested model enables interoperability between
# Ollama clients and HuggingFace by using the API default model instead.
IGNORE_MODEL_REQUEST=false

# Embedding
EMBEDDING_MODEL_NAME=snowflake-arctic-embed2
Expand Down
2 changes: 1 addition & 1 deletion services/api/.systemprompt.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
You are Chipper, a helpful and professional assistant.
You are Chipper 2.0, a helpful and professional assistant.
Given the above conversation and the following information, answer the question.
Ignore your own knowledge.
11 changes: 9 additions & 2 deletions services/api/src/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,16 @@
APP_VERSION = os.getenv("APP_VERSION", "[DEV]")
BUILD_NUMBER = os.getenv("APP_BUILD_NUM", "0")

# Provider settings
PROVIDER_IS_OLLAMA = os.getenv("PROVIDER", "ollama") == "ollama"

# Feature flags
ALLOW_MODEL_CHANGE = os.getenv("ALLOW_MODEL_CHANGE", "true").lower() == "true"
ALLOW_INDEX_CHANGE = os.getenv("ALLOW_INDEX_CHANGE", "true").lower() == "true"
ALLOW_MODEL_PARAMETER_CHANGE = (
os.getenv("ALLOW_MODEL_PARAMETER_CHANGE", "true").lower() == "true"
)
IGNORE_MODEL_REQUEST = os.getenv("IGNORE_MODEL_REQUEST", "true").lower() == "true"
DEBUG = os.getenv("DEBUG", "true").lower() == "true"

# Rate limiting configuration
Expand All @@ -53,7 +60,7 @@ def load_systemprompt(base_path: str) -> str:
env_var_name = "SYSTEM_PROMPT"
env_prompt = os.getenv(env_var_name)

if env_prompt is not None:
if env_prompt is not None and env_prompt.strip() != "":
content = env_prompt.strip()
logger.info(
f"Using system prompt from '{env_var_name}' environment variable; content: '{content}'"
Expand Down Expand Up @@ -83,4 +90,4 @@ def load_systemprompt(base_path: str) -> str:
return default_prompt


system_prompt_value = load_systemprompt(os.getenv("SYSTEM_PROMPT_PATH", os.getcwd()))
SYSTEM_PROMPT_VALUE = load_systemprompt(os.getenv("SYSTEM_PROMPT_PATH", os.getcwd()))
230 changes: 156 additions & 74 deletions services/api/src/api/pipeline_config.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,187 @@
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, Optional

from api.config import system_prompt_value
from api.config import SYSTEM_PROMPT_VALUE, logger
from core.pipeline_config import ModelProvider, QueryPipelineConfig


def get_env_param(param_name, converter=None, default=None):
value = os.getenv(param_name)
class EnvKeys(str, Enum):
PROVIDER = "PROVIDER"
MODEL_NAME = "MODEL_NAME"
HF_MODEL_NAME = "HF_MODEL_NAME"
EMBEDDING_MODEL = "EMBEDDING_MODEL_NAME"
HF_EMBEDDING_MODEL = "HF_EMBEDDING_MODEL_NAME"
HF_API_KEY = "HF_API_KEY"
OLLAMA_URL = "OLLAMA_URL"
ALLOW_MODEL_PULL = "ALLOW_MODEL_PULL"
ES_URL = "ES_URL"
ES_INDEX = "ES_INDEX"
ES_TOP_K = "ES_TOP_K"
ES_NUM_CANDIDATES = "ES_NUM_CANDIDATES"
ES_BASIC_AUTH_USER = "ES_BASIC_AUTH_USERNAME"
ES_BASIC_AUTH_PASSWORD = "ES_BASIC_AUTH_PASSWORD"
ENABLE_CONVERSATION_LOGS = "ENABLE_CONVERSATION_LOGS"


@dataclass
class GenerationParams:
context_window: tuple[str, type, str] = ("CONTEXT_WINDOW", int, "8192")
temperature: tuple[str, type, None] = ("TEMPERATURE", float, None)
seed: tuple[str, type, None] = ("SEED", int, None)
top_k: tuple[str, type, None] = ("TOP_K", int, None)
top_p: tuple[str, type, None] = ("TOP_P", float, None)
min_p: tuple[str, type, None] = ("MIN_P", float, None)
repeat_last_n: tuple[str, type, None] = ("REPEAT_LAST_N", int, None)
repeat_penalty: tuple[str, type, None] = ("REPEAT_PENALTY", float, None)
num_predict: tuple[str, type, None] = ("NUM_PREDICT", int, None)
tfs_z: tuple[str, type, None] = ("TFS_Z", float, None)


def get_env_value(
key: str, converter: Optional[Callable] = None, default: Optional[str] = None
) -> Any:
"""Get and convert environment variable value with optional default."""
value = os.getenv(key)
if value is None:
return None

if converter is not None:
if converter:
try:
if default is not None and value == "":
return converter(default)
return converter(value)
return converter(default if value == "" else value)
except (ValueError, TypeError):
return None
return value


def create_pipeline_config(model: str = None, index: str = None) -> QueryPipelineConfig:
provider_name = os.getenv("PROVIDER", "ollama")
def get_provider_specific_config() -> dict[str, Any]:
"""Get provider-specific configuration."""
provider = (
ModelProvider.HUGGINGFACE
if provider_name.lower() == "hf"
if os.getenv(EnvKeys.PROVIDER, "ollama").lower() == "hf"
else ModelProvider.OLLAMA
)

if provider == ModelProvider.HUGGINGFACE:
model_name = model or os.getenv("HF_MODEL_NAME")
embedding_model = os.getenv("HF_EMBEDDING_MODEL_NAME")
else:
model_name = model or os.getenv("MODEL_NAME")
embedding_model = os.getenv("EMBEDDING_MODEL_NAME")

config_params = {
config = {
"provider": provider,
"embedding_model": embedding_model,
"model_name": model_name,
"system_prompt": system_prompt_value,
"model_name": os.getenv(
EnvKeys.HF_MODEL_NAME
if provider == ModelProvider.HUGGINGFACE
else EnvKeys.MODEL_NAME
),
"embedding_model": os.getenv(
EnvKeys.HF_EMBEDDING_MODEL
if provider == ModelProvider.HUGGINGFACE
else EnvKeys.EMBEDDING_MODEL
),
"system_prompt": SYSTEM_PROMPT_VALUE,
}

# Provider specific parameters
if provider == ModelProvider.HUGGINGFACE:
if (hf_key := os.getenv("HF_API_KEY")) is not None:
config_params["hf_api_key"] = hf_key
else:
if (ollama_url := os.getenv("OLLAMA_URL")) is not None:
config_params["ollama_url"] = ollama_url

# Model pull configuration
allow_pull = os.getenv("ALLOW_MODEL_PULL")
if allow_pull is not None:
config_params["allow_model_pull"] = allow_pull.lower() == "true"

# Core generation parameters
if (context_window := get_env_param("CONTEXT_WINDOW", int, "8192")) is not None:
config_params["context_window"] = context_window

for param in ["TEMPERATURE", "SEED", "TOP_K"]:
if (
value := get_env_param(param, float if param == "TEMPERATURE" else int)
) is not None:
config_params[param.lower()] = value

# Advanced sampling parameters
for param in ["TOP_P", "MIN_P"]:
if (value := get_env_param(param, float)) is not None:
config_params[param.lower()] = value

# Mirostat parameters
if (mirostat := get_env_param("MIROSTAT", int)) is not None:
config_params["mirostat"] = mirostat
for param in ["MIROSTAT_ETA", "MIROSTAT_TAU"]:
if (value := get_env_param(param, float)) is not None:
config_params[param.lower()] = value
config["hf_api_key"] = os.getenv(EnvKeys.HF_API_KEY)
elif ollama_url := os.getenv(EnvKeys.OLLAMA_URL):
config["ollama_url"] = ollama_url

return config


def get_elasticsearch_config(index: Optional[str] = None) -> dict[str, Any]:
"""Get Elasticsearch configuration if enabled."""
if not (es_url := os.getenv(EnvKeys.ES_URL)):
return {}

# Elasticsearch parameters
if (es_url := os.getenv("ES_URL")) is not None:
config_params["es_url"] = es_url
config = {
"es_url": es_url,
"es_index": index or os.getenv(EnvKeys.ES_INDEX),
"es_basic_auth_user": os.getenv(EnvKeys.ES_BASIC_AUTH_USER),
"es_basic_auth_password": os.getenv(EnvKeys.ES_BASIC_AUTH_PASSWORD),
}

for env_key, default in [
(EnvKeys.ES_TOP_K, "5"),
(EnvKeys.ES_NUM_CANDIDATES, "-1"),
]:
if value := get_env_value(env_key, int, default):
config[env_key.lower()] = value

return config


def create_pipeline_config(
model: Optional[str] = None,
index: Optional[str] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
num_predict: Optional[int] = None,
tfs_z: Optional[float] = None,
context_window: Optional[int] = None,
seed: Optional[int] = None,
**additional_params: Dict[str, Any],
) -> QueryPipelineConfig:
"""Create pipeline configuration from environment variables with optional parameter overrides."""
config = get_provider_specific_config()
if model:
config["model_name"] = model

# Add generation parameters from environment first
params = GenerationParams()
for param in params.__annotations__:
env_key, converter, default = getattr(params, param)
if value := get_env_value(env_key, converter, default):
config[param] = value

# Override with any provided parameters
generation_params = {
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"min_p": min_p,
"repeat_last_n": repeat_last_n,
"repeat_penalty": repeat_penalty,
"num_predict": num_predict,
"tfs_z": tfs_z,
"context_window": context_window,
"seed": seed,
}

if index is not None:
config_params["es_index"] = index
elif (es_index := os.getenv("ES_INDEX")) is not None:
config_params["es_index"] = es_index
# Update config with provided non-None parameters
config.update({k: v for k, v in generation_params.items() if v is not None})

# Add any additional parameters passed
config.update(additional_params)

# Add mirostat parameters
if mirostat := get_env_value("MIROSTAT", int):
config["mirostat"] = mirostat
for param in ["MIROSTAT_ETA", "MIROSTAT_TAU"]:
if value := get_env_value(param, float):
config[param.lower()] = value

if (es_top_k := get_env_param("ES_TOP_K", int, "5")) is not None:
config_params["es_top_k"] = es_top_k
# Add model pull configuration
if allow_pull := os.getenv(EnvKeys.ALLOW_MODEL_PULL):
config["allow_model_pull"] = allow_pull.lower() == "true"

if (
es_num_candidates := get_env_param("ES_NUM_CANDIDATES", int, "-1")
) is not None:
config_params["es_num_candidates"] = es_num_candidates
# Add conversation logs setting
if value := os.getenv(EnvKeys.ENABLE_CONVERSATION_LOGS):
config["enable_conversation_logs"] = value.lower() == "true"

if (es_user := os.getenv("ES_BASIC_AUTH_USERNAME")) is not None:
config_params["es_basic_auth_user"] = es_user
# Add stop sequence
if stop_sequence := os.getenv("STOP_SEQUENCE"):
config["stop_sequence"] = stop_sequence

if (es_pass := os.getenv("ES_BASIC_AUTH_PASSWORD")) is not None:
config_params["es_basic_auth_password"] = es_pass
# Add Elasticsearch config
config.update(get_elasticsearch_config(index))

if (enable_conversation_logs := os.getenv("ENABLE_CONVERSATION_LOGS")) is not None:
config_params["enable_conversation_logs"] = enable_conversation_logs
logger.info("\nPipeline Configuration:")
for key, value in sorted(config.items()):
if any(sensitive in key.lower() for sensitive in ["password", "key", "auth"]):
logger.info(f" {key}: ****")
else:
logger.info(f" {key}: {value}")

return QueryPipelineConfig(**config_params)
return QueryPipelineConfig(**config)
30 changes: 26 additions & 4 deletions services/api/src/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from api.config import (
ALLOW_INDEX_CHANGE,
ALLOW_MODEL_CHANGE,
ALLOW_MODEL_PARAMETER_CHANGE,
APP_VERSION,
BUILD_NUMBER,
DEBUG,
IGNORE_MODEL_REQUEST,
logger,
)
from api.handlers import handle_standard_response, handle_streaming_response
Expand Down Expand Up @@ -66,9 +68,11 @@ def chat():
if not messages:
abort(400, description="No messages provided")

model = data.get("model")
if model and not ALLOW_MODEL_CHANGE:
abort(403, description="Model changes are not allowed")
model = None
if not IGNORE_MODEL_REQUEST:
model = data.get("model")
if model and not ALLOW_MODEL_CHANGE:
abort(403, description="Model changes are not allowed")

# Validate message format
for message in messages:
Expand All @@ -88,6 +92,17 @@ def chat():
options = data.get("options", {})
stream = data.get("stream", True)

temperature = None
top_k = None
top_p = None
seed = None

if ALLOW_MODEL_PARAMETER_CHANGE:
temperature = data.get("temperature", None)
top_k = data.get("top_k", None)
top_p = data.get("top_p", None)
seed = data.get("top_p", None)

# Handle index parameter
index = options.get("index")
if index and not ALLOW_INDEX_CHANGE:
Expand All @@ -99,7 +114,14 @@ def chat():
abort(400, description="Images must be provided as a list")

# Create configuration
config = create_pipeline_config(model, index)
config = create_pipeline_config(
model=model,
index=index,
temperature=temperature,
top_k=top_k,
top_p=top_p,
seed=seed,
)

# Get the latest message with content
query = None
Expand Down
9 changes: 5 additions & 4 deletions services/api/src/api/routes_setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from api.config import logger
from api.config import PROVIDER_IS_OLLAMA, logger
from api.ollama_routes import setup_ollama_routes
from api.routes import register_chat_routes, register_health_routes
from flask import Flask


def setup_all_routes(app: Flask):
try:
# Setup Ollama-specific routes
setup_ollama_routes(app)
logger.info("Ollama routes registered successfully")
if PROVIDER_IS_OLLAMA:
# Setup Ollama-specific routes
setup_ollama_routes(app)
logger.info("Ollama routes registered successfully")

# Setup chat routes (chat, streaming, etc)
register_chat_routes(app)
Expand Down
Loading

0 comments on commit a1acaad

Please sign in to comment.