diff --git a/services/api/.env.example b/services/api/.env.example index d66f233..3aa0a6a 100644 --- a/services/api/.env.example +++ b/services/api/.env.example @@ -20,6 +20,20 @@ ALLOW_MODEL_PARAMETER_CHANGE=true # Ollama clients and HuggingFace by using the API default model instead. IGNORE_MODEL_REQUEST=false +# Enables proxying of non-RAG Ollama API endpoints directly to the Ollama instance. +# If API key authentication is enabled, requests must be authenticated accordingly. +# This includes endpoints such as: +# - /api/tags +# - /api/pull +# - /api/generate +ENABLE_OLLAMA_PROXY=true + +# Enabling Chipper RAG bypass will route chat messages directly to the Ollama instance +# without applying any RAG embedding steps. This can be useful for debugging but goes +# against the intended purpose of this project. It can also be used to add an API key +# to your Ollama instance, though this is not recommended. +BYPASS_OLLAMA_RAG=false + # Embedding EMBEDDING_MODEL_NAME=snowflake-arctic-embed2 HF_EMBEDDING_MODEL_NAME=sentence-transformers/all-mpnet-base-v2 diff --git a/services/api/src/api/config.py b/services/api/src/api/config.py index 14cb17f..3e5be49 100644 --- a/services/api/src/api/config.py +++ b/services/api/src/api/config.py @@ -34,6 +34,9 @@ os.getenv("ALLOW_MODEL_PARAMETER_CHANGE", "true").lower() == "true" ) IGNORE_MODEL_REQUEST = os.getenv("IGNORE_MODEL_REQUEST", "true").lower() == "true" +ENABLE_OLLAMA_PROXY = os.getenv("ENABLE_OLLAMA_PROXY", "true").lower() == "true" +BYPASS_OLLAMA_RAG = os.getenv("BYPASS_OLLAMA_RAG", "false").lower() == "true" + DEBUG = os.getenv("DEBUG", "true").lower() == "true" # Rate limiting configuration diff --git a/services/api/src/api/middleware.py b/services/api/src/api/middleware.py index 6d42d2e..e409757 100644 --- a/services/api/src/api/middleware.py +++ b/services/api/src/api/middleware.py @@ -78,7 +78,7 @@ def internal_error(error): def setup_request_logging_middleware(app): @app.before_request def log_request_info(): - if request.path == "/health": + if request.path == "/" or request.path == "/health": return log_data = { @@ -89,7 +89,7 @@ def log_request_info(): "request_id": request.headers.get("X-Request-ID"), } - logger.info("Incoming request", extra=log_data) + logger.debug("Incoming request", extra=log_data) def init_middleware(app): diff --git a/services/api/src/api/ollama_proxy.py b/services/api/src/api/ollama_proxy.py index a5b36c8..29a30e1 100644 --- a/services/api/src/api/ollama_proxy.py +++ b/services/api/src/api/ollama_proxy.py @@ -10,12 +10,45 @@ class OllamaProxy: + """ + A proxy class for interacting with the Ollama API. + + This class provides methods for all Ollama API endpoints, handling both streaming + and non-streaming responses, and managing various model operations. + Ref: https://github.com/ollama/ollama/blob/main/docs/api.md + """ + def __init__(self, base_url: Optional[str] = None): + """ + Initialize the OllamaProxy with a base URL. + + Args: + base_url: The base URL for the Ollama API. Defaults to environment variable + OLLAMA_URL or 'http://localhost:11434' + """ self.base_url = base_url or os.getenv("OLLAMA_URL", "http://localhost:11434") - def _proxy_request(self, path: str, method: str = "GET", stream: bool = False): + def _proxy_request( + self, path: str, method: str = "GET", stream: bool = False + ) -> Response: + """ + Make a proxied request to the Ollama API. + + Args: + path: The API endpoint path + method: The HTTP method to use + stream: Whether to stream the response + + Returns: + A Flask Response object + """ url = f"{self.base_url}{path}" - headers = {k: v for k, v in request.headers if k != "Host"} + headers = { + k: v + for k, v in request.headers.items() + if k.lower() not in ["host", "transfer-encoding"] + } + data = request.get_data() if method != "GET" else None try: @@ -33,23 +66,30 @@ def _proxy_request(self, path: str, method: str = "GET", stream: bool = False): json.dumps({"error": str(e)}), status=500, mimetype="application/json" ) - def _handle_streaming_response(self, response): + def _handle_streaming_response(self, response: requests.Response) -> Response: + """Handle streaming responses from the Ollama API.""" + def generate(): - for chunk in response.iter_content(chunk_size=None): - yield chunk + try: + for chunk in response.iter_content(chunk_size=None): + if chunk: + yield chunk + except Exception as e: + logger.error(f"Error streaming response: {str(e)}") + yield json.dumps({"error": str(e)}).encode() + + response_headers = { + "Content-Type": response.headers.get("Content-Type", "application/json") + } return Response( stream_with_context(generate()), status=response.status_code, - headers={ - "Content-Type": response.headers.get( - "Content-Type", "application/json" - ), - "Transfer-Encoding": "chunked", - }, + headers=response_headers, ) - def _handle_standard_response(self, response): + def _handle_standard_response(self, response: requests.Response) -> Response: + """Handle non-streaming responses from the Ollama API.""" return Response( response.content, status=response.status_code, @@ -58,11 +98,66 @@ def _handle_standard_response(self, response): }, ) - def generate(self): + # Generation endpoints + def generate(self) -> Response: + """Generate a completion for a given prompt.""" return self._proxy_request("/api/generate", "POST", stream=True) - def tags(self): - return self._proxy_request("/api/tags", "GET") + def chat(self) -> Response: + """Generate the next message in a chat conversation.""" + return self._proxy_request("/api/chat", "POST", stream=True) + + def embeddings(self) -> Response: + """Generate embeddings (legacy endpoint).""" + return self._proxy_request("/api/embeddings", "POST") - def pull(self): + def embed(self) -> Response: + """Generate embeddings from a model.""" + return self._proxy_request("/api/embed", "POST") + + # Model management endpoints + def create(self) -> Response: + """Create a model.""" + return self._proxy_request("/api/create", "POST", stream=True) + + def show(self) -> Response: + """Show model information.""" + return self._proxy_request("/api/show", "POST") + + def copy(self) -> Response: + """Copy a model.""" + return self._proxy_request("/api/copy", "POST") + + def delete(self) -> Response: + """Delete a model.""" + return self._proxy_request("/api/delete", "DELETE") + + def pull(self) -> Response: + """Pull a model from the Ollama library.""" return self._proxy_request("/api/pull", "POST", stream=True) + + def push(self) -> Response: + """Push a model to the Ollama library.""" + return self._proxy_request("/api/push", "POST", stream=True) + + # Blob management endpoints + def check_blob(self, digest: str) -> Response: + """Check if a blob exists.""" + return self._proxy_request(f"/api/blobs/{digest}", "HEAD") + + def push_blob(self, digest: str) -> Response: + """Push a blob to the server.""" + return self._proxy_request(f"/api/blobs/{digest}", "POST") + + # Model listing and status endpoints + def list_local_models(self) -> Response: + """List models available locally.""" + return self._proxy_request("/api/tags", "GET") + + def list_running_models(self) -> Response: + """List models currently loaded in memory.""" + return self._proxy_request("/api/ps", "GET") + + def version(self) -> Response: + """Get the Ollama version.""" + return self._proxy_request("/api/version", "GET") diff --git a/services/api/src/api/ollama_routes.py b/services/api/src/api/ollama_routes.py index 7f570eb..a8a6eb5 100644 --- a/services/api/src/api/ollama_routes.py +++ b/services/api/src/api/ollama_routes.py @@ -1,6 +1,6 @@ import os -from api.config import logger +from api.config import BYPASS_OLLAMA_RAG, logger from api.middleware import require_api_key from api.ollama_proxy import OllamaProxy @@ -11,7 +11,21 @@ def __init__(self, app, proxy: OllamaProxy): self.proxy = proxy self.register_routes() + if BYPASS_OLLAMA_RAG: + self.register_bypass_routes() + + def register_bypass_routes(self): + @self.app.route("/api/chat", methods=["POST"]) + @require_api_key + def chat(): + try: + return self.proxy.chat() + except Exception as e: + logger.error(f"Error in chat endpoint: {e}") + return {"error": str(e)}, 500 + def register_routes(self): + # Generation endpoints @self.app.route("/api/generate", methods=["POST"]) @require_api_key def generate(): @@ -21,13 +35,59 @@ def generate(): logger.error(f"Error in generate endpoint: {e}") return {"error": str(e)}, 500 - @self.app.route("/api/tags", methods=["GET"]) + @self.app.route("/api/embeddings", methods=["POST"]) + @require_api_key + def embeddings(): + try: + return self.proxy.embeddings() + except Exception as e: + logger.error(f"Error in embeddings endpoint: {e}") + return {"error": str(e)}, 500 + + @self.app.route("/api/embed", methods=["POST"]) + @require_api_key + def embed(): + try: + return self.proxy.embed() + except Exception as e: + logger.error(f"Error in embed endpoint: {e}") + return {"error": str(e)}, 500 + + # Model management endpoints + @self.app.route("/api/create", methods=["POST"]) + @require_api_key + def create(): + try: + return self.proxy.create() + except Exception as e: + logger.error(f"Error in create endpoint: {e}") + return {"error": str(e)}, 500 + + @self.app.route("/api/show", methods=["POST"]) + @require_api_key + def show(): + try: + return self.proxy.show() + except Exception as e: + logger.error(f"Error in show endpoint: {e}") + return {"error": str(e)}, 500 + + @self.app.route("/api/copy", methods=["POST"]) + @require_api_key + def copy(): + try: + return self.proxy.copy() + except Exception as e: + logger.error(f"Error in copy endpoint: {e}") + return {"error": str(e)}, 500 + + @self.app.route("/api/delete", methods=["DELETE"]) @require_api_key - def tags(): + def delete(): try: - return self.proxy.tags() + return self.proxy.delete() except Exception as e: - logger.error(f"Error in tags endpoint: {e}") + logger.error(f"Error in delete endpoint: {e}") return {"error": str(e)}, 500 @self.app.route("/api/pull", methods=["POST"]) @@ -39,10 +99,66 @@ def pull(): logger.error(f"Error in pull endpoint: {e}") return {"error": str(e)}, 500 + @self.app.route("/api/push", methods=["POST"]) + @require_api_key + def push(): + try: + return self.proxy.push() + except Exception as e: + logger.error(f"Error in push endpoint: {e}") + return {"error": str(e)}, 500 + + # Blob management endpoints + @self.app.route("/api/blobs/", methods=["HEAD"]) + @require_api_key + def check_blob(digest): + try: + return self.proxy.check_blob(digest) + except Exception as e: + logger.error(f"Error in check_blob endpoint: {e}") + return {"error": str(e)}, 500 + + @self.app.route("/api/blobs/", methods=["POST"]) + @require_api_key + def push_blob(digest): + try: + return self.proxy.push_blob(digest) + except Exception as e: + logger.error(f"Error in push_blob endpoint: {e}") + return {"error": str(e)}, 500 + + # Model listing and status endpoints + @self.app.route("/api/tags", methods=["GET"]) + @require_api_key + def list_local_models(): + try: + return self.proxy.list_local_models() + except Exception as e: + logger.error(f"Error in list_local_models endpoint: {e}") + return {"error": str(e)}, 500 + + @self.app.route("/api/ps", methods=["GET"]) + @require_api_key + def list_running_models(): + try: + return self.proxy.list_running_models() + except Exception as e: + logger.error(f"Error in list_running_models endpoint: {e}") + return {"error": str(e)}, 500 + + @self.app.route("/api/version", methods=["GET"]) + @require_api_key + def version(): + try: + return self.proxy.version() + except Exception as e: + logger.error(f"Error in version endpoint: {e}") + return {"error": str(e)}, 500 + -def setup_ollama_routes(app): +def setup_ollama_proxy_routes(app): ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434") proxy = OllamaProxy(ollama_url) OllamaRoutes(app, proxy) - logger.info(f"Initialized Ollama routes with URL: {ollama_url}") + logger.info(f"Initialized Ollama proxy routes with Ollama URL: {ollama_url}") return proxy diff --git a/services/api/src/api/routes.py b/services/api/src/api/routes.py index 6ede8b5..717684f 100644 --- a/services/api/src/api/routes.py +++ b/services/api/src/api/routes.py @@ -50,7 +50,7 @@ def log_request_info(request): logger.info("Request: %s", json.dumps(request_info, indent=None, sort_keys=True)) -def register_chat_routes(app: Flask): +def register_rag_chat_route(app: Flask): @app.route("/api/chat", methods=["POST"]) @require_api_key def chat(): @@ -82,7 +82,12 @@ def chat(): or "content" not in message ): abort(400, description="Invalid message format") - if message["role"] not in ["system", "user", "assistant", "tool"]: + if message["role"] != "" and message["role"] not in [ + "system", + "user", + "assistant", + "tool", + ]: abort(400, description="Invalid message role") # Optional parameters diff --git a/services/api/src/api/routes_setup.py b/services/api/src/api/routes_setup.py index af03842..dce1674 100644 --- a/services/api/src/api/routes_setup.py +++ b/services/api/src/api/routes_setup.py @@ -1,21 +1,33 @@ -from api.config import PROVIDER_IS_OLLAMA, logger -from api.ollama_routes import setup_ollama_routes -from api.routes import register_chat_routes, register_health_routes +from api.config import ( + BYPASS_OLLAMA_RAG, + ENABLE_OLLAMA_PROXY, + PROVIDER_IS_OLLAMA, + logger, +) +from api.ollama_routes import setup_ollama_proxy_routes +from api.routes import register_health_routes, register_rag_chat_route from flask import Flask def setup_all_routes(app: Flask): try: - if PROVIDER_IS_OLLAMA: - # Setup Ollama-specific routes - setup_ollama_routes(app) - logger.info("Ollama routes registered successfully") + if PROVIDER_IS_OLLAMA and ENABLE_OLLAMA_PROXY: + # Setup Ollama proxy routes + setup_ollama_proxy_routes(app) + logger.info("Ollama proxy routes registered successfully") - # Setup chat routes (chat, streaming, etc) - register_chat_routes(app) - logger.info("Chat routes registered successfully") + # Setup internal RAG pipeline routes + if not BYPASS_OLLAMA_RAG or not PROVIDER_IS_OLLAMA: + register_rag_chat_route(app) + logger.info( + "Chat routes registered successfully: RAG and embedding enabled." + ) + else: + logger.warning( + "Chat routes bypassed! RAG is disabled, and embeddings will not be used." + ) - # Setup health check and basic routes + # Setup health check routes register_health_routes(app) logger.info("Health check routes registered successfully") diff --git a/services/api/src/core/rag_pipeline.py b/services/api/src/core/rag_pipeline.py index 8e20384..c6a58f7 100644 --- a/services/api/src/core/rag_pipeline.py +++ b/services/api/src/core/rag_pipeline.py @@ -180,8 +180,10 @@ def run_query( print_response: bool = False, use_embeddings: bool = True, ) -> Optional[dict]: - self.logger.info(f"\nProcessing Query: {query}") - self.logger.info(f"Conversation history present: {bool(conversation)}") + self.logger.info("Processing Query...") + self.logger.info( + f"Conversation history present: {bool(conversation)}; history length: {len(conversation)}" + ) if not self.query_pipeline: self.logger.info("Query pipeline not initialized. Creating new pipeline...") @@ -208,7 +210,8 @@ def run_query( self.conversation_logger.log_conversation(query, response, conversation) if print_response and response["llm"]["replies"]: - logging.info("Response: " + response["llm"]["replies"][0]) + self.logger.info("Query: " + query) + self.logger.info("Response: " + response["llm"]["replies"][0]) return response