Skip to content

Commit

Permalink
feat: Mirror Ollama chat API
Browse files Browse the repository at this point in the history
  • Loading branch information
TilmanGriesel committed Feb 1, 2025
1 parent 9b56b1c commit c5079af
Show file tree
Hide file tree
Showing 17 changed files with 804 additions and 494 deletions.
2 changes: 1 addition & 1 deletion docker/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ services:
args:
BUILD_ENV: development
ports:
- 21210:8000
- 21434:8000
env_file: ../services/api/.env
depends_on:
elasticsearch:
Expand Down
2 changes: 1 addition & 1 deletion docker/docker-compose.prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ services:
api:
image: griesel/chipper:api-latest
ports:
- 127.0.0.1:21210:8000
- 127.0.0.1:21434:8000
env_file: ../services/api/.env
depends_on:
elasticsearch:
Expand Down
2 changes: 1 addition & 1 deletion services/api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ run-dev:
--hostname api \
--env-file .env \
-v "$(CURDIR)/src:/app/src:z" \
-p 21210:8000 \
-p 21434:8000 \
--name $(CONTAINER_NAME) \
--network=$(NETWORK) \
$(IMAGE_NAME)
Expand Down
86 changes: 86 additions & 0 deletions services/api/src/api/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import logging
import os
import secrets
from pathlib import Path

from dotenv import load_dotenv
from flask import Flask
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from werkzeug.middleware.proxy_fix import ProxyFix

# Load environment variables
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# App configuration
app = Flask(__name__)
app.wsgi_app = ProxyFix(app.wsgi_app)

# Version information
APP_VERSION = os.getenv("APP_VERSION", "[DEV]")
BUILD_NUMBER = os.getenv("APP_BUILD_NUM", "0")

# Feature flags
ALLOW_MODEL_CHANGE = os.getenv("ALLOW_MODEL_CHANGE", "true").lower() == "true"
ALLOW_INDEX_CHANGE = os.getenv("ALLOW_INDEX_CHANGE", "true").lower() == "true"
DEBUG = os.getenv("DEBUG", "true").lower() == "true"

# Rate limiting configuration
DAILY_LIMIT = int(os.getenv("DAILY_RATE_LIMIT", "86400"))
MINUTE_LIMIT = int(os.getenv("MINUTE_RATE_LIMIT", "60"))
STORAGE_URI = os.getenv("RATE_LIMIT_STORAGE", "memory://")

limiter = Limiter(
key_func=get_remote_address,
app=app,
default_limits=[f"{DAILY_LIMIT} per day", f"{MINUTE_LIMIT} per minute"],
storage_uri=STORAGE_URI,
)

# API Key configuration
API_KEY = os.getenv("API_KEY")
if not API_KEY:
API_KEY = secrets.token_urlsafe(32)
logger.info(f"Generated API key: {API_KEY}")


def load_systemprompt(base_path: str) -> str:
default_prompt = ""
env_var_name = "SYSTEM_PROMPT"
env_prompt = os.getenv(env_var_name)

if env_prompt is not None:
content = env_prompt.strip()
logger.info(
f"Using system prompt from '{env_var_name}' environment variable; content: '{content}'"
)
return content

file = Path(base_path) / ".systemprompt"
if not file.exists():
logger.info("No .systemprompt file found. Using default prompt.")
return default_prompt

try:
with open(file, "r", encoding="utf-8") as f:
content = f.read().strip()

if not content:
logger.warning("System prompt file is empty. Using default prompt.")
return default_prompt

logger.info(
f"Successfully loaded system prompt from {file}; content: '{content}'"
)
return content

except Exception as e:
logger.error(f"Error reading system prompt file: {e}")
return default_prompt


system_prompt_value = load_systemprompt(os.getenv("SYSTEM_PROMPT_PATH", os.getcwd()))
255 changes: 255 additions & 0 deletions services/api/src/api/handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import json
import queue
import threading
import time
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional

import elasticsearch
from api.config import DEBUG, logger
from core.pipeline_config import QueryPipelineConfig
from core.rag_pipeline import RAGQueryPipeline
from flask import Response, jsonify, stream_with_context


def format_stream_response(
config: QueryPipelineConfig,
content: str = "",
done: bool = False,
done_reason: Optional[str] = None,
images: Optional[List[str]] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
**metrics,
) -> Dict[str, Any]:
"""Format streaming response according to Ollama-API specification."""
response = {
"model": config.model_name,
"created_at": datetime.now(timezone.utc).isoformat(),
"done": done,
}

if not done:
message = {"role": "assistant", "content": content}
if images:
message["images"] = images
if tool_calls:
message["tool_calls"] = tool_calls
response["message"] = message
else:
if done_reason:
response["done_reason"] = done_reason
response.update(
{
"total_duration": metrics.get("total_duration", 0),
"load_duration": metrics.get("load_duration", 0),
"prompt_eval_count": metrics.get("prompt_eval_count", 0),
"prompt_eval_duration": metrics.get("prompt_eval_duration", 0),
"eval_count": metrics.get("eval_count", 0),
"eval_duration": metrics.get("eval_duration", 0),
}
)

return response


def handle_streaming_response(
config: QueryPipelineConfig,
query: str,
conversation: List[Dict[str, str]],
format_schema: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
) -> Response:
q = queue.Queue()
start_time = time.time_ns()
prompt_start = None

def streaming_callback(chunk):
nonlocal prompt_start
if prompt_start is None:
prompt_start = time.time_ns()

if chunk.content:
if format_schema and chunk.is_final:
try:
content = json.loads(chunk.content)
response_data = format_stream_response(
config, json.dumps(content), done=True, done_reason="stop"
)
except json.JSONDecodeError:
response_data = format_stream_response(
config,
"Error: Failed to generate valid JSON response.",
done=True,
done_reason="error",
)
else:
response_data = format_stream_response(
config,
chunk.content,
images=getattr(chunk, "images", None),
tool_calls=getattr(chunk, "tool_calls", None),
)

q.put(json.dumps(response_data) + "\n")

rag = RAGQueryPipeline(config=config, streaming_callback=streaming_callback)

def run_rag():
try:
# Track model loading
load_start = time.time_ns()
for status in rag.initialize_and_check_models():
if status.get("status") == "error":
error_data = format_stream_response(
config,
f"Error: Model initialization failed - {status.get('error')}",
done=True,
done_reason="error",
)
q.put(json.dumps(error_data) + "\n")
return

load_duration = time.time_ns() - load_start

rag.create_query_pipeline()
result = rag.run_query(
query=query, conversation=conversation, print_response=DEBUG
)

# Calculate final metrics
end_time = time.time_ns()
final_data = format_stream_response(
config,
done=True,
done_reason="stop",
total_duration=end_time - start_time,
load_duration=load_duration,
prompt_eval_count=len(conversation) + 1,
prompt_eval_duration=end_time - (prompt_start or start_time),
eval_count=len(result["llm"]["replies"][0].split())
if result and "llm" in result and "replies" in result["llm"]
else 0,
eval_duration=end_time - (prompt_start or start_time),
)
q.put(json.dumps(final_data) + "\n")

except elasticsearch.BadRequestError as e:
error_data = format_stream_response(
config,
f"Error: Embedding retriever error - {str(e)}",
done=True,
done_reason="error",
)
q.put(json.dumps(error_data) + "\n")

except Exception as e:
error_data = format_stream_response(
config, f"Error: {str(e)}", done=True, done_reason="error"
)
logger.error(f"Error in RAG pipeline: {e}", exc_info=True)
q.put(json.dumps(error_data) + "\n")

thread = threading.Thread(target=run_rag, daemon=True)
thread.start()

def generate():
while True:
try:
data = q.get(timeout=120)
if data:
yield data

if '"done": true' in data:
logger.info("Streaming completed.")
break

except queue.Empty:
# Send an empty object for heartbeat
yield json.dumps({}) + "\n"
logger.warning("Queue timeout. Sending heartbeat.")
except Exception as e:
logger.error(f"Streaming error: {e}")
error_data = format_stream_response(
config, "Streaming error occurred.", done=True, done_reason="error"
)
yield json.dumps(error_data) + "\n"
break

return Response(
stream_with_context(generate()),
mimetype="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
},
)


def handle_standard_response(
config: QueryPipelineConfig,
query: str,
conversation: List[Dict[str, str]],
format_schema: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
) -> Response:
"""Handle non-streaming responses with support for structured outputs."""
start_time = time.time_ns()
rag = RAGQueryPipeline(config=config)

try:
# Track model loading time
load_start = time.time_ns()
for status in rag.initialize_and_check_models():
if status.get("status") == "error":
raise Exception(f"Model initialization failed: {status.get('error')}")
load_duration = time.time_ns() - load_start

rag.create_query_pipeline()

# Track query execution time
prompt_start = time.time_ns()
result = rag.run_query(
query=query, conversation=conversation, print_response=False
)
end_time = time.time_ns()

if result and "llm" in result and "replies" in result["llm"]:
response_content = result["llm"]["replies"][0]

# Handle structured output if format_schema is provided
if format_schema:
try:
content = json.loads(response_content)
response_content = json.dumps(content)
except json.JSONDecodeError:
raise Exception("Failed to generate valid JSON response")

eval_count = len(response_content.split()) if response_content else 0

response = {
"model": config.model_name,
"created_at": datetime.now(timezone.utc).isoformat(),
"message": {"role": "assistant", "content": response_content},
"done": True,
"done_reason": "stop",
"total_duration": end_time - start_time,
"load_duration": load_duration,
"prompt_eval_count": len(conversation) + 1,
"prompt_eval_duration": end_time - prompt_start,
"eval_count": eval_count,
"eval_duration": end_time - prompt_start,
}

return jsonify(response)

except Exception as e:
logger.error(f"Error in RAG pipeline: {e}", exc_info=True)
error_response = {
"model": config.model_name,
"created_at": datetime.now(timezone.utc).isoformat(),
"done": True,
"done_reason": "error",
"error": str(e),
}
return jsonify(error_response)
Loading

0 comments on commit c5079af

Please sign in to comment.