From 7bf8b04a7153f2d5fc5260b9dabe01f416379665 Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Wed, 29 Jan 2025 16:50:44 +0200 Subject: [PATCH] Bootstrap provider models on addition When adding a provider via the API, this bootstrap the available models for a provider. Signed-off-by: Juan Antonio Osorio --- src/codegate/api/v1.py | 20 ++- src/codegate/api/v1_models.py | 10 +- src/codegate/db/connection.py | 15 +++ src/codegate/providers/anthropic/provider.py | 29 +++-- src/codegate/providers/base.py | 6 +- src/codegate/providers/crud/crud.py | 121 ++++++++++++++++++- src/codegate/providers/llamacpp/provider.py | 20 ++- src/codegate/providers/ollama/provider.py | 19 ++- src/codegate/providers/openai/provider.py | 17 ++- src/codegate/providers/registry.py | 17 +++ src/codegate/providers/vllm/provider.py | 20 ++- src/codegate/server.py | 4 +- tests/test_server.py | 2 +- 13 files changed, 261 insertions(+), 39 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index b256635f..28309c8b 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -109,13 +109,18 @@ async def get_provider_endpoint( status_code=201, ) async def add_provider_endpoint( - request: v1_models.ProviderEndpoint, + request: v1_models.AddProviderEndpointRequest, ) -> v1_models.ProviderEndpoint: """Add a provider endpoint.""" try: provend = await pcrud.add_endpoint(request) except AlreadyExistsError: raise HTTPException(status_code=409, detail="Provider endpoint already exists") + except ValueError as e: + raise HTTPException( + status_code=400, + detail=str(e), + ) except ValidationError as e: # TODO: This should be more specific raise HTTPException( @@ -123,6 +128,7 @@ async def add_provider_endpoint( detail=str(e), ) except Exception: + logger.exception("Error while adding provider endpoint") raise HTTPException(status_code=500, detail="Internal server error") return provend @@ -154,20 +160,24 @@ async def configure_auth_material( ) async def update_provider_endpoint( provider_id: UUID, - request: v1_models.ProviderEndpoint, + request: v1_models.AddProviderEndpointRequest, ) -> v1_models.ProviderEndpoint: """Update a provider endpoint by ID.""" try: - request.id = provider_id + request.id = str(provider_id) provend = await pcrud.update_endpoint(request) + except provendcrud.ProviderNotFoundError: + raise HTTPException(status_code=404, detail="Provider endpoint not found") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) except ValidationError as e: # TODO: This should be more specific raise HTTPException( status_code=400, detail=str(e), ) - except Exception: - raise HTTPException(status_code=500, detail="Internal server error") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) return provend diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index 65c4acc1..b448ce6b 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -222,7 +222,7 @@ class ProviderEndpoint(pydantic.BaseModel): name: str description: str = "" provider_type: ProviderType - endpoint: str + endpoint: str = "" # Some providers have defaults we can leverage auth_type: Optional[ProviderAuthType] = ProviderAuthType.none @staticmethod @@ -250,6 +250,14 @@ def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider return registry.get_provider(self.provider_type) +class AddProviderEndpointRequest(ProviderEndpoint): + """ + Represents a request to add a provider endpoint. + """ + + api_key: Optional[str] = None + + class ConfigureAuthMaterial(pydantic.BaseModel): """ Represents a request to configure auth material for a provider. diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index caed5276..8ed39deb 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -459,6 +459,21 @@ async def add_provider_model(self, model: ProviderModel) -> ProviderModel: added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True) return added_model + async def delete_provider_models(self, provider_id: str) -> Optional[ProviderModel]: + sql = text( + """ + DELETE FROM provider_models + WHERE provider_endpoint_id = :provider_endpoint_id + RETURNING * + """ + ) + await self._execute_update_pydantic_model( + ProviderModel( + provider_endpoint_id=provider_id, + name="Fake name to respect the signature of the function" + ), sql, should_raise=True + ) + class DbReader(DbCodeGate): diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 48821de0..9c656794 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -8,7 +8,7 @@ from codegate.pipeline.factory import PipelineFactory from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer from codegate.providers.anthropic.completion_handler import AnthropicCompletion -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.litellmshim import anthropic_stream_generator @@ -29,16 +29,23 @@ def __init__( def provider_route_name(self) -> str: return "anthropic" - def models(self) -> List[str]: - # TODO: This won't work since we need an API Key being set. - resp = httpx.get("https://api.anthropic.com/models") - # If Anthropic returned 404, it means it's not accepting our - # requests. We should throw an error. - if resp.status_code == 404: - raise HTTPException( - status_code=404, - detail="The Anthropic API is not accepting requests. Please check your API key.", - ) + def models(self, endpoint: str = None, api_key: str = None) -> List[str]: + headers = { + "Content-Type": "application/json", + "anthropic-version": "2023-06-01", + } + if api_key: + headers["x-api-key"] = api_key + if not endpoint: + endpoint = "https://api.anthropic.com" + + resp = httpx.get( + f"{endpoint}/v1/models", + headers=headers, + ) + + if resp.status_code != 200: + raise ModelFetchError(f"Failed to fetch models from Anthropic API: {resp.text}") respjson = resp.json() diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 8e9a4d40..050cea06 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -24,6 +24,10 @@ StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]] +class ModelFetchError(Exception): + pass + + class BaseProvider(ABC): """ The provider class is responsible for defining the API routes and @@ -55,7 +59,7 @@ def _setup_routes(self) -> None: pass @abstractmethod - def models(self) -> List[str]: + def models(self, endpoint, str=None, api_key: str = None) -> List[str]: pass @property diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py index ebae2b97..5207b7d6 100644 --- a/src/codegate/providers/crud/crud.py +++ b/src/codegate/providers/crud/crud.py @@ -11,7 +11,7 @@ from codegate.db import models as dbmodels from codegate.db.connection import DbReader, DbRecorder from codegate.providers.base import BaseProvider -from codegate.providers.registry import ProviderRegistry +from codegate.providers.registry import ProviderRegistry, get_provider_registry logger = structlog.get_logger("codegate") @@ -62,23 +62,106 @@ async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.Provider return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) async def add_endpoint( - self, endpoint: apimodelsv1.ProviderEndpoint + self, endpoint: apimodelsv1.AddProviderEndpointRequest ) -> apimodelsv1.ProviderEndpoint: """Add an endpoint.""" + + if not endpoint.endpoint: + endpoint.endpoint = provider_default_endpoints(endpoint.provider_type) + + # If we STILL don't have an endpoint, we can't continue + if not endpoint.endpoint: + raise ValueError("No endpoint provided and no default found for provider type") + dbend = endpoint.to_db_model() + provider_registry = get_provider_registry() # We override the ID here, as we want to generate it. dbend.id = str(uuid4()) - dbendpoint = await self._db_writer.add_provider_endpoint() + prov = endpoint.get_from_registry(provider_registry) + if prov is None: + raise ValueError("Unknown provider type: {}".format(endpoint.provider_type)) + + models = [] + if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key: + raise ValueError("API key must be provided for API auth type") + if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough: + try: + models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key) + except Exception as err: + raise ValueError("Unable to get models from provider: {}".format(str(err))) + + dbendpoint = await self._db_writer.add_provider_endpoint(dbend) + + await self._db_writer.push_provider_auth_material( + dbmodels.ProviderAuthMaterial( + provider_endpoint_id=dbendpoint.id, + auth_type=endpoint.auth_type, + auth_blob=endpoint.api_key if endpoint.api_key else "", + ) + ) + + for model in models: + await self._db_writer.add_provider_model( + dbmodels.ProviderModel( + provider_endpoint_id=dbendpoint.id, + name=model, + ) + ) return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) async def update_endpoint( - self, endpoint: apimodelsv1.ProviderEndpoint + self, endpoint: apimodelsv1.AddProviderEndpointRequest ) -> apimodelsv1.ProviderEndpoint: """Update an endpoint.""" + if not endpoint.endpoint: + endpoint.endpoint = provider_default_endpoints(endpoint.provider_type) + + # If we STILL don't have an endpoint, we can't continue + if not endpoint.endpoint: + raise ValueError("No endpoint provided and no default found for provider type") + + provider_registry = get_provider_registry() + prov = endpoint.get_from_registry(provider_registry) + if prov is None: + raise ValueError("Unknown provider type: {}".format(endpoint.provider_type)) + + founddbe = await self._db_reader.get_provider_endpoint_by_id(str(endpoint.id)) + if founddbe is None: + raise ProviderNotFoundError("Provider not found") + + models = [] + if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key: + raise ValueError("API key must be provided for API auth type") + if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough: + try: + models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key) + except Exception as err: + raise ValueError("Unable to get models from provider: {}".format(str(err))) + + # Reset all provider models. + await self._db_writer.delete_provider_models(str(endpoint.id)) + + for model in models: + await self._db_writer.add_provider_model( + dbmodels.ProviderModel( + provider_endpoint_id=founddbe.id, + name=model, + ) + ) + dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model()) + + await self._db_writer.push_provider_auth_material( + dbmodels.ProviderAuthMaterial( + provider_endpoint_id=dbendpoint.id, + auth_type=endpoint.auth_type, + auth_blob=endpoint.api_key if endpoint.api_key else "", + ) + ) + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) async def configure_auth_material( @@ -175,6 +258,13 @@ async def initialize_provider_endpoints(preg: ProviderRegistry): continue pimpl = provend.get_from_registry(preg) + if pimpl is None: + logger.warning( + "Provider not found in registry", + provider=provend.name, + endpoint=provend.endpoint, + ) + continue await try_initialize_provider_endpoints(provend, pimpl, db_writer) @@ -240,7 +330,7 @@ def __provider_endpoint_from_cfg( description=("Endpoint for the {} provided via the CodeGate configuration.").format( provider_name ), - provider_type=provider_name, + provider_type=provider_overrides(provider_name), auth_type=apimodelsv1.ProviderAuthType.passthrough, ) except ValidationError as err: @@ -251,3 +341,24 @@ def __provider_endpoint_from_cfg( err=str(err), ) return None + + +def provider_default_endpoints(provider_type: str) -> str: + defaults = { + "openai": "https://api.openai.com", + "anthropic": "https://api.anthropic.com", + } + + # If we have a default, we return it + # Otherwise, we return an empty string + return defaults.get(provider_type, "") + + +def provider_overrides(provider_type: str) -> str: + overrides = { + "lm_studio": "openai", + } + + # If we have an override, we return it + # Otherwise, we return the type + return overrides.get(provider_type, provider_type) diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index 4478d137..ebeb6c06 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -1,11 +1,12 @@ import json +from typing import List import httpx import structlog from fastapi import HTTPException, Request from codegate.pipeline.factory import PipelineFactory -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer @@ -27,9 +28,22 @@ def __init__( def provider_route_name(self) -> str: return "llamacpp" - def models(self): + def models(self, endpoint: str = None, api_key: str = None) -> List[str]: + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if not endpoint: + endpoint = self.base_url + # HACK: This is using OpenAI's /v1/models endpoint to get the list of models - resp = httpx.get(f"{self.base_url}/v1/models") + resp = httpx.get( + f"{endpoint}/v1/models", + headers=headers, + ) + + if resp.status_code != 200: + raise ModelFetchError(f"Failed to fetch models from Llama API: {resp.text}") + jsonresp = resp.json() return [model["id"] for model in jsonresp.get("data", [])] diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index b8e0477b..66ea38ef 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -1,4 +1,5 @@ import json +from typing import List import httpx import structlog @@ -6,7 +7,7 @@ from codegate.config import Config from codegate.pipeline.factory import PipelineFactory -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer from codegate.providers.ollama.completion_handler import OllamaShim @@ -34,8 +35,20 @@ def __init__( def provider_route_name(self) -> str: return "ollama" - def models(self): - resp = httpx.get(f"{self.base_url}/api/tags") + def models(self, endpoint: str = None, api_key: str = None) -> List[str]: + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if not endpoint: + endpoint = self.base_url + resp = httpx.get( + f"{endpoint}/api/tags", + headers=headers, + ) + + if resp.status_code != 200: + raise ModelFetchError(f"Failed to fetch models from Ollama API: {resp.text}") + jsonresp = resp.json() return [model["name"] for model in jsonresp.get("models", [])] diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 87588265..be9ddabe 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -8,7 +8,7 @@ from codegate.config import Config from codegate.pipeline.factory import PipelineFactory -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer @@ -35,9 +35,18 @@ def __init__( def provider_route_name(self) -> str: return "openai" - def models(self) -> List[str]: - # NOTE: This won't work since we need an API Key being set. - resp = httpx.get(f"{self.lm_studio_url}/v1/models") + def models(self, endpoint: str = None, api_key: str = None) -> List[str]: + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if not endpoint: + endpoint = "https://api.openai.com" + + resp = httpx.get(f"{endpoint}/v1/models", headers=headers) + + if resp.status_code != 200: + raise ModelFetchError(f"Failed to fetch models from OpenAI API: {resp.text}") + jsonresp = resp.json() return [model["id"] for model in jsonresp.get("data", [])] diff --git a/src/codegate/providers/registry.py b/src/codegate/providers/registry.py index 7450460f..3def8840 100644 --- a/src/codegate/providers/registry.py +++ b/src/codegate/providers/registry.py @@ -1,9 +1,26 @@ +from threading import Lock from typing import Dict, Optional from fastapi import FastAPI from codegate.providers.base import BaseProvider +_provider_registry_lock = Lock() +_provider_registry_singleton: Optional["ProviderRegistry"] = None + + +def get_provider_registry(app: FastAPI = None) -> "ProviderRegistry": + global _provider_registry_singleton + + if _provider_registry_singleton is None: + if app is None: + raise ValueError("Cannot initialize a ProviderRegistry without an app") + with _provider_registry_lock: + if _provider_registry_singleton is None: + _provider_registry_singleton = ProviderRegistry(app) + + return _provider_registry_singleton + class ProviderRegistry: def __init__(self, app: FastAPI): diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index 303b907b..3448db84 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -1,4 +1,5 @@ import json +from typing import List import httpx import structlog @@ -7,7 +8,7 @@ from codegate.config import Config from codegate.pipeline.factory import PipelineFactory -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.vllm.adapter import VLLMInputNormalizer, VLLMOutputNormalizer @@ -31,8 +32,21 @@ def __init__( def provider_route_name(self) -> str: return "vllm" - def models(self): - resp = httpx.get(f"{self.base_url}/v1/models") + def models(self, endpoint: str = None, api_key: str = None) -> List[str]: + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if not endpoint: + endpoint = self.base_url + + resp = httpx.get( + f"{endpoint}/v1/models", + headers=headers, + ) + + if resp.status_code != 200: + raise ModelFetchError(f"Failed to fetch models from vLLM API: {resp.text}") + jsonresp = resp.json() return [model["id"] for model in jsonresp.get("data", [])] diff --git a/src/codegate/server.py b/src/codegate/server.py index ece60c0c..216ba95e 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -15,7 +15,7 @@ from codegate.providers.llamacpp.provider import LlamaCppProvider from codegate.providers.ollama.provider import OllamaProvider from codegate.providers.openai.provider import OpenAIProvider -from codegate.providers.registry import ProviderRegistry +from codegate.providers.registry import ProviderRegistry, get_provider_registry from codegate.providers.vllm.provider import VLLMProvider logger = structlog.get_logger("codegate") @@ -64,7 +64,7 @@ async def log_user_agent(request: Request, call_next): app.add_middleware(ServerErrorMiddleware, handler=custom_error_handler) # Create provider registry - registry = ProviderRegistry(app) + registry = get_provider_registry(app) app.set_provider_registry(registry) # Register all known providers diff --git a/tests/test_server.py b/tests/test_server.py index f7b7a12f..80bb7cb0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -97,7 +97,7 @@ def test_version_endpoint(mock_fetch_latest_version, test_client: TestClient) -> @patch("codegate.pipeline.secrets.manager.SecretsManager") -@patch("codegate.server.ProviderRegistry") +@patch("codegate.server.get_provider_registry") def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_factory) -> None: """Test that all providers are registered correctly.""" init_app(mock_pipeline_factory)