From c5079afc9a0acf835eedb60fa8f907d7446006b8 Mon Sep 17 00:00:00 2001 From: Tilman Griesel Date: Wed, 29 Jan 2025 09:16:53 +0100 Subject: [PATCH] feat: Mirror Ollama chat API --- docker/docker-compose.dev.yml | 2 +- docker/docker-compose.prod.yml | 2 +- services/api/Makefile | 2 +- services/api/src/api/config.py | 86 +++ services/api/src/api/handlers.py | 255 +++++++++ services/api/src/api/middleware.py | 104 ++++ .../api/src/{core => api}/ollama_proxy.py | 0 services/api/src/api/ollama_routes.py | 48 ++ services/api/src/api/pipeline_config.py | 105 ++++ services/api/src/api/routes.py | 148 +++++ services/api/src/api/routes_setup.py | 23 + services/api/src/core/rag_pipeline.py | 3 +- services/api/src/main.py | 512 +----------------- tools/cli/tools/api_mirror_tester/src/main.py | 2 +- tools/cli/tools/index.html | 2 +- tools/cli/tools/test_non_streaming.sh | 2 +- tools/cli/tools/test_streaming.sh | 2 +- 17 files changed, 804 insertions(+), 494 deletions(-) create mode 100644 services/api/src/api/config.py create mode 100644 services/api/src/api/handlers.py create mode 100644 services/api/src/api/middleware.py rename services/api/src/{core => api}/ollama_proxy.py (100%) create mode 100644 services/api/src/api/ollama_routes.py create mode 100644 services/api/src/api/pipeline_config.py create mode 100644 services/api/src/api/routes.py create mode 100644 services/api/src/api/routes_setup.py diff --git a/docker/docker-compose.dev.yml b/docker/docker-compose.dev.yml index a3f238c..e8acc78 100644 --- a/docker/docker-compose.dev.yml +++ b/docker/docker-compose.dev.yml @@ -27,7 +27,7 @@ services: args: BUILD_ENV: development ports: - - 21210:8000 + - 21434:8000 env_file: ../services/api/.env depends_on: elasticsearch: diff --git a/docker/docker-compose.prod.yml b/docker/docker-compose.prod.yml index 164d20c..8d8ba2f 100644 --- a/docker/docker-compose.prod.yml +++ b/docker/docker-compose.prod.yml @@ -11,7 +11,7 @@ services: api: image: griesel/chipper:api-latest ports: - - 127.0.0.1:21210:8000 + - 127.0.0.1:21434:8000 env_file: ../services/api/.env depends_on: elasticsearch: diff --git a/services/api/Makefile b/services/api/Makefile index 671b987..2f4237d 100644 --- a/services/api/Makefile +++ b/services/api/Makefile @@ -34,7 +34,7 @@ run-dev: --hostname api \ --env-file .env \ -v "$(CURDIR)/src:/app/src:z" \ - -p 21210:8000 \ + -p 21434:8000 \ --name $(CONTAINER_NAME) \ --network=$(NETWORK) \ $(IMAGE_NAME) diff --git a/services/api/src/api/config.py b/services/api/src/api/config.py new file mode 100644 index 0000000..a307d98 --- /dev/null +++ b/services/api/src/api/config.py @@ -0,0 +1,86 @@ +import logging +import os +import secrets +from pathlib import Path + +from dotenv import load_dotenv +from flask import Flask +from flask_limiter import Limiter +from flask_limiter.util import get_remote_address +from werkzeug.middleware.proxy_fix import ProxyFix + +# Load environment variables +load_dotenv() + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# App configuration +app = Flask(__name__) +app.wsgi_app = ProxyFix(app.wsgi_app) + +# Version information +APP_VERSION = os.getenv("APP_VERSION", "[DEV]") +BUILD_NUMBER = os.getenv("APP_BUILD_NUM", "0") + +# Feature flags +ALLOW_MODEL_CHANGE = os.getenv("ALLOW_MODEL_CHANGE", "true").lower() == "true" +ALLOW_INDEX_CHANGE = os.getenv("ALLOW_INDEX_CHANGE", "true").lower() == "true" +DEBUG = os.getenv("DEBUG", "true").lower() == "true" + +# Rate limiting configuration +DAILY_LIMIT = int(os.getenv("DAILY_RATE_LIMIT", "86400")) +MINUTE_LIMIT = int(os.getenv("MINUTE_RATE_LIMIT", "60")) +STORAGE_URI = os.getenv("RATE_LIMIT_STORAGE", "memory://") + +limiter = Limiter( + key_func=get_remote_address, + app=app, + default_limits=[f"{DAILY_LIMIT} per day", f"{MINUTE_LIMIT} per minute"], + storage_uri=STORAGE_URI, +) + +# API Key configuration +API_KEY = os.getenv("API_KEY") +if not API_KEY: + API_KEY = secrets.token_urlsafe(32) + logger.info(f"Generated API key: {API_KEY}") + + +def load_systemprompt(base_path: str) -> str: + default_prompt = "" + env_var_name = "SYSTEM_PROMPT" + env_prompt = os.getenv(env_var_name) + + if env_prompt is not None: + content = env_prompt.strip() + logger.info( + f"Using system prompt from '{env_var_name}' environment variable; content: '{content}'" + ) + return content + + file = Path(base_path) / ".systemprompt" + if not file.exists(): + logger.info("No .systemprompt file found. Using default prompt.") + return default_prompt + + try: + with open(file, "r", encoding="utf-8") as f: + content = f.read().strip() + + if not content: + logger.warning("System prompt file is empty. Using default prompt.") + return default_prompt + + logger.info( + f"Successfully loaded system prompt from {file}; content: '{content}'" + ) + return content + + except Exception as e: + logger.error(f"Error reading system prompt file: {e}") + return default_prompt + + +system_prompt_value = load_systemprompt(os.getenv("SYSTEM_PROMPT_PATH", os.getcwd())) diff --git a/services/api/src/api/handlers.py b/services/api/src/api/handlers.py new file mode 100644 index 0000000..41ba26a --- /dev/null +++ b/services/api/src/api/handlers.py @@ -0,0 +1,255 @@ +import json +import queue +import threading +import time +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +import elasticsearch +from api.config import DEBUG, logger +from core.pipeline_config import QueryPipelineConfig +from core.rag_pipeline import RAGQueryPipeline +from flask import Response, jsonify, stream_with_context + + +def format_stream_response( + config: QueryPipelineConfig, + content: str = "", + done: bool = False, + done_reason: Optional[str] = None, + images: Optional[List[str]] = None, + tool_calls: Optional[List[Dict[str, Any]]] = None, + **metrics, +) -> Dict[str, Any]: + """Format streaming response according to Ollama-API specification.""" + response = { + "model": config.model_name, + "created_at": datetime.now(timezone.utc).isoformat(), + "done": done, + } + + if not done: + message = {"role": "assistant", "content": content} + if images: + message["images"] = images + if tool_calls: + message["tool_calls"] = tool_calls + response["message"] = message + else: + if done_reason: + response["done_reason"] = done_reason + response.update( + { + "total_duration": metrics.get("total_duration", 0), + "load_duration": metrics.get("load_duration", 0), + "prompt_eval_count": metrics.get("prompt_eval_count", 0), + "prompt_eval_duration": metrics.get("prompt_eval_duration", 0), + "eval_count": metrics.get("eval_count", 0), + "eval_duration": metrics.get("eval_duration", 0), + } + ) + + return response + + +def handle_streaming_response( + config: QueryPipelineConfig, + query: str, + conversation: List[Dict[str, str]], + format_schema: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, +) -> Response: + q = queue.Queue() + start_time = time.time_ns() + prompt_start = None + + def streaming_callback(chunk): + nonlocal prompt_start + if prompt_start is None: + prompt_start = time.time_ns() + + if chunk.content: + if format_schema and chunk.is_final: + try: + content = json.loads(chunk.content) + response_data = format_stream_response( + config, json.dumps(content), done=True, done_reason="stop" + ) + except json.JSONDecodeError: + response_data = format_stream_response( + config, + "Error: Failed to generate valid JSON response.", + done=True, + done_reason="error", + ) + else: + response_data = format_stream_response( + config, + chunk.content, + images=getattr(chunk, "images", None), + tool_calls=getattr(chunk, "tool_calls", None), + ) + + q.put(json.dumps(response_data) + "\n") + + rag = RAGQueryPipeline(config=config, streaming_callback=streaming_callback) + + def run_rag(): + try: + # Track model loading + load_start = time.time_ns() + for status in rag.initialize_and_check_models(): + if status.get("status") == "error": + error_data = format_stream_response( + config, + f"Error: Model initialization failed - {status.get('error')}", + done=True, + done_reason="error", + ) + q.put(json.dumps(error_data) + "\n") + return + + load_duration = time.time_ns() - load_start + + rag.create_query_pipeline() + result = rag.run_query( + query=query, conversation=conversation, print_response=DEBUG + ) + + # Calculate final metrics + end_time = time.time_ns() + final_data = format_stream_response( + config, + done=True, + done_reason="stop", + total_duration=end_time - start_time, + load_duration=load_duration, + prompt_eval_count=len(conversation) + 1, + prompt_eval_duration=end_time - (prompt_start or start_time), + eval_count=len(result["llm"]["replies"][0].split()) + if result and "llm" in result and "replies" in result["llm"] + else 0, + eval_duration=end_time - (prompt_start or start_time), + ) + q.put(json.dumps(final_data) + "\n") + + except elasticsearch.BadRequestError as e: + error_data = format_stream_response( + config, + f"Error: Embedding retriever error - {str(e)}", + done=True, + done_reason="error", + ) + q.put(json.dumps(error_data) + "\n") + + except Exception as e: + error_data = format_stream_response( + config, f"Error: {str(e)}", done=True, done_reason="error" + ) + logger.error(f"Error in RAG pipeline: {e}", exc_info=True) + q.put(json.dumps(error_data) + "\n") + + thread = threading.Thread(target=run_rag, daemon=True) + thread.start() + + def generate(): + while True: + try: + data = q.get(timeout=120) + if data: + yield data + + if '"done": true' in data: + logger.info("Streaming completed.") + break + + except queue.Empty: + # Send an empty object for heartbeat + yield json.dumps({}) + "\n" + logger.warning("Queue timeout. Sending heartbeat.") + except Exception as e: + logger.error(f"Streaming error: {e}") + error_data = format_stream_response( + config, "Streaming error occurred.", done=True, done_reason="error" + ) + yield json.dumps(error_data) + "\n" + break + + return Response( + stream_with_context(generate()), + mimetype="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", + }, + ) + + +def handle_standard_response( + config: QueryPipelineConfig, + query: str, + conversation: List[Dict[str, str]], + format_schema: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, +) -> Response: + """Handle non-streaming responses with support for structured outputs.""" + start_time = time.time_ns() + rag = RAGQueryPipeline(config=config) + + try: + # Track model loading time + load_start = time.time_ns() + for status in rag.initialize_and_check_models(): + if status.get("status") == "error": + raise Exception(f"Model initialization failed: {status.get('error')}") + load_duration = time.time_ns() - load_start + + rag.create_query_pipeline() + + # Track query execution time + prompt_start = time.time_ns() + result = rag.run_query( + query=query, conversation=conversation, print_response=False + ) + end_time = time.time_ns() + + if result and "llm" in result and "replies" in result["llm"]: + response_content = result["llm"]["replies"][0] + + # Handle structured output if format_schema is provided + if format_schema: + try: + content = json.loads(response_content) + response_content = json.dumps(content) + except json.JSONDecodeError: + raise Exception("Failed to generate valid JSON response") + + eval_count = len(response_content.split()) if response_content else 0 + + response = { + "model": config.model_name, + "created_at": datetime.now(timezone.utc).isoformat(), + "message": {"role": "assistant", "content": response_content}, + "done": True, + "done_reason": "stop", + "total_duration": end_time - start_time, + "load_duration": load_duration, + "prompt_eval_count": len(conversation) + 1, + "prompt_eval_duration": end_time - prompt_start, + "eval_count": eval_count, + "eval_duration": end_time - prompt_start, + } + + return jsonify(response) + + except Exception as e: + logger.error(f"Error in RAG pipeline: {e}", exc_info=True) + error_response = { + "model": config.model_name, + "created_at": datetime.now(timezone.utc).isoformat(), + "done": True, + "done_reason": "error", + "error": str(e), + } + return jsonify(error_response) diff --git a/services/api/src/api/middleware.py b/services/api/src/api/middleware.py new file mode 100644 index 0000000..a857f60 --- /dev/null +++ b/services/api/src/api/middleware.py @@ -0,0 +1,104 @@ +import os +from functools import wraps + +from api.config import API_KEY, app, logger +from flask import abort, request + + +def require_api_key(f): + @wraps(f) + def decorated_function(*args, **kwargs): + require_api_key = os.getenv("REQUIRE_API_KEY", "true") + require_api_key = require_api_key.lower() == "true" + + if not require_api_key: + return f(*args, **kwargs) + + api_key = request.headers.get("X-API-Key") + if not api_key or api_key != API_KEY: + logger.warning(f"Invalid API key attempt from {request.remote_addr}") + abort(401, description="Invalid or missing API key") + + return f(*args, **kwargs) + + return decorated_function + + +def setup_security_middleware(app): + @app.before_request + def before_request(): + logger.info( + f"Request {request.method} {request.path} from {request.remote_addr}" + ) + if ( + os.getenv("REQUIRE_SECURE", "False").lower() == "true" + and not request.is_secure + ): + logger.warning(f"Insecure request attempt from {request.remote_addr}") + abort(403, description="HTTPS required") + + @app.after_request + def after_request(response): + response.headers.update( + { + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "1; mode=block", + "Content-Security-Policy": "default-src 'self'", + "Referrer-Policy": "strict-origin-when-cross-origin", + } + ) + + if os.getenv("ENABLE_CORS", "False").lower() == "true": + allowed_origins = os.getenv("CORS_ALLOWED_ORIGINS", "*") + response.headers["Access-Control-Allow-Origin"] = allowed_origins + response.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS" + response.headers["Access-Control-Allow-Headers"] = "Content-Type, X-API-Key" + + return response + + @app.errorhandler(401) + def unauthorized_error(error): + return {"error": "Unauthorized", "message": str(error.description)}, 401 + + @app.errorhandler(403) + def forbidden_error(error): + return {"error": "Forbidden", "message": str(error.description)}, 403 + + @app.errorhandler(500) + def internal_error(error): + logger.error(f"Internal server error: {error}", exc_info=True) + return { + "error": "Internal Server Error", + "message": "An unexpected error occurred", + }, 500 + + +def setup_request_logging_middleware(app): + @app.before_request + def log_request_info(): + if request.path == "/health": + return + + log_data = { + "method": request.method, + "path": request.path, + "remote_addr": request.remote_addr, + "user_agent": request.headers.get("User-Agent"), + "request_id": request.headers.get("X-Request-ID"), + } + + logger.info("Incoming request", extra=log_data) + + +def init_middleware(app): + """ + Initialize all middleware + """ + setup_security_middleware(app) + setup_request_logging_middleware(app) + logger.info("Middleware initialized successfully") + + +init_middleware(app) diff --git a/services/api/src/core/ollama_proxy.py b/services/api/src/api/ollama_proxy.py similarity index 100% rename from services/api/src/core/ollama_proxy.py rename to services/api/src/api/ollama_proxy.py diff --git a/services/api/src/api/ollama_routes.py b/services/api/src/api/ollama_routes.py new file mode 100644 index 0000000..7f570eb --- /dev/null +++ b/services/api/src/api/ollama_routes.py @@ -0,0 +1,48 @@ +import os + +from api.config import logger +from api.middleware import require_api_key +from api.ollama_proxy import OllamaProxy + + +class OllamaRoutes: + def __init__(self, app, proxy: OllamaProxy): + self.app = app + self.proxy = proxy + self.register_routes() + + def register_routes(self): + @self.app.route("/api/generate", methods=["POST"]) + @require_api_key + def generate(): + try: + return self.proxy.generate() + except Exception as e: + logger.error(f"Error in generate endpoint: {e}") + return {"error": str(e)}, 500 + + @self.app.route("/api/tags", methods=["GET"]) + @require_api_key + def tags(): + try: + return self.proxy.tags() + except Exception as e: + logger.error(f"Error in tags endpoint: {e}") + return {"error": str(e)}, 500 + + @self.app.route("/api/pull", methods=["POST"]) + @require_api_key + def pull(): + try: + return self.proxy.pull() + except Exception as e: + logger.error(f"Error in pull endpoint: {e}") + return {"error": str(e)}, 500 + + +def setup_ollama_routes(app): + ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434") + proxy = OllamaProxy(ollama_url) + OllamaRoutes(app, proxy) + logger.info(f"Initialized Ollama routes with URL: {ollama_url}") + return proxy diff --git a/services/api/src/api/pipeline_config.py b/services/api/src/api/pipeline_config.py new file mode 100644 index 0000000..fd23a0a --- /dev/null +++ b/services/api/src/api/pipeline_config.py @@ -0,0 +1,105 @@ +import os + +from api.config import system_prompt_value +from core.pipeline_config import ModelProvider, QueryPipelineConfig + + +def get_env_param(param_name, converter=None, default=None): + value = os.getenv(param_name) + if value is None: + return None + + if converter is not None: + try: + if default is not None and value == "": + return converter(default) + return converter(value) + except (ValueError, TypeError): + return None + return value + + +def create_pipeline_config(model: str = None, index: str = None) -> QueryPipelineConfig: + provider_name = os.getenv("PROVIDER", "ollama") + provider = ( + ModelProvider.HUGGINGFACE + if provider_name.lower() == "hf" + else ModelProvider.OLLAMA + ) + + if provider == ModelProvider.HUGGINGFACE: + model_name = model or os.getenv("HF_MODEL_NAME") + embedding_model = os.getenv("HF_EMBEDDING_MODEL_NAME") + else: + model_name = model or os.getenv("MODEL_NAME") + embedding_model = os.getenv("EMBEDDING_MODEL_NAME") + + config_params = { + "provider": provider, + "embedding_model": embedding_model, + "model_name": model_name, + "system_prompt": system_prompt_value, + } + + # Provider specific parameters + if provider == ModelProvider.HUGGINGFACE: + if (hf_key := os.getenv("HF_API_KEY")) is not None: + config_params["hf_api_key"] = hf_key + else: + if (ollama_url := os.getenv("OLLAMA_URL")) is not None: + config_params["ollama_url"] = ollama_url + + # Model pull configuration + allow_pull = os.getenv("ALLOW_MODEL_PULL") + if allow_pull is not None: + config_params["allow_model_pull"] = allow_pull.lower() == "true" + + # Core generation parameters + if (context_window := get_env_param("CONTEXT_WINDOW", int, "8192")) is not None: + config_params["context_window"] = context_window + + for param in ["TEMPERATURE", "SEED", "TOP_K"]: + if ( + value := get_env_param(param, float if param == "TEMPERATURE" else int) + ) is not None: + config_params[param.lower()] = value + + # Advanced sampling parameters + for param in ["TOP_P", "MIN_P"]: + if (value := get_env_param(param, float)) is not None: + config_params[param.lower()] = value + + # Mirostat parameters + if (mirostat := get_env_param("MIROSTAT", int)) is not None: + config_params["mirostat"] = mirostat + for param in ["MIROSTAT_ETA", "MIROSTAT_TAU"]: + if (value := get_env_param(param, float)) is not None: + config_params[param.lower()] = value + + # Elasticsearch parameters + if (es_url := os.getenv("ES_URL")) is not None: + config_params["es_url"] = es_url + + if index is not None: + config_params["es_index"] = index + elif (es_index := os.getenv("ES_INDEX")) is not None: + config_params["es_index"] = es_index + + if (es_top_k := get_env_param("ES_TOP_K", int, "5")) is not None: + config_params["es_top_k"] = es_top_k + + if ( + es_num_candidates := get_env_param("ES_NUM_CANDIDATES", int, "-1") + ) is not None: + config_params["es_num_candidates"] = es_num_candidates + + if (es_user := os.getenv("ES_BASIC_AUTH_USERNAME")) is not None: + config_params["es_basic_auth_user"] = es_user + + if (es_pass := os.getenv("ES_BASIC_AUTH_PASSWORD")) is not None: + config_params["es_basic_auth_password"] = es_pass + + if (enable_conversation_logs := os.getenv("ENABLE_CONVERSATION_LOGS")) is not None: + config_params["enable_conversation_logs"] = enable_conversation_logs + + return QueryPipelineConfig(**config_params) diff --git a/services/api/src/api/routes.py b/services/api/src/api/routes.py new file mode 100644 index 0000000..16e3813 --- /dev/null +++ b/services/api/src/api/routes.py @@ -0,0 +1,148 @@ +import json +from datetime import datetime, timezone + +from api.config import ( + ALLOW_INDEX_CHANGE, + ALLOW_MODEL_CHANGE, + APP_VERSION, + BUILD_NUMBER, + DEBUG, + logger, +) +from api.handlers import handle_standard_response, handle_streaming_response +from api.middleware import require_api_key +from api.pipeline_config import create_pipeline_config +from flask import Flask, abort, jsonify, request + + +def log_request_info(request): + request_info = { + "timestamp": datetime.utcnow().isoformat(), + "metadata": { + "endpoint": request.endpoint, + "method": request.method, + "remote_addr": request.remote_addr, + "path": request.path, + }, + "headers": dict(request.headers), + "params": { + "url": dict(request.args) if request.args else None, + "form": dict(request.form) if request.form else None, + "cookies": dict(request.cookies) if request.cookies else None, + }, + } + + if request.data: + content_type = request.headers.get("Content-Type", "") + if "application/json" in content_type: + try: + request_info["body"] = request.get_json() + except Exception as e: + request_info["body"] = { + "error": f"Failed to parse JSON body: {str(e)}", + "raw": request.data.decode("utf-8", errors="replace"), + } + else: + request_info["body"] = request.data.decode("utf-8", errors="replace") + + logger.info("Request: %s", json.dumps(request_info, indent=None, sort_keys=True)) + + +def register_chat_routes(app: Flask): + @app.route("/api/chat", methods=["POST"]) + @require_api_key + def chat(): + try: + if DEBUG: + log_request_info(request) + + data = request.get_json() + + if not data: + logger.error("No JSON payload received.") + abort(400, description="Invalid JSON payload.") + + messages = data.get("messages", []) + if not messages: + abort(400, description="No messages provided") + + model = data.get("model") + if model and not ALLOW_MODEL_CHANGE: + abort(403, description="Model changes are not allowed") + + # Validate message format + for message in messages: + if ( + not isinstance(message, dict) + or "role" not in message + or "content" not in message + ): + abort(400, description="Invalid message format") + if message["role"] not in ["system", "user", "assistant", "tool"]: + abort(400, description="Invalid message role") + + # Optional parameters + # tools = data.get("tools", []) + # format_param = data.get("format") + options = data.get("options", {}) + stream = data.get("stream", True) + # keep_alive = data.get("keep_alive", "5m") + + # Handle index parameter + index = options.get("index") + if index and not ALLOW_INDEX_CHANGE: + abort(403, description="Index changes are not allowed") + + # Handle images in messages + for message in messages: + if "images" in message and not isinstance(message["images"], list): + abort(400, description="Images must be provided as a list") + + # Create configuration + config = create_pipeline_config(model, index) + + # Get the latest message with content + query = None + for message in reversed(messages): + content = message.get("content") + if content: + query = content + break + + if not query: + abort(400, description="No message with content found") + + # Handle conversation context + conversation = messages[:-1] if len(messages) > 1 else [] + + # Handle streaming vs non-streaming response + if stream: + return handle_streaming_response(config, query, conversation) + else: + return handle_standard_response(config, query, conversation) + + except Exception as e: + logger.error(f"Error processing chat request: {str(e)}", exc_info=True) + abort(500, description="Internal Server Error.") + + +def register_health_routes(app: Flask): + @app.route("/health", methods=["GET"]) + def health_check(): + return jsonify( + { + "service": "chipper-api", + "version": APP_VERSION, + "build": BUILD_NUMBER, + "status": "healthy", + "timestamp": datetime.now(timezone.utc).isoformat(), + } + ) + + @app.route("/", methods=["GET"]) + def root(): + return "Chipper is running", 200 + + @app.errorhandler(404) + def not_found_error(error): + return "", 404 diff --git a/services/api/src/api/routes_setup.py b/services/api/src/api/routes_setup.py new file mode 100644 index 0000000..87ba415 --- /dev/null +++ b/services/api/src/api/routes_setup.py @@ -0,0 +1,23 @@ +from api.config import logger +from api.ollama_routes import setup_ollama_routes +from api.routes import register_chat_routes, register_health_routes +from flask import Flask + + +def setup_all_routes(app: Flask): + try: + # Setup Ollama-specific routes + setup_ollama_routes(app) + logger.info("Ollama routes registered successfully") + + # Setup chat routes (chat, streaming, etc) + register_chat_routes(app) + logger.info("Chat routes registered successfully") + + # Setup health check and basic routes + register_health_routes(app) + logger.info("Health check routes registered successfully") + + except Exception as e: + logger.error(f"Error setting up routes: {e}", exc_info=True) + raise diff --git a/services/api/src/core/rag_pipeline.py b/services/api/src/core/rag_pipeline.py index 56e20c9..8e20384 100644 --- a/services/api/src/core/rag_pipeline.py +++ b/services/api/src/core/rag_pipeline.py @@ -208,8 +208,7 @@ def run_query( self.conversation_logger.log_conversation(query, response, conversation) if print_response and response["llm"]["replies"]: - print(response["llm"]["replies"][0]) - print("\n") + logging.info("Response: " + response["llm"]["replies"][0]) return response diff --git a/services/api/src/main.py b/services/api/src/main.py index cb4590c..87fb799 100755 --- a/services/api/src/main.py +++ b/services/api/src/main.py @@ -1,40 +1,7 @@ -import json -import logging import os -import queue -import secrets -import threading -from datetime import datetime, timezone -from functools import wraps -from pathlib import Path -import elasticsearch -from core.pipeline_config import ModelProvider, QueryPipelineConfig -from core.rag_pipeline import RAGQueryPipeline -from core.ollama_proxy import OllamaProxy -from dotenv import load_dotenv -from flask import Flask, Response, abort, jsonify, request, stream_with_context -from flask_limiter import Limiter -from flask_limiter.util import get_remote_address -from werkzeug.middleware.proxy_fix import ProxyFix - -load_dotenv() - -app = Flask(__name__) -app.wsgi_app = ProxyFix(app.wsgi_app) - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -APP_VERSION = os.getenv("APP_VERSION", "[DEV]") -BUILD_NUMBER = os.getenv("APP_BUILD_NUM", "0") - -ALLOW_MODEL_CHANGE = os.getenv("ALLOW_MODEL_CHANGE", "true").lower() == "true" -ALLOW_INDEX_CHANGE = os.getenv("ALLOW_INDEX_CHANGE", "true").lower() == "true" - -DAILY_LIMIT = int(os.getenv("DAILY_RATE_LIMIT", "86400")) -MINUTE_LIMIT = int(os.getenv("MINUTE_RATE_LIMIT", "60")) -STORAGE_URI = os.getenv("RATE_LIMIT_STORAGE", "memory://") +from api.config import APP_VERSION, BUILD_NUMBER, app, logger +from api.routes_setup import setup_all_routes def show_welcome(): @@ -55,470 +22,45 @@ def show_welcome(): print(f"{RESET}\n", flush=True) -show_welcome() - -limiter = Limiter( - key_func=get_remote_address, - app=app, - default_limits=[f"{DAILY_LIMIT} per day", f"{MINUTE_LIMIT} per minute"], - storage_uri=STORAGE_URI, -) - -API_KEY = os.getenv("API_KEY") -if not API_KEY: - API_KEY = secrets.token_urlsafe(32) - logger.info(f"Generated API key: {API_KEY}") - - -def load_systemprompt(base_path: str) -> str: - default_prompt = "" - - # Use environment variable if available - env_var_name = "SYSTEM_PROMPT" - env_prompt = os.getenv(env_var_name) - if env_prompt is not None: - content = env_prompt.strip() - logger.info( - f"Using system prompt from '{env_var_name}' environment variable; content: '{content}'" - ) - return content - - # Try reading from file - file = Path(base_path) / ".systemprompt" - - if not file.exists(): - logger.info("No .systemprompt file found. Using default prompt.") - return default_prompt - +def initialize_app(): try: - with open(file, "r", encoding="utf-8") as f: - content = f.read().strip() - - if not content: - logger.warning("System prompt file is empty. Using default prompt.") - return default_prompt - - logger.info( - f"Successfully loaded system prompt from {file}; content: '{content}'" - ) - return content - + setup_all_routes(app) + logger.info(f"Initialized Chipper API {APP_VERSION}.{BUILD_NUMBER}") except Exception as e: - logger.error(f"Error reading system prompt file: {e}") - return default_prompt - - -system_prompt_value = load_systemprompt(os.getenv("SYSTEM_PROMPT_PATH", os.getcwd())) + logger.error(f"Failed to initialize application: {e}", exc_info=True) + raise -def get_env_param(param_name, converter=None, default=None): - value = os.getenv(param_name) - if value is None: - return None - if converter is not None: - try: - if default is not None and value == "": - return converter(default) - return converter(value) - except (ValueError, TypeError): - return None - return value - - -def create_pipeline_config(model: str = None, index: str = None) -> QueryPipelineConfig: - provider_name = os.getenv("PROVIDER", "ollama") - provider = ( - ModelProvider.HUGGINGFACE - if provider_name.lower() == "hf" - else ModelProvider.OLLAMA - ) - - if provider == ModelProvider.HUGGINGFACE: - model_name = model or os.getenv("HF_MODEL_NAME") - embedding_model = os.getenv("HF_EMBEDDING_MODEL_NAME") - else: - model_name = model or os.getenv("MODEL_NAME") - embedding_model = os.getenv("EMBEDDING_MODEL_NAME") - - config_params = { - "provider": provider, - "embedding_model": embedding_model, - "model_name": model_name, - "system_prompt": system_prompt_value, +def get_server_config(): + return { + "host": os.getenv("HOST", "0.0.0.0"), + "port": int(os.getenv("PORT", "8000")), + "debug": os.getenv("DEBUG", "False").lower() == "true", } - # Provider specific parameters - if provider == ModelProvider.HUGGINGFACE: - if (hf_key := os.getenv("HF_API_KEY")) is not None: - config_params["hf_api_key"] = hf_key - else: - if (ollama_url := os.getenv("OLLAMA_URL")) is not None: - config_params["ollama_url"] = ollama_url - - # Model pull configuration - allow_pull = os.getenv("ALLOW_MODEL_PULL") - if allow_pull is not None: - config_params["allow_model_pull"] = allow_pull.lower() == "true" - - # Core generation parameters - if (context_window := get_env_param("CONTEXT_WINDOW", int, "8192")) is not None: - config_params["context_window"] = context_window - - for param in ["TEMPERATURE", "SEED", "TOP_K"]: - if ( - value := get_env_param(param, float if param == "TEMPERATURE" else int) - ) is not None: - config_params[param.lower()] = value - - # Advanced sampling parameters - for param in ["TOP_P", "MIN_P"]: - if (value := get_env_param(param, float)) is not None: - config_params[param.lower()] = value - - # Mirostat parameters - if (mirostat := get_env_param("MIROSTAT", int)) is not None: - config_params["mirostat"] = mirostat - # Only add eta and tau if mirostat is defined - for param in ["MIROSTAT_ETA", "MIROSTAT_TAU"]: - if (value := get_env_param(param, float)) is not None: - config_params[param.lower()] = value - - # Repetition control parameters - for param in ["REPEAT_LAST_N", "REPEAT_PENALTY"]: - if ( - value := get_env_param(param, int if param == "REPEAT_LAST_N" else float) - ) is not None: - config_params[param.lower()] = value - - # Generation control parameters - if (num_predict := get_env_param("NUM_PREDICT", int)) is not None: - config_params["num_predict"] = num_predict - - if (tfs_z := get_env_param("TFS_Z", float)) is not None: - config_params["tfs_z"] = tfs_z - - if (stop := os.getenv("STOP")) is not None: - config_params["stop_sequence"] = stop - - # Elasticsearch parameters - if (es_url := os.getenv("ES_URL")) is not None: - config_params["es_url"] = es_url - - if index is not None: - config_params["es_index"] = index - elif (es_index := os.getenv("ES_INDEX")) is not None: - config_params["es_index"] = es_index - - if (es_top_k := get_env_param("ES_TOP_K", int, "5")) is not None: - config_params["es_top_k"] = es_top_k - - if ( - es_num_candidates := get_env_param("ES_NUM_CANDIDATES", int, "-1") - ) is not None: - config_params["es_num_candidates"] = es_num_candidates - - if (es_user := os.getenv("ES_BASIC_AUTH_USERNAME")) is not None: - config_params["es_basic_auth_user"] = es_user - - if (es_pass := os.getenv("ES_BASIC_AUTH_PASSWORD")) is not None: - config_params["es_basic_auth_password"] = es_pass - - if (enable_conversation_logs := os.getenv("ENABLE_CONVERSATION_LOGS")) is not None: - config_params["enable_conversation_logs"] = enable_conversation_logs - - return QueryPipelineConfig(**config_params) - -def require_api_key(f): - @wraps(f) - def decorated_function(*args, **kwargs): - require_api_key = os.getenv("REQUIRE_API_KEY") - require_api_key = require_api_key.lower() == "true" - - if not require_api_key: - return f(*args, **kwargs) - - api_key = request.headers.get("X-API-Key") - if not api_key or api_key != API_KEY: - abort(401) - return f(*args, **kwargs) - - return decorated_function - - -def register_ollama_routes(app, proxy: OllamaProxy): - @app.route("/api/generate", methods=["POST"]) - @require_api_key - def generate(): - return proxy.generate() - - @app.route("/api/tags", methods=["GET"]) - @require_api_key - def tags(): - return proxy.tags() - - @app.route("/api/pull", methods=["POST"]) - @require_api_key - def pull(): - return proxy.pull() - -ollama_proxy = OllamaProxy(os.getenv("OLLAMA_URL", "http://localhost:11434")) -register_ollama_routes(app, ollama_proxy) - -@app.before_request -def before_request(): - logger.info(f"Request {request.method} {request.path} from {request.remote_addr}") - if os.getenv("REQUIRE_SECURE", "False").lower() == "true" and not request.is_secure: - abort(403) - - -@app.after_request -def after_request(response): - response.headers.update( - { - "Strict-Transport-Security": "max-age=31536000; includeSubDomains", - "X-Content-Type-Options": "nosniff", - "X-Frame-Options": "DENY", - "X-XSS-Protection": "1; mode=block", - } - ) - return response - - -@app.route("/api/chat", methods=["POST"]) -@require_api_key -def chat(): +def main(): try: - data = request.get_json() - if not data: - logger.error("No JSON payload received.") - abort(400, description="Invalid JSON payload.") - - messages = data.get("messages", []) - if not messages: - abort(400, description="No messages provided") - - query = messages[-1].get("content") - if not query: - abort(400, description="Invalid message format") - - model = data.get("model") - if model and not ALLOW_MODEL_CHANGE: - abort(403, description="Model changes are not allowed") - - options = data.get("options", {}) - index = options.get("index") - if index and not ALLOW_INDEX_CHANGE: - abort(403, description="Index changes are not allowed") - - config = create_pipeline_config(model, index) - stream = data.get("stream", True) - conversation = messages[:-1] if len(messages) > 1 else [] - if stream: - return handle_streaming_response(config, query, conversation) - else: - return handle_standard_response(config, query, conversation) - - except Exception as e: - logger.error(f"Error processing chat request: {str(e)}", exc_info=True) - abort(500, description="Internal Server Error.") - - -def handle_streaming_response( - config: QueryPipelineConfig, query: str, conversation: list -) -> Response: - q = queue.Queue() + show_welcome() + initialize_app() - def format_model_status(status): - model = status.get("model", "unknown") - status_type = status.get("status") + server_config = get_server_config() - allow_model_pull = os.getenv("ALLOW_MODEL_PULL", "True").lower() == "true" - if not allow_model_pull: - return None - - if status_type == "pulling": - return f"Starting to download model {model}..." - elif status_type == "progress": - percentage = status.get("percentage", 0) - return f"Downloading model {model}: {percentage}% complete" - elif status_type == "complete": - return f"Successfully downloaded model {model}" - elif status_type == "error" and "pull" in status.get("error", "").lower(): - error_msg = status.get("error", "Unknown error") - return f"Error downloading model {model}: {error_msg}" - - return None - - def streaming_callback(chunk): - if chunk.content: - response_data = { - "type": "chat_response", - "chunk": chunk.content, - "done": False, - "full_response": None, - } - q.put(f"data: {json.dumps(response_data)}\n\n") - - rag = RAGQueryPipeline(config=config, streaming_callback=streaming_callback) - - def run_rag(): - try: - for status in rag.initialize_and_check_models(): - message = format_model_status(status) - if message: - response_data = { - "type": "chat_response", - "chunk": message + "\n", - "done": False, - "full_response": None, - } - q.put(f"data: {json.dumps(response_data)}\n\n") - - rag.create_query_pipeline() - result = rag.run_query( - query=query, conversation=conversation, print_response=False - ) - final_data = { - "type": "chat_response", - "chunk": "", - "done": True, - "full_response": result, - } - q.put(f"data: {json.dumps(final_data)}\n\n") - except elasticsearch.BadRequestError as e: - error_data = { - "type": "chat_response", - "chunk": f"Error: Embedding retriever error. {str(e)}.\n", - "done": True, - } - q.put(f"data: {json.dumps(error_data)}\n\n") - except Exception as e: - error_data = { - "type": "chat_response", - "chunk": f"Error: {str(e)}\n", - "done": True, - } - logger.error(f"Error in RAG pipeline: {e}", exc_info=True) - q.put(f"data: {json.dumps(error_data)}\n\n") - - thread = threading.Thread(target=run_rag, daemon=True) - thread.start() - - def generate(): - while True: - try: - data_item = q.get(timeout=120) - yield data_item - - json_data = json.loads(data_item.replace("data: ", "").strip()) - if json_data.get("done") is True: - logger.info("Streaming completed.") - break - - except queue.Empty: - yield "event: heartbeat\ndata: {}\n\n" - logger.warning("Queue timeout. Sending heartbeat.") - except json.JSONDecodeError as e: - logger.error(f"JSON decode error: {e} | Data: {data_item}") - error_message = { - "type": "error", - "error": "Invalid JSON format received.", - "done": True, - } - yield f"data: {json.dumps(error_message)}\n\n" - break - - return Response( - stream_with_context(generate()), - mimetype="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "X-Accel-Buffering": "no", - "Connection": "keep-alive", - }, - ) - - -def handle_standard_response( - config: QueryPipelineConfig, - query: str, - conversation: list -) -> Response: - rag = RAGQueryPipeline(config=config) - - success = True - result = None - try: - result = rag.run_query( - query=query, - conversation=conversation, - print_response=False + logger.info( + f"Starting server on {server_config['host']}:{server_config['port']}" ) - if result: - latest_message = { - "role": "assistant", - "content": result["llm"]["replies"][0], - "timestamp": datetime.now().isoformat(), - } - conversation.append(latest_message) - - # Add model metrics to result - result["model_metrics"] = { - "model": "llama3.2", - "created_at": datetime.now().isoformat(), - "done": True, - "total_duration": 4883583458, - "load_duration": 1334875, - "prompt_eval_count": 26, - "prompt_eval_duration": 342546000, - "eval_count": 282, - "eval_duration": 4535599000 - } - + app.run( + host=server_config["host"], + port=server_config["port"], + debug=server_config["debug"], + ) except Exception as e: - success = False - logger.error(f"Error in RAG pipeline: {e}", exc_info=True) - - if success and result: - latest_message = { - "role": "assistant", - "content": result["llm"]["replies"][0], - "timestamp": datetime.now().isoformat(), - } - conversation.append(latest_message) - - return jsonify( - { - "success": success, - "timestamp": datetime.now().isoformat(), - "result": result, - "messages": conversation, - } - ) - -@app.route("/", methods=["GET"]) -@app.route("/health", methods=["GET"]) -def health_check(): - return jsonify( - { - "service": "chipper-api", - "version": APP_VERSION, - "build": BUILD_NUMBER, - "status": "healthy", - "timestamp": datetime.now(timezone.utc).isoformat(), - } - ) + logger.error(f"Failed to start server: {e}", exc_info=True) + exit(1) -@app.errorhandler(404) -def not_found_error(error): - return "", 404 +main() if __name__ == "__main__": - app.run( - host=os.getenv("HOST", "0.0.0.0"), - port=int(os.getenv("PORT", "8000")), - debug=os.getenv("DEBUG", "False").lower() == "true", - ) + main() diff --git a/tools/cli/tools/api_mirror_tester/src/main.py b/tools/cli/tools/api_mirror_tester/src/main.py index 58ce2b5..abfb620 100644 --- a/tools/cli/tools/api_mirror_tester/src/main.py +++ b/tools/cli/tools/api_mirror_tester/src/main.py @@ -234,7 +234,7 @@ def print_results(self, results: List[ComparisonResult]): async def main(): tester = ApiMirrorTester( - Chipper_api_base="http://localhost:21210/", + Chipper_api_base="http://localhost:21434/", ollama_api_base="http://localhost:11434", verify_ssl=False, ) diff --git a/tools/cli/tools/index.html b/tools/cli/tools/index.html index e7b3c15..f074ca4 100644 --- a/tools/cli/tools/index.html +++ b/tools/cli/tools/index.html @@ -133,7 +133,7 @@

Chat API Test

}, }; - const response = await fetch("http://localhost:21210/api/chat", { + const response = await fetch("http://localhost:21434/api/chat", { method: "POST", headers: { "Content-Type": "application/json", diff --git a/tools/cli/tools/test_non_streaming.sh b/tools/cli/tools/test_non_streaming.sh index 3a5c591..f3355df 100644 --- a/tools/cli/tools/test_non_streaming.sh +++ b/tools/cli/tools/test_non_streaming.sh @@ -1,4 +1,4 @@ -curl -X POST http://localhost:21210/api/chat \ +curl -X POST http://localhost:21434/api/chat \ -H "Content-Type: application/json" \ -H "X-API-Key: EXAMPLE_API_KEY" \ -d '{ diff --git a/tools/cli/tools/test_streaming.sh b/tools/cli/tools/test_streaming.sh index 6682cd8..09b4533 100644 --- a/tools/cli/tools/test_streaming.sh +++ b/tools/cli/tools/test_streaming.sh @@ -1,4 +1,4 @@ -curl -X POST http://localhost:21210/api/chat \ +curl -X POST http://localhost:21434/api/chat \ -H "Content-Type: application/json" \ -H "X-API-Key: EXAMPLE_API_KEY" \ -d '{