Skip to content

Commit

Permalink
feat: Let the user add their own system prompts (#643)
Browse files Browse the repository at this point in the history
* feat: Let the user add their own system prompts

Related: #454

This PR is not ready yet. For the moment it adds the system prompts
to DB and associates it to a workspace.

It's missing to use the system prompt and actually send it to the
LLM

* Finished functionality to add wrkspace system prompt

* separated into it's own command system-prompt

* Raise WorkspaceDoesNotExistError on update system-prompt

* Added show system-prompt command

* comment changes

* unit test fixes

* added some docstrings and mentioned dashboard
  • Loading branch information
aponcedeleonch authored Jan 20, 2025
1 parent 28af062 commit c35d3b1
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 47 deletions.
26 changes: 26 additions & 0 deletions migrations/versions/a692c8b52308_add_workspace_system_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add_workspace_system_prompt
Revision ID: a692c8b52308
Revises: 5c2f3eee5f90
Create Date: 2025-01-17 16:33:58.464223
"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "a692c8b52308"
down_revision: Union[str, None] = "5c2f3eee5f90"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Add column to workspaces table
op.execute("ALTER TABLE workspaces ADD COLUMN system_prompt TEXT DEFAULT NULL;")


def downgrade() -> None:
op.execute("ALTER TABLE workspaces DROP COLUMN system_prompt;")
26 changes: 20 additions & 6 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,15 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
except Exception as e:
logger.error(f"Failed to record context: {context}.", error=str(e))

async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
async def add_workspace(self, workspace_name: str) -> Workspace:
"""Add a new workspace to the DB.
This handles validation and insertion of a new workspace.
It may raise a ValidationError if the workspace name is invalid.
or a AlreadyExistsError if the workspace already exists.
"""
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)

workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name, system_prompt=None)
sql = text(
"""
INSERT INTO workspaces (id, name)
Expand All @@ -275,6 +274,21 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
raise AlreadyExistsError(f"Workspace {workspace_name} already exists.")
return added_workspace

async def update_workspace(self, workspace: Workspace) -> Workspace:
sql = text(
"""
UPDATE workspaces SET
name = :name,
system_prompt = :system_prompt
WHERE id = :id
RETURNING *
"""
)
updated_workspace = await self._execute_update_pydantic_model(
workspace, sql, should_raise=True
)
return updated_workspace

async def update_session(self, session: Session) -> Optional[Session]:
sql = text(
"""
Expand Down Expand Up @@ -392,11 +406,11 @@ async def get_workspaces(self) -> List[WorkspaceActive]:
workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql)
return workspaces

async def get_workspace_by_name(self, name: str) -> List[Workspace]:
async def get_workspace_by_name(self, name: str) -> Optional[Workspace]:
sql = text(
"""
SELECT
id, name
id, name, system_prompt
FROM workspaces
WHERE name = :name
"""
Expand All @@ -422,7 +436,7 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
sql = text(
"""
SELECT
w.id, w.name, s.id as session_id, s.last_update
w.id, w.name, w.system_prompt, s.id as session_id, s.last_update
FROM sessions s
INNER JOIN workspaces w ON w.id = s.active_workspace_id
"""
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Setting(BaseModel):
class Workspace(BaseModel):
id: str
name: str
system_prompt: Optional[str]

@field_validator("name", mode="plain")
@classmethod
Expand Down Expand Up @@ -98,5 +99,6 @@ class WorkspaceActive(BaseModel):
class ActiveWorkspace(BaseModel):
id: str
name: str
system_prompt: Optional[str]
session_id: str
last_update: datetime.datetime
3 changes: 2 additions & 1 deletion src/codegate/pipeline/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
PipelineResult,
PipelineStep,
)
from codegate.pipeline.cli.commands import Version, Workspace
from codegate.pipeline.cli.commands import SystemPrompt, Version, Workspace

HELP_TEXT = """
## CodeGate CLI\n
Expand All @@ -32,6 +32,7 @@ async def codegate_cli(command):
available_commands = {
"version": Version().exec,
"workspace": Workspace().exec,
"system-prompt": SystemPrompt().exec,
}
out_func = available_commands.get(command[0])
if out_func is None:
Expand Down
Loading

0 comments on commit c35d3b1

Please sign in to comment.