Skip to content

Commit

Permalink
Merge pull request #1074 from julep-ai/dev
Browse files Browse the repository at this point in the history
dev -> main
  • Loading branch information
Ahmad-mtos authored Jan 22, 2025
2 parents 34d2787 + faa767d commit bc2ae7e
Show file tree
Hide file tree
Showing 12 changed files with 1,945 additions and 6 deletions.
24 changes: 24 additions & 0 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import wraps
from typing import Literal

import aiohttp
from beartype import beartype
from litellm import acompletion as _acompletion
from litellm import aembedding as _aembedding
Expand Down Expand Up @@ -109,3 +110,26 @@ async def aembedding(
for item in embedding_list
if len(item["embedding"]) >= dimensions
]


@beartype
async def get_model_list(*, custom_api_key: str | None = None) -> list[dict]:
"""
Fetches the list of available models from the LiteLLM server.
Returns:
list[dict]: A list of model information dictionaries
"""

headers = {
"accept": "application/json",
"x-api-key": custom_api_key or litellm_master_key
}

async with aiohttp.ClientSession() as session, session.get(
url=f"{litellm_url}/models" if not custom_api_key else "/models",
headers=headers
) as response:
response.raise_for_status()
data = await response.json()
return data["data"]
4 changes: 2 additions & 2 deletions agents-api/agents_api/queries/chat/prepare_chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
agents.metadata,
agents.default_settings
FROM session_lookup
INNER JOIN agents ON session_lookup.participant_id = agents.agent_id
INNER JOIN agents ON session_lookup.participant_id = agents.agent_id AND agents.developer_id = session_lookup.developer_id
WHERE
session_lookup.developer_id = $1 AND
session_id = $2 AND
Expand Down Expand Up @@ -95,7 +95,7 @@
tools.updated_at,
tools.created_at
FROM session_lookup
INNER JOIN tools ON session_lookup.participant_id = tools.agent_id
INNER JOIN tools ON session_lookup.participant_id = tools.agent_id AND tools.developer_id = session_lookup.developer_id
WHERE
session_lookup.developer_id = $1 AND
session_id = $2 AND
Expand Down
6 changes: 5 additions & 1 deletion agents-api/agents_api/routers/agents/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from ...dependencies.developer_id import get_developer_id
from ...queries.agents.create_agent import create_agent as create_agent_query
from ..utils.model_validation import validate_model
from .router import router


Expand All @@ -18,7 +19,10 @@ async def create_agent(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
data: CreateAgentRequest,
) -> ResourceCreatedResponse:
# TODO: Validate model name

if data.model:
await validate_model(data.model)

agent = await create_agent_query(
developer_id=x_developer_id,
data=data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ...queries.agents.create_or_update_agent import (
create_or_update_agent as create_or_update_agent_query,
)
from ..utils.model_validation import validate_model
from .router import router


Expand All @@ -21,7 +22,10 @@ async def create_or_update_agent(
data: CreateOrUpdateAgentRequest,
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
) -> ResourceCreatedResponse:
# TODO: Validate model name

if data.model:
await validate_model(data.model)

agent = await create_or_update_agent_query(
developer_id=x_developer_id,
agent_id=agent_id,
Expand Down
5 changes: 5 additions & 0 deletions agents-api/agents_api/routers/agents/patch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse
from ...dependencies.developer_id import get_developer_id
from ...queries.agents.patch_agent import patch_agent as patch_agent_query
from ..utils.model_validation import validate_model
from .router import router


Expand All @@ -21,6 +22,10 @@ async def patch_agent(
agent_id: UUID,
data: PatchAgentRequest,
) -> ResourceUpdatedResponse:

if data.model:
await validate_model(data.model)

return await patch_agent_query(
agent_id=agent_id,
developer_id=x_developer_id,
Expand Down
7 changes: 6 additions & 1 deletion agents-api/agents_api/routers/agents/update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest
from ...dependencies.developer_id import get_developer_id
from ...queries.agents.update_agent import update_agent as update_agent_query
from ..utils.model_validation import validate_model
from .router import router


Expand All @@ -20,7 +21,11 @@ async def update_agent(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
agent_id: UUID,
data: UpdateAgentRequest,
) -> ResourceUpdatedResponse:
) -> ResourceUpdatedResponse:

if data.model:
await validate_model(data.model)

return await update_agent_query(
developer_id=x_developer_id,
agent_id=agent_id,
Expand Down
5 changes: 5 additions & 0 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ...queries.chat.prepare_chat_context import prepare_chat_context
from ...queries.entries.create_entries import create_entries
from ...queries.sessions.count_sessions import count_sessions as count_sessions_query
from ..utils.model_validation import validate_model
from .metrics import total_tokens_per_user
from .router import router

Expand Down Expand Up @@ -55,6 +56,10 @@ async def chat(
Returns:
ChatResponse: The chat response.
"""

if chat_input.model:
await validate_model(chat_input.model)

# check if the developer is paid
if "paid" not in developer.tags:
# get the session length
Expand Down
19 changes: 19 additions & 0 deletions agents-api/agents_api/routers/utils/model_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from fastapi import HTTPException
from starlette.status import HTTP_400_BAD_REQUEST

from ...clients.litellm import get_model_list


async def validate_model(model_name: str) -> None:
"""
Validates if a given model name is available in LiteLLM.
Raises HTTPException if model is not available.
"""
models = await get_model_list()
available_models = [model["id"] for model in models]

if model_name not in available_models:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail=f"Model {model_name} not available. Available models: {available_models}"
)
11 changes: 10 additions & 1 deletion agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
import string
import sys
from unittest.mock import patch
from uuid import UUID

from agents_api.autogen.openapi_model import (
Expand Down Expand Up @@ -440,10 +441,18 @@ async def test_tool(
return tool


SAMPLE_MODELS = [
{"id": "gpt-4"},
{"id": "gpt-3.5-turbo"},
{"id": "gpt-4o-mini"},
]


@fixture(scope="global")
def client(_dsn=pg_dsn):
with TestClient(app=app) as client:
yield client
with patch("agents_api.routers.utils.model_validation.get_model_list", return_value=SAMPLE_MODELS):
yield client


@fixture(scope="global")
Expand Down
28 changes: 28 additions & 0 deletions agents-api/tests/test_model_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from unittest.mock import patch

from agents_api.routers.utils.model_validation import validate_model
from fastapi import HTTPException
from ward import raises, test

from tests.fixtures import SAMPLE_MODELS


@test("validate_model: succeeds when model is available")
async def _():
# Use async context manager for patching
with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models:
mock_get_models.return_value = SAMPLE_MODELS
await validate_model("gpt-4o-mini")
mock_get_models.assert_called_once()


@test("validate_model: fails when model is unavailable")
async def _():
with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models:
mock_get_models.return_value = SAMPLE_MODELS
with raises(HTTPException) as exc:
await validate_model("non-existent-model")

assert exc.raised.status_code == 400
assert "Model non-existent-model not available" in exc.raised.detail
mock_get_models.assert_called_once()
Loading

0 comments on commit bc2ae7e

Please sign in to comment.