Skip to content

Commit

Permalink
Merge pull request #1089 from julep-ai/f/better-return-models
Browse files Browse the repository at this point in the history
feat(agents-api): Return full objects instead of ResourceCreated/Updated
  • Loading branch information
HamadaSalhab authored Jan 27, 2025
2 parents 3d38df9 + bf84810 commit dbde7a8
Show file tree
Hide file tree
Showing 78 changed files with 1,105 additions and 629 deletions.
36 changes: 0 additions & 36 deletions agents-api/agents_api/autogen/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,6 @@ class PyExpression(RootModel[str]):
"""


class ResourceCreatedResponse(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
id: UUID
"""
ID of created resource
"""
created_at: Annotated[AwareDatetime, Field(json_schema_extra={"readOnly": True})]
"""
When this resource was created as UTC date-time
"""
jobs: Annotated[list[UUID], Field(json_schema_extra={"readOnly": True})] = []
"""
IDs (if any) of jobs created as part of this request
"""


class ResourceDeletedResponse(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
Expand All @@ -92,24 +74,6 @@ class ResourceDeletedResponse(BaseModel):
"""


class ResourceUpdatedResponse(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
id: UUID
"""
ID of updated resource
"""
updated_at: Annotated[AwareDatetime, Field(json_schema_extra={"readOnly": True})]
"""
When this resource was updated as UTC date-time
"""
jobs: Annotated[list[UUID], Field(json_schema_extra={"readOnly": True})] = []
"""
IDs (if any) of jobs created as part of this request
"""


class Uuid(RootModel[UUID]):
model_config = ConfigDict(
populate_by_name=True,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
DataT = TypeVar("DataT", bound=BaseModel)


class ListResponse(BaseModel, Generic[DataT]):
class ListResponse[DataT: BaseModel](BaseModel):
items: list[DataT]


Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def render_template_string(


@beartype
async def render_template_nested(
async def render_template_nested[T: (str, dict, list[dict | list[dict]], None)](
input: T,
variables: dict,
check: bool = False,
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/queries/agents/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from beartype import beartype
from uuid_extensions import uuid7

from ...autogen.openapi_model import CreateAgentRequest, ResourceCreatedResponse
from ...autogen.openapi_model import Agent, CreateAgentRequest
from ...common.utils.db_exceptions import common_db_exceptions
from ...metrics.counters import increase_counter
from ..utils import generate_canonical_name, pg_query, rewrap_exceptions, wrap_in_class
Expand Down Expand Up @@ -43,9 +43,9 @@

@rewrap_exceptions(common_db_exceptions("agent", ["create"]))
@wrap_in_class(
ResourceCreatedResponse,
Agent,
one=True,
transform=lambda d: {"id": d["agent_id"], "created_at": d["created_at"]},
transform=lambda d: {**d, "id": d["agent_id"]},
)
@increase_counter("create_agent")
@pg_query
Expand Down
11 changes: 7 additions & 4 deletions agents-api/agents_api/queries/agents/patch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from beartype import beartype

from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse
from ...autogen.openapi_model import Agent, PatchAgentRequest
from ...common.utils.db_exceptions import common_db_exceptions
from ...metrics.counters import increase_counter
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
Expand Down Expand Up @@ -51,15 +51,18 @@

@rewrap_exceptions(common_db_exceptions("agent", ["patch"]))
@wrap_in_class(
ResourceUpdatedResponse,
Agent,
one=True,
transform=lambda d: {"id": d["agent_id"], **d},
transform=lambda d: {**d, "id": d["agent_id"]},
)
@increase_counter("patch_agent")
@pg_query
@beartype
async def patch_agent(
*, agent_id: UUID, developer_id: UUID, data: PatchAgentRequest
*,
agent_id: UUID,
developer_id: UUID,
data: PatchAgentRequest,
) -> tuple[str, list]:
"""
Constructs the SQL query to partially update an agent's details.
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/queries/agents/update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from beartype import beartype

from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest
from ...autogen.openapi_model import Agent, UpdateAgentRequest
from ...common.utils.db_exceptions import common_db_exceptions
from ...metrics.counters import increase_counter
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
Expand All @@ -28,9 +28,9 @@

@rewrap_exceptions(common_db_exceptions("agent", ["update"]))
@wrap_in_class(
ResourceUpdatedResponse,
Agent,
one=True,
transform=lambda d: {"id": d["agent_id"], **d},
transform=lambda d: {**d, "id": d["agent_id"]},
)
@increase_counter("update_agent")
@pg_query
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/queries/developers/create_developer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from beartype import beartype
from uuid_extensions import uuid7

from ...autogen.openapi_model import ResourceCreatedResponse
from ...common.protocol.developers import Developer
from ...common.utils.db_exceptions import common_db_exceptions
from ..utils import pg_query, rewrap_exceptions, wrap_in_class

Expand All @@ -29,9 +29,9 @@

@rewrap_exceptions(common_db_exceptions("developer", ["create"]))
@wrap_in_class(
ResourceCreatedResponse,
Developer,
one=True,
transform=lambda d: {**d, "id": d["developer_id"], "created_at": d["created_at"]},
transform=lambda d: {**d, "id": d["developer_id"]},
)
@pg_query
@beartype
Expand Down
29 changes: 20 additions & 9 deletions agents-api/agents_api/queries/docs/create_doc.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from pydoc import Doc
from typing import Literal
from uuid import UUID

from beartype import beartype
from uuid_extensions import uuid7

from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse
from ...common.utils.datetime import utcnow
from ...autogen.openapi_model import CreateDocRequest, Doc
from ...common.utils.db_exceptions import common_db_exceptions
from ...metrics.counters import increase_counter
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
from .get_doc import doc_without_embedding_query
from .utils import transform_doc

# Base INSERT for docs
doc_query = """
Expand Down Expand Up @@ -49,14 +51,9 @@

@rewrap_exceptions(common_db_exceptions("doc", ["create"]))
@wrap_in_class(
ResourceCreatedResponse,
Doc,
one=True,
transform=lambda d: {
"id": d["doc_id"],
"jobs": [],
"created_at": utcnow(),
**d,
},
transform=transform_doc,
)
@increase_counter("create_doc")
@pg_query
Expand Down Expand Up @@ -131,6 +128,13 @@ async def create_doc(
# Add the owner query
queries.append((doc_owner_query, final_params_owner, "fetchmany"))

# get the doc with embedding
queries.append((
doc_without_embedding_query,
[developer_id, current_doc_id],
"fetchrow",
))

else:
# Create the doc record
doc_params = [
Expand Down Expand Up @@ -159,4 +163,11 @@ async def create_doc(
# Add the owner query
queries.append((doc_owner_query, owner_params, "fetch"))

# get the doc with embedding
queries.append((
doc_without_embedding_query,
[developer_id, current_doc_id],
"fetchrow",
))

return queries
56 changes: 33 additions & 23 deletions agents-api/agents_api/queries/docs/get_doc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
from uuid import UUID

from beartype import beartype

from ...autogen.openapi_model import Doc
from ...common.utils.db_exceptions import common_db_exceptions
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
from .utils import transform_doc

# Update the query to use DISTINCT ON to prevent duplicates
# get doc with embedding
doc_with_embedding_query = """
SELECT
d.doc_id,
Expand Down Expand Up @@ -42,33 +42,43 @@
LIMIT 1;
"""


def transform_get_doc(d: dict) -> dict:
content = d["content"]

embeddings = d["embeddings"]

if isinstance(embeddings, str):
embeddings = json.loads(embeddings)
elif isinstance(embeddings, list) and all(isinstance(e, str) for e in embeddings):
embeddings = [json.loads(e) for e in embeddings]

if embeddings and all((e is None) for e in embeddings):
embeddings = None

return {
**d,
"id": d["doc_id"],
"content": content,
"embeddings": embeddings,
}
# get doc without embedding
doc_without_embedding_query = """
SELECT
d.doc_id,
d.developer_id,
d.title,
array_agg(d.content ORDER BY d.index) as content,
array_agg(d.index ORDER BY d.index) as indices,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at
FROM docs d
WHERE d.developer_id = $1
AND d.doc_id = $2
GROUP BY
d.doc_id,
d.developer_id,
d.title,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at
ORDER BY d.created_at DESC
LIMIT 1;
"""


@rewrap_exceptions(common_db_exceptions("doc", ["get"]))
@wrap_in_class(
Doc,
one=True,
transform=transform_get_doc,
transform=transform_doc,
)
@pg_query
@beartype
Expand Down
25 changes: 2 additions & 23 deletions agents-api/agents_api/queries/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
It constructs and executes SQL queries to fetch document details based on various filters.
"""

import json
from typing import Any, Literal
from uuid import UUID

Expand All @@ -13,6 +12,7 @@
from ...autogen.openapi_model import Doc
from ...common.utils.db_exceptions import common_db_exceptions
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
from .utils import transform_doc

# Base query for listing docs with aggregated content and embeddings
base_docs_query = """
Expand Down Expand Up @@ -51,32 +51,11 @@
"""


def transform_list_docs(d: dict) -> dict:
content = d["content"]

embeddings = d["embeddings"]

if isinstance(embeddings, str):
embeddings = json.loads(embeddings)
elif isinstance(embeddings, list) and all(isinstance(e, str) for e in embeddings):
embeddings = [json.loads(e) for e in embeddings]

if embeddings and all((e is None) for e in embeddings):
embeddings = None

return {
**d,
"id": d["doc_id"],
"content": content,
"embeddings": embeddings,
}


@rewrap_exceptions(common_db_exceptions("doc", ["list"]))
@wrap_in_class(
Doc,
one=False,
transform=transform_list_docs,
transform=transform_doc,
)
@pg_query
@beartype
Expand Down
22 changes: 22 additions & 0 deletions agents-api/agents_api/queries/docs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,31 @@ def transform_to_doc_reference(d: dict) -> dict:
metadata = d.pop("metadata")

return {
**d,
"id": id,
"owner": owner,
"snippet": snippet,
"metadata": metadata,
}


def transform_doc(d: dict) -> dict:
content = d["content"]

embeddings = d.get("embeddings") or []

if embeddings:
if isinstance(embeddings, str):
embeddings = json.loads(embeddings) if embeddings.strip() else None
elif isinstance(embeddings, list) and all(isinstance(e, str) for e in embeddings):
embeddings = [json.loads(e) for e in embeddings if e.strip()]

if embeddings and all((e is None) for e in embeddings):
embeddings = None

return {
**d,
"id": d["doc_id"],
"content": content,
"embeddings": embeddings,
}
Loading

0 comments on commit dbde7a8

Please sign in to comment.