Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use exceptions for handling workspace add error #641

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from fastapi import APIRouter, Response
from fastapi.exceptions import HTTPException
from fastapi.routing import APIRoute
from pydantic import ValidationError

from codegate.api import v1_models
from codegate.db.connection import AlreadyExistsError
from codegate.workspaces.crud import WorkspaceCrud

v1 = APIRouter()
Expand Down Expand Up @@ -52,13 +54,19 @@ async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status
async def create_workspace(request: v1_models.CreateWorkspaceRequest):
"""Create a new workspace."""
# Input validation is done in the model
created = await wscrud.add_workspace(request.name)

# TODO: refactor to use a more specific exception
if not created:
raise HTTPException(status_code=400, detail="Failed to create workspace")

return v1_models.Workspace(name=request.name)
try:
created = await wscrud.add_workspace(request.name)
except AlreadyExistsError:
raise HTTPException(status_code=409, detail="Workspace already exists")
except ValidationError:
raise HTTPException(status_code=400,
detail=("Invalid workspace name. "
"Please use only alphanumeric characters and dashes"))
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

if created:
return v1_models.Workspace(name=created.name)


@v1.delete(
Expand Down
4 changes: 2 additions & 2 deletions src/codegate/dashboard/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import structlog
from fastapi import APIRouter, Depends, FastAPI
from fastapi.responses import StreamingResponse
from codegate import __version__

from codegate import __version__
from codegate.dashboard.post_processing import (
parse_get_alert_conversation,
parse_messages_in_conversations,
Expand Down Expand Up @@ -82,7 +82,7 @@ def version_check():
latest_version_stripped = latest_version.lstrip('v')

is_latest: bool = latest_version_stripped == current_version

return {
"current_version": current_version,
"latest_version": latest_version_stripped,
Expand Down
44 changes: 26 additions & 18 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import structlog
from alembic import command as alembic_command
from alembic.config import Config as AlembicConfig
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from sqlalchemy import CursorResult, TextClause, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import create_async_engine

from codegate.db.fim_cache import FimCache
Expand All @@ -30,6 +30,8 @@
alert_queue = asyncio.Queue()
fim_cache = FimCache()

class AlreadyExistsError(Exception):
pass

class DbCodeGate:
_instance = None
Expand Down Expand Up @@ -70,11 +72,11 @@ def __init__(self, sqlite_path: Optional[str] = None):
super().__init__(sqlite_path)

async def _execute_update_pydantic_model(
self, model: BaseModel, sql_command: TextClause
self, model: BaseModel, sql_command: TextClause, should_raise: bool = False
) -> Optional[BaseModel]:
"""Execute an update or insert command for a Pydantic model."""
async with self._async_db_engine.begin() as conn:
try:
try:
async with self._async_db_engine.begin() as conn:
result = await conn.execute(sql_command, model.model_dump())
row = result.first()
if row is None:
Expand All @@ -83,9 +85,11 @@ async def _execute_update_pydantic_model(
# Get the class of the Pydantic object to create a new object
model_class = model.__class__
return model_class(**row._asdict())
except Exception as e:
logger.error(f"Failed to update model: {model}.", error=str(e))
return None
except Exception as e:
logger.error(f"Failed to update model: {model}.", error=str(e))
if should_raise:
raise e
return None

async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
if prompt_params is None:
Expand Down Expand Up @@ -243,11 +247,14 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
logger.error(f"Failed to record context: {context}.", error=str(e))

async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
try:
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)
except ValidationError as e:
logger.error(f"Failed to create workspace with name: {workspace_name}: {str(e)}")
return None
"""Add a new workspace to the DB.

This handles validation and insertion of a new workspace.

It may raise a ValidationError if the workspace name is invalid.
or a AlreadyExistsError if the workspace already exists.
"""
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)

sql = text(
"""
Expand All @@ -256,12 +263,13 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
RETURNING *
"""
)
try:
added_workspace = await self._execute_update_pydantic_model(workspace, sql)
except Exception as e:
logger.error(f"Failed to add workspace: {workspace_name}.", error=str(e))
return None

try:
added_workspace = await self._execute_update_pydantic_model(
workspace, sql, should_raise=True)
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Workspace {workspace_name} already exists.")
return added_workspace

async def update_session(self, session: Session) -> Optional[Session]:
Expand Down
19 changes: 12 additions & 7 deletions src/codegate/pipeline/cli/commands.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from abc import ABC, abstractmethod
from typing import List

from pydantic import ValidationError

from codegate import __version__
from codegate.db.connection import AlreadyExistsError
from codegate.workspaces.crud import WorkspaceCrud


Expand Down Expand Up @@ -69,13 +72,15 @@ async def _add_workspace(self, args: List[str]) -> str:
if not new_workspace_name:
return "Please provide a name. Use `codegate workspace add your_workspace_name`"

workspace_created = await self.workspace_crud.add_workspace(new_workspace_name)
if not workspace_created:
return (
"Something went wrong. Workspace could not be added.\n"
"1. Check if the name is alphanumeric and only contains dashes, and underscores.\n"
"2. Check if the workspace already exists."
)
try:
_ = await self.workspace_crud.add_workspace(new_workspace_name)
except ValidationError:
return "Invalid workspace name: It should be alphanumeric and dashes"
except AlreadyExistsError:
return f"Workspace **{new_workspace_name}** already exists"
except Exception:
return "An error occurred while adding the workspace"

return f"Workspace **{new_workspace_name}** has been added"

async def _activate_workspace(self, args: List[str]) -> str:
Expand Down
11 changes: 7 additions & 4 deletions src/codegate/workspaces/crud.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import datetime
from typing import Optional, Tuple, List
from typing import List, Optional, Tuple

from codegate.db.connection import DbReader, DbRecorder
from codegate.db.models import Session, Workspace, WorkspaceActive, ActiveWorkspace
from codegate.db.models import ActiveWorkspace, Session, Workspace, WorkspaceActive


class WorkspaceCrudError(Exception):
pass

class WorkspaceCrud:

def __init__(self):
self._db_reader = DbReader()

async def add_workspace(self, new_workspace_name: str) -> bool:
async def add_workspace(self, new_workspace_name: str) -> Workspace:
"""
Add a workspace

Expand All @@ -19,7 +22,7 @@ async def add_workspace(self, new_workspace_name: str) -> bool:
"""
db_recorder = DbRecorder()
workspace_created = await db_recorder.add_workspace(new_workspace_name)
return bool(workspace_created)
return workspace_created

async def get_workspaces(self)-> List[WorkspaceActive]:
"""
Expand Down
Loading