Skip to content

Commit

Permalink
Bootstrap provider models on addition
Browse files Browse the repository at this point in the history
When adding a provider via the API, this bootstrap the available models
for a provider.

Signed-off-by: Juan Antonio Osorio <[email protected]>
  • Loading branch information
JAORMX committed Jan 29, 2025
1 parent 5526ead commit 7bf8b04
Show file tree
Hide file tree
Showing 13 changed files with 261 additions and 39 deletions.
20 changes: 15 additions & 5 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,26 @@ async def get_provider_endpoint(
status_code=201,
)
async def add_provider_endpoint(
request: v1_models.ProviderEndpoint,
request: v1_models.AddProviderEndpointRequest,
) -> v1_models.ProviderEndpoint:
"""Add a provider endpoint."""
try:
provend = await pcrud.add_endpoint(request)
except AlreadyExistsError:
raise HTTPException(status_code=409, detail="Provider endpoint already exists")
except ValueError as e:
raise HTTPException(
status_code=400,
detail=str(e),
)
except ValidationError as e:
# TODO: This should be more specific
raise HTTPException(
status_code=400,
detail=str(e),
)
except Exception:
logger.exception("Error while adding provider endpoint")
raise HTTPException(status_code=500, detail="Internal server error")

return provend
Expand Down Expand Up @@ -154,20 +160,24 @@ async def configure_auth_material(
)
async def update_provider_endpoint(
provider_id: UUID,
request: v1_models.ProviderEndpoint,
request: v1_models.AddProviderEndpointRequest,
) -> v1_models.ProviderEndpoint:
"""Update a provider endpoint by ID."""
try:
request.id = provider_id
request.id = str(provider_id)
provend = await pcrud.update_endpoint(request)
except provendcrud.ProviderNotFoundError:
raise HTTPException(status_code=404, detail="Provider endpoint not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except ValidationError as e:
# TODO: This should be more specific
raise HTTPException(
status_code=400,
detail=str(e),
)
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

return provend

Expand Down
10 changes: 9 additions & 1 deletion src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class ProviderEndpoint(pydantic.BaseModel):
name: str
description: str = ""
provider_type: ProviderType
endpoint: str
endpoint: str = "" # Some providers have defaults we can leverage
auth_type: Optional[ProviderAuthType] = ProviderAuthType.none

@staticmethod
Expand Down Expand Up @@ -250,6 +250,14 @@ def get_from_registry(self, registry: ProviderRegistry) -> Optional[BaseProvider
return registry.get_provider(self.provider_type)


class AddProviderEndpointRequest(ProviderEndpoint):
"""
Represents a request to add a provider endpoint.
"""

api_key: Optional[str] = None


class ConfigureAuthMaterial(pydantic.BaseModel):
"""
Represents a request to configure auth material for a provider.
Expand Down
15 changes: 15 additions & 0 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,21 @@ async def add_provider_model(self, model: ProviderModel) -> ProviderModel:
added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True)
return added_model

async def delete_provider_models(self, provider_id: str) -> Optional[ProviderModel]:
sql = text(
"""
DELETE FROM provider_models
WHERE provider_endpoint_id = :provider_endpoint_id
RETURNING *
"""
)
await self._execute_update_pydantic_model(
ProviderModel(
provider_endpoint_id=provider_id,
name="Fake name to respect the signature of the function"
), sql, should_raise=True
)


class DbReader(DbCodeGate):

Expand Down
29 changes: 18 additions & 11 deletions src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
from codegate.providers.base import BaseProvider
from codegate.providers.base import BaseProvider, ModelFetchError
from codegate.providers.litellmshim import anthropic_stream_generator


Expand All @@ -29,16 +29,23 @@ def __init__(
def provider_route_name(self) -> str:
return "anthropic"

def models(self) -> List[str]:
# TODO: This won't work since we need an API Key being set.
resp = httpx.get("https://api.anthropic.com/models")
# If Anthropic returned 404, it means it's not accepting our
# requests. We should throw an error.
if resp.status_code == 404:
raise HTTPException(
status_code=404,
detail="The Anthropic API is not accepting requests. Please check your API key.",
)
def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
headers = {
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}
if api_key:
headers["x-api-key"] = api_key
if not endpoint:
endpoint = "https://api.anthropic.com"

resp = httpx.get(
f"{endpoint}/v1/models",
headers=headers,
)

if resp.status_code != 200:
raise ModelFetchError(f"Failed to fetch models from Anthropic API: {resp.text}")

respjson = resp.json()

Expand Down
6 changes: 5 additions & 1 deletion src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]]


class ModelFetchError(Exception):
pass


class BaseProvider(ABC):
"""
The provider class is responsible for defining the API routes and
Expand Down Expand Up @@ -55,7 +59,7 @@ def _setup_routes(self) -> None:
pass

@abstractmethod
def models(self) -> List[str]:
def models(self, endpoint, str=None, api_key: str = None) -> List[str]:
pass

@property
Expand Down
121 changes: 116 additions & 5 deletions src/codegate/providers/crud/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from codegate.db import models as dbmodels
from codegate.db.connection import DbReader, DbRecorder
from codegate.providers.base import BaseProvider
from codegate.providers.registry import ProviderRegistry
from codegate.providers.registry import ProviderRegistry, get_provider_registry

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -62,23 +62,106 @@ async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.Provider
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)

async def add_endpoint(
self, endpoint: apimodelsv1.ProviderEndpoint
self, endpoint: apimodelsv1.AddProviderEndpointRequest
) -> apimodelsv1.ProviderEndpoint:
"""Add an endpoint."""

if not endpoint.endpoint:
endpoint.endpoint = provider_default_endpoints(endpoint.provider_type)

# If we STILL don't have an endpoint, we can't continue
if not endpoint.endpoint:
raise ValueError("No endpoint provided and no default found for provider type")

dbend = endpoint.to_db_model()
provider_registry = get_provider_registry()

# We override the ID here, as we want to generate it.
dbend.id = str(uuid4())

dbendpoint = await self._db_writer.add_provider_endpoint()
prov = endpoint.get_from_registry(provider_registry)
if prov is None:
raise ValueError("Unknown provider type: {}".format(endpoint.provider_type))

models = []
if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key:
raise ValueError("API key must be provided for API auth type")
if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough:
try:
models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key)
except Exception as err:
raise ValueError("Unable to get models from provider: {}".format(str(err)))

dbendpoint = await self._db_writer.add_provider_endpoint(dbend)

await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=endpoint.auth_type,
auth_blob=endpoint.api_key if endpoint.api_key else "",
)
)

for model in models:
await self._db_writer.add_provider_model(
dbmodels.ProviderModel(
provider_endpoint_id=dbendpoint.id,
name=model,
)
)
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)

async def update_endpoint(
self, endpoint: apimodelsv1.ProviderEndpoint
self, endpoint: apimodelsv1.AddProviderEndpointRequest
) -> apimodelsv1.ProviderEndpoint:
"""Update an endpoint."""

if not endpoint.endpoint:
endpoint.endpoint = provider_default_endpoints(endpoint.provider_type)

# If we STILL don't have an endpoint, we can't continue
if not endpoint.endpoint:
raise ValueError("No endpoint provided and no default found for provider type")

provider_registry = get_provider_registry()
prov = endpoint.get_from_registry(provider_registry)
if prov is None:
raise ValueError("Unknown provider type: {}".format(endpoint.provider_type))

founddbe = await self._db_reader.get_provider_endpoint_by_id(str(endpoint.id))
if founddbe is None:
raise ProviderNotFoundError("Provider not found")

models = []
if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key:
raise ValueError("API key must be provided for API auth type")
if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough:
try:
models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key)
except Exception as err:
raise ValueError("Unable to get models from provider: {}".format(str(err)))

# Reset all provider models.
await self._db_writer.delete_provider_models(str(endpoint.id))

for model in models:
await self._db_writer.add_provider_model(
dbmodels.ProviderModel(
provider_endpoint_id=founddbe.id,
name=model,
)
)

dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())

await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=endpoint.auth_type,
auth_blob=endpoint.api_key if endpoint.api_key else "",
)
)

return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)

async def configure_auth_material(
Expand Down Expand Up @@ -175,6 +258,13 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
continue

pimpl = provend.get_from_registry(preg)
if pimpl is None:
logger.warning(
"Provider not found in registry",
provider=provend.name,
endpoint=provend.endpoint,
)
continue
await try_initialize_provider_endpoints(provend, pimpl, db_writer)


Expand Down Expand Up @@ -240,7 +330,7 @@ def __provider_endpoint_from_cfg(
description=("Endpoint for the {} provided via the CodeGate configuration.").format(
provider_name
),
provider_type=provider_name,
provider_type=provider_overrides(provider_name),
auth_type=apimodelsv1.ProviderAuthType.passthrough,
)
except ValidationError as err:
Expand All @@ -251,3 +341,24 @@ def __provider_endpoint_from_cfg(
err=str(err),
)
return None


def provider_default_endpoints(provider_type: str) -> str:
defaults = {
"openai": "https://api.openai.com",
"anthropic": "https://api.anthropic.com",
}

# If we have a default, we return it
# Otherwise, we return an empty string
return defaults.get(provider_type, "")


def provider_overrides(provider_type: str) -> str:
overrides = {
"lm_studio": "openai",
}

# If we have an override, we return it
# Otherwise, we return the type
return overrides.get(provider_type, provider_type)
20 changes: 17 additions & 3 deletions src/codegate/providers/llamacpp/provider.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
from typing import List

import httpx
import structlog
from fastapi import HTTPException, Request

from codegate.pipeline.factory import PipelineFactory
from codegate.providers.base import BaseProvider
from codegate.providers.base import BaseProvider, ModelFetchError
from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler
from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer

Expand All @@ -27,9 +28,22 @@ def __init__(
def provider_route_name(self) -> str:
return "llamacpp"

def models(self):
def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if not endpoint:
endpoint = self.base_url

# HACK: This is using OpenAI's /v1/models endpoint to get the list of models
resp = httpx.get(f"{self.base_url}/v1/models")
resp = httpx.get(
f"{endpoint}/v1/models",
headers=headers,
)

if resp.status_code != 200:
raise ModelFetchError(f"Failed to fetch models from Llama API: {resp.text}")

jsonresp = resp.json()

return [model["id"] for model in jsonresp.get("data", [])]
Expand Down
Loading

0 comments on commit 7bf8b04

Please sign in to comment.