Skip to content

Commit

Permalink
Fully reflect the Ollama api (#93)
Browse files Browse the repository at this point in the history
* Add option to disable Ollama proxy

* Make Ollama proxy Ollama CLI compatible

* Allow empty message role

* Fully reflect the Ollama API

* Suppress logging of health ping calls

* Only log queries if debug is enabled

* Implement option to bypass RAG pipeline

* Apply formatting

* Make route naming more distinct
  • Loading branch information
TilmanGriesel authored Feb 1, 2025
1 parent 70e5cce commit 1bae97c
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 41 deletions.
14 changes: 14 additions & 0 deletions services/api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ ALLOW_MODEL_PARAMETER_CHANGE=true
# Ollama clients and HuggingFace by using the API default model instead.
IGNORE_MODEL_REQUEST=false

# Enables proxying of non-RAG Ollama API endpoints directly to the Ollama instance.
# If API key authentication is enabled, requests must be authenticated accordingly.
# This includes endpoints such as:
# - /api/tags
# - /api/pull
# - /api/generate
ENABLE_OLLAMA_PROXY=true

# Enabling Chipper RAG bypass will route chat messages directly to the Ollama instance
# without applying any RAG embedding steps. This can be useful for debugging but goes
# against the intended purpose of this project. It can also be used to add an API key
# to your Ollama instance, though this is not recommended.
BYPASS_OLLAMA_RAG=false

# Embedding
EMBEDDING_MODEL_NAME=snowflake-arctic-embed2
HF_EMBEDDING_MODEL_NAME=sentence-transformers/all-mpnet-base-v2
Expand Down
3 changes: 3 additions & 0 deletions services/api/src/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
os.getenv("ALLOW_MODEL_PARAMETER_CHANGE", "true").lower() == "true"
)
IGNORE_MODEL_REQUEST = os.getenv("IGNORE_MODEL_REQUEST", "true").lower() == "true"
ENABLE_OLLAMA_PROXY = os.getenv("ENABLE_OLLAMA_PROXY", "true").lower() == "true"
BYPASS_OLLAMA_RAG = os.getenv("BYPASS_OLLAMA_RAG", "false").lower() == "true"

DEBUG = os.getenv("DEBUG", "true").lower() == "true"

# Rate limiting configuration
Expand Down
4 changes: 2 additions & 2 deletions services/api/src/api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def internal_error(error):
def setup_request_logging_middleware(app):
@app.before_request
def log_request_info():
if request.path == "/health":
if request.path == "/" or request.path == "/health":
return

log_data = {
Expand All @@ -89,7 +89,7 @@ def log_request_info():
"request_id": request.headers.get("X-Request-ID"),
}

logger.info("Incoming request", extra=log_data)
logger.debug("Incoming request", extra=log_data)


def init_middleware(app):
Expand Down
127 changes: 111 additions & 16 deletions services/api/src/api/ollama_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,45 @@


class OllamaProxy:
"""
A proxy class for interacting with the Ollama API.
This class provides methods for all Ollama API endpoints, handling both streaming
and non-streaming responses, and managing various model operations.
Ref: https://github.com/ollama/ollama/blob/main/docs/api.md
"""

def __init__(self, base_url: Optional[str] = None):
"""
Initialize the OllamaProxy with a base URL.
Args:
base_url: The base URL for the Ollama API. Defaults to environment variable
OLLAMA_URL or 'http://localhost:11434'
"""
self.base_url = base_url or os.getenv("OLLAMA_URL", "http://localhost:11434")

def _proxy_request(self, path: str, method: str = "GET", stream: bool = False):
def _proxy_request(
self, path: str, method: str = "GET", stream: bool = False
) -> Response:
"""
Make a proxied request to the Ollama API.
Args:
path: The API endpoint path
method: The HTTP method to use
stream: Whether to stream the response
Returns:
A Flask Response object
"""
url = f"{self.base_url}{path}"
headers = {k: v for k, v in request.headers if k != "Host"}
headers = {
k: v
for k, v in request.headers.items()
if k.lower() not in ["host", "transfer-encoding"]
}

data = request.get_data() if method != "GET" else None

try:
Expand All @@ -33,23 +66,30 @@ def _proxy_request(self, path: str, method: str = "GET", stream: bool = False):
json.dumps({"error": str(e)}), status=500, mimetype="application/json"
)

def _handle_streaming_response(self, response):
def _handle_streaming_response(self, response: requests.Response) -> Response:
"""Handle streaming responses from the Ollama API."""

def generate():
for chunk in response.iter_content(chunk_size=None):
yield chunk
try:
for chunk in response.iter_content(chunk_size=None):
if chunk:
yield chunk
except Exception as e:
logger.error(f"Error streaming response: {str(e)}")
yield json.dumps({"error": str(e)}).encode()

response_headers = {
"Content-Type": response.headers.get("Content-Type", "application/json")
}

return Response(
stream_with_context(generate()),
status=response.status_code,
headers={
"Content-Type": response.headers.get(
"Content-Type", "application/json"
),
"Transfer-Encoding": "chunked",
},
headers=response_headers,
)

def _handle_standard_response(self, response):
def _handle_standard_response(self, response: requests.Response) -> Response:
"""Handle non-streaming responses from the Ollama API."""
return Response(
response.content,
status=response.status_code,
Expand All @@ -58,11 +98,66 @@ def _handle_standard_response(self, response):
},
)

def generate(self):
# Generation endpoints
def generate(self) -> Response:
"""Generate a completion for a given prompt."""
return self._proxy_request("/api/generate", "POST", stream=True)

def tags(self):
return self._proxy_request("/api/tags", "GET")
def chat(self) -> Response:
"""Generate the next message in a chat conversation."""
return self._proxy_request("/api/chat", "POST", stream=True)

def embeddings(self) -> Response:
"""Generate embeddings (legacy endpoint)."""
return self._proxy_request("/api/embeddings", "POST")

def pull(self):
def embed(self) -> Response:
"""Generate embeddings from a model."""
return self._proxy_request("/api/embed", "POST")

# Model management endpoints
def create(self) -> Response:
"""Create a model."""
return self._proxy_request("/api/create", "POST", stream=True)

def show(self) -> Response:
"""Show model information."""
return self._proxy_request("/api/show", "POST")

def copy(self) -> Response:
"""Copy a model."""
return self._proxy_request("/api/copy", "POST")

def delete(self) -> Response:
"""Delete a model."""
return self._proxy_request("/api/delete", "DELETE")

def pull(self) -> Response:
"""Pull a model from the Ollama library."""
return self._proxy_request("/api/pull", "POST", stream=True)

def push(self) -> Response:
"""Push a model to the Ollama library."""
return self._proxy_request("/api/push", "POST", stream=True)

# Blob management endpoints
def check_blob(self, digest: str) -> Response:
"""Check if a blob exists."""
return self._proxy_request(f"/api/blobs/{digest}", "HEAD")

def push_blob(self, digest: str) -> Response:
"""Push a blob to the server."""
return self._proxy_request(f"/api/blobs/{digest}", "POST")

# Model listing and status endpoints
def list_local_models(self) -> Response:
"""List models available locally."""
return self._proxy_request("/api/tags", "GET")

def list_running_models(self) -> Response:
"""List models currently loaded in memory."""
return self._proxy_request("/api/ps", "GET")

def version(self) -> Response:
"""Get the Ollama version."""
return self._proxy_request("/api/version", "GET")
130 changes: 123 additions & 7 deletions services/api/src/api/ollama_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from api.config import logger
from api.config import BYPASS_OLLAMA_RAG, logger
from api.middleware import require_api_key
from api.ollama_proxy import OllamaProxy

Expand All @@ -11,7 +11,21 @@ def __init__(self, app, proxy: OllamaProxy):
self.proxy = proxy
self.register_routes()

if BYPASS_OLLAMA_RAG:
self.register_bypass_routes()

def register_bypass_routes(self):
@self.app.route("/api/chat", methods=["POST"])
@require_api_key
def chat():
try:
return self.proxy.chat()
except Exception as e:
logger.error(f"Error in chat endpoint: {e}")
return {"error": str(e)}, 500

def register_routes(self):
# Generation endpoints
@self.app.route("/api/generate", methods=["POST"])
@require_api_key
def generate():
Expand All @@ -21,13 +35,59 @@ def generate():
logger.error(f"Error in generate endpoint: {e}")
return {"error": str(e)}, 500

@self.app.route("/api/tags", methods=["GET"])
@self.app.route("/api/embeddings", methods=["POST"])
@require_api_key
def embeddings():
try:
return self.proxy.embeddings()
except Exception as e:
logger.error(f"Error in embeddings endpoint: {e}")
return {"error": str(e)}, 500

@self.app.route("/api/embed", methods=["POST"])
@require_api_key
def embed():
try:
return self.proxy.embed()
except Exception as e:
logger.error(f"Error in embed endpoint: {e}")
return {"error": str(e)}, 500

# Model management endpoints
@self.app.route("/api/create", methods=["POST"])
@require_api_key
def create():
try:
return self.proxy.create()
except Exception as e:
logger.error(f"Error in create endpoint: {e}")
return {"error": str(e)}, 500

@self.app.route("/api/show", methods=["POST"])
@require_api_key
def show():
try:
return self.proxy.show()
except Exception as e:
logger.error(f"Error in show endpoint: {e}")
return {"error": str(e)}, 500

@self.app.route("/api/copy", methods=["POST"])
@require_api_key
def copy():
try:
return self.proxy.copy()
except Exception as e:
logger.error(f"Error in copy endpoint: {e}")
return {"error": str(e)}, 500

@self.app.route("/api/delete", methods=["DELETE"])
@require_api_key
def tags():
def delete():
try:
return self.proxy.tags()
return self.proxy.delete()
except Exception as e:
logger.error(f"Error in tags endpoint: {e}")
logger.error(f"Error in delete endpoint: {e}")
return {"error": str(e)}, 500

@self.app.route("/api/pull", methods=["POST"])
Expand All @@ -39,10 +99,66 @@ def pull():
logger.error(f"Error in pull endpoint: {e}")
return {"error": str(e)}, 500

@self.app.route("/api/push", methods=["POST"])
@require_api_key
def push():
try:
return self.proxy.push()
except Exception as e:
logger.error(f"Error in push endpoint: {e}")
return {"error": str(e)}, 500

# Blob management endpoints
@self.app.route("/api/blobs/<digest>", methods=["HEAD"])
@require_api_key
def check_blob(digest):
try:
return self.proxy.check_blob(digest)
except Exception as e:
logger.error(f"Error in check_blob endpoint: {e}")
return {"error": str(e)}, 500

@self.app.route("/api/blobs/<digest>", methods=["POST"])
@require_api_key
def push_blob(digest):
try:
return self.proxy.push_blob(digest)
except Exception as e:
logger.error(f"Error in push_blob endpoint: {e}")
return {"error": str(e)}, 500

# Model listing and status endpoints
@self.app.route("/api/tags", methods=["GET"])
@require_api_key
def list_local_models():
try:
return self.proxy.list_local_models()
except Exception as e:
logger.error(f"Error in list_local_models endpoint: {e}")
return {"error": str(e)}, 500

@self.app.route("/api/ps", methods=["GET"])
@require_api_key
def list_running_models():
try:
return self.proxy.list_running_models()
except Exception as e:
logger.error(f"Error in list_running_models endpoint: {e}")
return {"error": str(e)}, 500

@self.app.route("/api/version", methods=["GET"])
@require_api_key
def version():
try:
return self.proxy.version()
except Exception as e:
logger.error(f"Error in version endpoint: {e}")
return {"error": str(e)}, 500


def setup_ollama_routes(app):
def setup_ollama_proxy_routes(app):
ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
proxy = OllamaProxy(ollama_url)
OllamaRoutes(app, proxy)
logger.info(f"Initialized Ollama routes with URL: {ollama_url}")
logger.info(f"Initialized Ollama proxy routes with Ollama URL: {ollama_url}")
return proxy
9 changes: 7 additions & 2 deletions services/api/src/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def log_request_info(request):
logger.info("Request: %s", json.dumps(request_info, indent=None, sort_keys=True))


def register_chat_routes(app: Flask):
def register_rag_chat_route(app: Flask):
@app.route("/api/chat", methods=["POST"])
@require_api_key
def chat():
Expand Down Expand Up @@ -82,7 +82,12 @@ def chat():
or "content" not in message
):
abort(400, description="Invalid message format")
if message["role"] not in ["system", "user", "assistant", "tool"]:
if message["role"] != "" and message["role"] not in [
"system",
"user",
"assistant",
"tool",
]:
abort(400, description="Invalid message role")

# Optional parameters
Expand Down
Loading

0 comments on commit 1bae97c

Please sign in to comment.