Skip to content

Commit

Permalink
Use exceptions for handling workspace add error
Browse files Browse the repository at this point in the history
This stops using the boolean and instead will raise exceptions if
there's an issue adding a workspace. This will help us differentiate if
the operation failed due to a name already being taken, or the name
having invalid characters.

Signed-off-by: Juan Antonio Osorio <[email protected]>
  • Loading branch information
JAORMX committed Jan 17, 2025
1 parent b68186c commit 4cd4a57
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 39 deletions.
20 changes: 13 additions & 7 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from fastapi import APIRouter, Response
from fastapi.exceptions import HTTPException
from fastapi.routing import APIRoute
from pydantic import ValidationError

from codegate.api import v1_models
from codegate.db.connection import AlreadyExistsError
from codegate.workspaces.crud import WorkspaceCrud

v1 = APIRouter()
Expand Down Expand Up @@ -52,13 +54,17 @@ async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status
async def create_workspace(request: v1_models.CreateWorkspaceRequest):
"""Create a new workspace."""
# Input validation is done in the model
created = await wscrud.add_workspace(request.name)

# TODO: refactor to use a more specific exception
if not created:
raise HTTPException(status_code=400, detail="Failed to create workspace")

return v1_models.Workspace(name=request.name)
try:
created = await wscrud.add_workspace(request.name)
except AlreadyExistsError:
raise HTTPException(status_code=409, detail="Workspace already exists")
except ValidationError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

if created:
return v1_models.Workspace(name=created.name)


@v1.delete(
Expand Down
4 changes: 2 additions & 2 deletions src/codegate/dashboard/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import structlog
from fastapi import APIRouter, Depends, FastAPI
from fastapi.responses import StreamingResponse
from codegate import __version__

from codegate import __version__
from codegate.dashboard.post_processing import (
parse_get_alert_conversation,
parse_messages_in_conversations,
Expand Down Expand Up @@ -82,7 +82,7 @@ def version_check():
latest_version_stripped = latest_version.lstrip('v')

is_latest: bool = latest_version_stripped == current_version

return {
"current_version": current_version,
"latest_version": latest_version_stripped,
Expand Down
44 changes: 26 additions & 18 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import structlog
from alembic import command as alembic_command
from alembic.config import Config as AlembicConfig
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from sqlalchemy import CursorResult, TextClause, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import create_async_engine

from codegate.db.fim_cache import FimCache
Expand All @@ -30,6 +30,8 @@
alert_queue = asyncio.Queue()
fim_cache = FimCache()

class AlreadyExistsError(Exception):
pass

class DbCodeGate:
_instance = None
Expand Down Expand Up @@ -70,11 +72,11 @@ def __init__(self, sqlite_path: Optional[str] = None):
super().__init__(sqlite_path)

async def _execute_update_pydantic_model(
self, model: BaseModel, sql_command: TextClause
self, model: BaseModel, sql_command: TextClause, should_raise: bool = False
) -> Optional[BaseModel]:
"""Execute an update or insert command for a Pydantic model."""
async with self._async_db_engine.begin() as conn:
try:
try:
async with self._async_db_engine.begin() as conn:
result = await conn.execute(sql_command, model.model_dump())
row = result.first()
if row is None:
Expand All @@ -83,9 +85,11 @@ async def _execute_update_pydantic_model(
# Get the class of the Pydantic object to create a new object
model_class = model.__class__
return model_class(**row._asdict())
except Exception as e:
logger.error(f"Failed to update model: {model}.", error=str(e))
return None
except Exception as e:
logger.error(f"Failed to update model: {model}.", error=str(e))
if should_raise:
raise e
return None

async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
if prompt_params is None:
Expand Down Expand Up @@ -243,11 +247,14 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
logger.error(f"Failed to record context: {context}.", error=str(e))

async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
try:
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)
except ValidationError as e:
logger.error(f"Failed to create workspace with name: {workspace_name}: {str(e)}")
return None
"""Add a new workspace to the DB.
This handles validation and insertion of a new workspace.
It may raise a ValidationError if the workspace name is invalid.
or a AlreadyExistsError if the workspace already exists.
"""
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)

sql = text(
"""
Expand All @@ -256,12 +263,13 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
RETURNING *
"""
)
try:
added_workspace = await self._execute_update_pydantic_model(workspace, sql)
except Exception as e:
logger.error(f"Failed to add workspace: {workspace_name}.", error=str(e))
return None

try:
added_workspace = await self._execute_update_pydantic_model(
workspace, sql, should_raise=True)
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Workspace {workspace_name} already exists.")
return added_workspace

async def update_session(self, session: Session) -> Optional[Session]:
Expand Down
21 changes: 13 additions & 8 deletions src/codegate/pipeline/cli/commands.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from abc import ABC, abstractmethod
from typing import List

from pydantic import ValidationError

from codegate import __version__
from codegate.db.connection import AlreadyExistsError
from codegate.workspaces.crud import WorkspaceCrud


Expand Down Expand Up @@ -69,14 +72,16 @@ async def _add_workspace(self, args: List[str]) -> str:
if not new_workspace_name:
return "Please provide a name. Use `codegate workspace add your_workspace_name`"

workspace_created = await self.workspace_crud.add_workspace(new_workspace_name)
if not workspace_created:
return (
"Something went wrong. Workspace could not be added.\n"
"1. Check if the name is alphanumeric and only contains dashes, and underscores.\n"
"2. Check if the workspace already exists."
)
return f"Workspace **{new_workspace_name}** has been added"
try:
workspace_created = await self.workspace_crud.add_workspace(new_workspace_name)
except ValidationError as e:
return f"Invalid workspace name: {e}"
except AlreadyExistsError:
return f"Workspace **{new_workspace_name}** already exists"
except Exception:
return "An error occurred while adding the workspace"

return f"Workspace **{workspace_created.name}** has been added"

async def _activate_workspace(self, args: List[str]) -> str:
"""
Expand Down
11 changes: 7 additions & 4 deletions src/codegate/workspaces/crud.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import datetime
from typing import Optional, Tuple, List
from typing import List, Optional, Tuple

from codegate.db.connection import DbReader, DbRecorder
from codegate.db.models import Session, Workspace, WorkspaceActive, ActiveWorkspace
from codegate.db.models import ActiveWorkspace, Session, Workspace, WorkspaceActive


class WorkspaceCrudError(Exception):
pass

class WorkspaceCrud:

def __init__(self):
self._db_reader = DbReader()

async def add_workspace(self, new_workspace_name: str) -> bool:
async def add_workspace(self, new_workspace_name: str) -> Workspace:
"""
Add a workspace
Expand All @@ -19,7 +22,7 @@ async def add_workspace(self, new_workspace_name: str) -> bool:
"""
db_recorder = DbRecorder()
workspace_created = await db_recorder.add_workspace(new_workspace_name)
return bool(workspace_created)
return workspace_created

async def get_workspaces(self)-> List[WorkspaceActive]:
"""
Expand Down

0 comments on commit 4cd4a57

Please sign in to comment.