diff --git a/.gitignore b/.gitignore index ae1c1eed..a4d6dc2d 100644 --- a/.gitignore +++ b/.gitignore @@ -59,3 +59,7 @@ npm-debug.log* yarn-debug.log* yarn-error.log* pnpm-debug.log* + + +# Local db files +opengpts.db \ No newline at end of file diff --git a/README.md b/README.md index a5b4beb8..fc26e8d3 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Because this is open source, if you do not like those architectures or want to m ## Quickstart with Docker This project supports a Docker-based setup, streamlining installation and execution. It automatically builds images for -the frontend and backend and sets up Postgres using docker-compose. +the frontend and backend and sets up either SQLite or Postgres using docker-compose. 1. **Prerequisites:** @@ -71,14 +71,18 @@ the frontend and backend and sets up Postgres using docker-compose. 4. **Run with Docker Compose:** - In the root directory of the project, execute: + In the root directory of the project, execute one of the following commands to start the services: - ``` + ```shell + # For SQLite based setup docker compose up + + # For Postgres based setup + docker compose -f docker-compose.pg.yml up ``` This command builds the Docker images for the frontend and backend from their respective Dockerfiles and starts all - necessary services, including Postgres. + necessary services, including SQLite/Postgres. 5. **Access the Application:** With the services running, access the frontend at [http://localhost:5173](http://localhost:5173), substituting `5173` with the @@ -87,8 +91,12 @@ the frontend and backend and sets up Postgres using docker-compose. 6. **Rebuilding After Changes:** If you make changes to either the frontend or backend, rebuild the Docker images to reflect these changes. Run: - ``` + ```shell + # For SQLite based setup docker compose up --build + + # For Postgres based setup + docker compose -f docker-compose.pg.yml up --build ``` This command rebuilds the images with your latest changes and restarts the services. @@ -115,6 +123,16 @@ pip install poetry pip install langchain-community ``` +### Persistence Layer + +The backend supports using SQLite and Postgres for saving agent configurations and chat message history. Set the `STORAGE_TYPE` environment variable to `sqlite` or `postgres`: + +```shell +export STORAGE_TYPE=postgres +``` + +SQLite requires no configuration (apart from [running migrations](####migrations)). The database file will be created in the `backend` directory. However, to configure and use Postgres, follow the instructions below: + **Install Postgres and the Postgres Vector Extension** ``` brew install postgresql pgvector @@ -123,8 +141,7 @@ brew services start postgresql **Configure persistence layer** -The backend uses Postgres for saving agent configurations and chat message history. -In order to use this, you need to set the following environment variables: +Set the following environment variables: ```shell export POSTGRES_HOST=localhost @@ -148,9 +165,9 @@ psql -d opengpts CREATE ROLE postgres WITH LOGIN SUPERUSER CREATEDB CREATEROLE; ``` -**Install Golang Migrate** +#### Migrations -Database migrations are managed with [golang-migrate](https://github.com/golang-migrate/migrate). +Database migrations for both SQLite and Postgres are managed with [golang-migrate](https://github.com/golang-migrate/migrate). On MacOS, you can install it with `brew install golang-migrate`. Instructions for other OSs or the Golang toolchain, can be found [here](https://github.com/golang-migrate/migrate/blob/master/cmd/migrate/README.md#installation). @@ -160,7 +177,7 @@ Once `golang-migrate` is installed, you can run all the migrations with: make migrate ``` -This will enable the backend to use Postgres as a vector database and create the initial tables. +This will create the initial tables. **Install backend dependencies** diff --git a/backend/Makefile b/backend/Makefile index 3e8dc5b2..a4502f35 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -17,11 +17,17 @@ start: poetry run uvicorn app.server:app --reload --port 8100 migrate: - migrate -database postgres://$(POSTGRES_USER):$(POSTGRES_PASSWORD)@$(POSTGRES_HOST):$(POSTGRES_PORT)/$(POSTGRES_DB)?sslmode=disable -path ./migrations up + ifeq ($(STORAGE_TYPE),postgres) + @echo "Running Postgres migrations..." + migrate -database postgres://$(POSTGRES_USER):$(POSTGRES_PASSWORD)@$(POSTGRES_HOST):$(POSTGRES_PORT)/$(POSTGRES_DB)?sslmode=disable -path ./migrations/postgres up + else + @echo "Running SQLite migrations..." + migrate -database sqlite3://$(PWD)/opengpts.db -path ./migrations/sqlite up + endif test: # We need to update handling of env variables for tests - YDC_API_KEY=placeholder OPENAI_API_KEY=placeholder poetry run pytest $(TEST_FILE) + STORAGE_TYPE=postgres YDC_API_KEY=placeholder OPENAI_API_KEY=placeholder poetry run pytest $(TEST_FILE) test_watch: diff --git a/backend/app/agent.py b/backend/app/agent.py index 2658c9b7..45b92f2d 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -14,7 +14,7 @@ from app.agent_types.tools_agent import get_tools_agent_executor from app.agent_types.xml_agent import get_xml_agent_executor from app.chatbot import get_chatbot_executor -from app.checkpoint import PostgresCheckpoint +from app.checkpoint import Checkpointer from app.llms import ( get_anthropic_llm, get_google_llm, @@ -73,7 +73,7 @@ class AgentType(str, Enum): DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." -CHECKPOINTER = PostgresCheckpoint(serde=pickle, at=CheckpointAt.END_OF_STEP) +CHECKPOINTER = Checkpointer(serde=pickle, at=CheckpointAt.END_OF_STEP) def get_agent_executor( diff --git a/backend/app/api/assistants.py b/backend/app/api/assistants.py index 8458b48c..1374c927 100644 --- a/backend/app/api/assistants.py +++ b/backend/app/api/assistants.py @@ -4,9 +4,9 @@ from fastapi import APIRouter, HTTPException, Path from pydantic import BaseModel, Field -import app.storage as storage from app.auth.handlers import AuthedUser from app.schema import Assistant +from app.storage.storage import storage router = APIRouter() diff --git a/backend/app/api/runs.py b/backend/app/api/runs.py index 2d4a83da..44a97a9e 100644 --- a/backend/app/api/runs.py +++ b/backend/app/api/runs.py @@ -13,7 +13,7 @@ from app.agent import agent from app.auth.handlers import AuthedUser -from app.storage import get_assistant, get_thread +from app.storage.storage import storage from app.stream import astream_state, to_sse router = APIRouter() @@ -30,11 +30,11 @@ class CreateRunPayload(BaseModel): async def _run_input_and_config(payload: CreateRunPayload, user_id: str): - thread = await get_thread(user_id, payload.thread_id) + thread = await storage.get_thread(user_id, payload.thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") - assistant = await get_assistant(user_id, str(thread["assistant_id"])) + assistant = await storage.get_assistant(user_id, str(thread["assistant_id"])) if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") diff --git a/backend/app/api/threads.py b/backend/app/api/threads.py index dd6441b6..7ab937df 100644 --- a/backend/app/api/threads.py +++ b/backend/app/api/threads.py @@ -5,9 +5,9 @@ from langchain.schema.messages import AnyMessage from pydantic import BaseModel, Field -import app.storage as storage from app.auth.handlers import AuthedUser from app.schema import Thread +from app.storage.storage import storage router = APIRouter() diff --git a/backend/app/auth/handlers.py b/backend/app/auth/handlers.py index 630d45ff..1486f4f9 100644 --- a/backend/app/auth/handlers.py +++ b/backend/app/auth/handlers.py @@ -7,9 +7,9 @@ from fastapi import Depends, HTTPException, Request from fastapi.security.http import HTTPBearer -import app.storage as storage from app.auth.settings import AuthType, settings from app.schema import User +from app.storage.storage import storage class AuthHandler(ABC): diff --git a/backend/app/checkpoint.py b/backend/app/checkpoint.py index abb3e9e3..5404a1cd 100644 --- a/backend/app/checkpoint.py +++ b/backend/app/checkpoint.py @@ -2,9 +2,11 @@ from datetime import datetime from typing import AsyncIterator, Optional +import aiosqlite from langchain_core.messages import BaseMessage from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig from langgraph.checkpoint import BaseCheckpointSaver +from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver from langgraph.checkpoint.base import ( Checkpoint, CheckpointAt, @@ -13,7 +15,9 @@ SerializerProtocol, ) -from app.lifespan import get_pg_pool +from app.storage.settings import StorageType +from app.storage.settings import settings as storage_settings +from app.storage.storage import storage def loads(value: bytes) -> Checkpoint: @@ -24,7 +28,7 @@ def loads(value: bytes) -> Checkpoint: return loaded -class PostgresCheckpoint(BaseCheckpointSaver): +class PostgresCheckpointer(BaseCheckpointSaver): def __init__( self, *, @@ -54,7 +58,7 @@ def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig: raise NotImplementedError async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]: - async with get_pg_pool().acquire() as db, db.transaction(): + async with storage.get_pool().acquire() as db, db.transaction(): thread_id = config["configurable"]["thread_id"] async for value in db.cursor( "SELECT checkpoint, thread_ts, parent_ts FROM checkpoints WHERE thread_id = $1 ORDER BY thread_ts DESC", @@ -81,7 +85,7 @@ async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]: async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: thread_id = config["configurable"]["thread_id"] thread_ts = config["configurable"].get("thread_ts") - async with get_pg_pool().acquire() as conn: + async with storage.get_pool().acquire() as conn: if thread_ts: if value := await conn.fetchrow( "SELECT checkpoint, parent_ts FROM checkpoints WHERE thread_id = $1 AND thread_ts = $2", @@ -125,7 +129,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: async def aput(self, config: RunnableConfig, checkpoint: Checkpoint) -> None: thread_id = config["configurable"]["thread_id"] - async with get_pg_pool().acquire() as conn: + async with storage.get_pool().acquire() as conn: await conn.execute( """ INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint) @@ -145,3 +149,43 @@ async def aput(self, config: RunnableConfig, checkpoint: Checkpoint) -> None: "thread_ts": checkpoint["ts"], } } + + +class SqliteCheckpointer(AsyncSqliteSaver): + conn: aiosqlite.Connection = None + + def __init__( + self, + *, + serde: Optional[SerializerProtocol] = None, + at: Optional[CheckpointAt] = None, + ) -> None: + super().__init__(conn=None, serde=serde, at=at) + + @property + def config_specs(self) -> list[ConfigurableFieldSpec]: + return [ + ConfigurableFieldSpec( + id="thread_id", + annotation=Optional[str], + name="Thread ID", + description=None, + default=None, + is_shared=True, + ), + CheckpointThreadTs, + ] + + async def setup(self) -> None: + if self.is_setup: + return + self.conn = storage.get_conn() + self.is_setup = True + + +if storage_settings.storage_type == StorageType.POSTGRES: + Checkpointer = PostgresCheckpointer +elif storage_settings.storage_type == StorageType.SQLITE: + Checkpointer = SqliteCheckpointer +else: + raise NotImplementedError() diff --git a/backend/app/lifespan.py b/backend/app/lifespan.py index 8e15f139..a6ef3f80 100644 --- a/backend/app/lifespan.py +++ b/backend/app/lifespan.py @@ -1,41 +1,12 @@ -import os from contextlib import asynccontextmanager -import asyncpg -import orjson from fastapi import FastAPI -_pg_pool = None - - -def get_pg_pool() -> asyncpg.pool.Pool: - return _pg_pool - - -async def _init_connection(conn) -> None: - await conn.set_type_codec( - "json", - encoder=lambda v: orjson.dumps(v).decode(), - decoder=orjson.loads, - schema="pg_catalog", - ) - await conn.set_type_codec( - "uuid", encoder=lambda v: str(v), decoder=lambda v: v, schema="pg_catalog" - ) +from app.storage.storage import storage @asynccontextmanager async def lifespan(app: FastAPI): - global _pg_pool - - _pg_pool = await asyncpg.create_pool( - database=os.environ["POSTGRES_DB"], - user=os.environ["POSTGRES_USER"], - password=os.environ["POSTGRES_PASSWORD"], - host=os.environ["POSTGRES_HOST"], - port=os.environ["POSTGRES_PORT"], - init=_init_connection, - ) + await storage.setup() yield - await _pg_pool.close() - _pg_pool = None + await storage.teardown() diff --git a/backend/app/server.py b/backend/app/server.py index 9dd65743..b51c7510 100644 --- a/backend/app/server.py +++ b/backend/app/server.py @@ -7,10 +7,10 @@ from fastapi.exceptions import HTTPException from fastapi.staticfiles import StaticFiles -import app.storage as storage from app.api import router as api_router from app.auth.handlers import AuthedUser from app.lifespan import lifespan +from app.storage.storage import storage from app.upload import convert_ingestion_input_to_blob, ingest_runnable logger = logging.getLogger(__name__) diff --git a/backend/app/storage.py b/backend/app/storage.py deleted file mode 100644 index edfbc585..00000000 --- a/backend/app/storage.py +++ /dev/null @@ -1,204 +0,0 @@ -from datetime import datetime, timezone -from typing import Any, List, Optional, Sequence, Union - -from langchain_core.messages import AnyMessage -from langchain_core.runnables import RunnableConfig - -from app.agent import agent -from app.lifespan import get_pg_pool -from app.schema import Assistant, Thread, User - - -async def list_assistants(user_id: str) -> List[Assistant]: - """List all assistants for the current user.""" - async with get_pg_pool().acquire() as conn: - return await conn.fetch("SELECT * FROM assistant WHERE user_id = $1", user_id) - - -async def get_assistant(user_id: str, assistant_id: str) -> Optional[Assistant]: - """Get an assistant by ID.""" - async with get_pg_pool().acquire() as conn: - return await conn.fetchrow( - "SELECT * FROM assistant WHERE assistant_id = $1 AND (user_id = $2 OR public IS true)", - assistant_id, - user_id, - ) - - -async def list_public_assistants() -> List[Assistant]: - """List all the public assistants.""" - async with get_pg_pool().acquire() as conn: - return await conn.fetch(("SELECT * FROM assistant WHERE public IS true;")) - - -async def put_assistant( - user_id: str, assistant_id: str, *, name: str, config: dict, public: bool = False -) -> Assistant: - """Modify an assistant. - - Args: - user_id: The user ID. - assistant_id: The assistant ID. - name: The assistant name. - config: The assistant config. - public: Whether the assistant is public. - - Returns: - return the assistant model if no exception is raised. - """ - updated_at = datetime.now(timezone.utc) - async with get_pg_pool().acquire() as conn: - async with conn.transaction(): - await conn.execute( - ( - "INSERT INTO assistant (assistant_id, user_id, name, config, updated_at, public) VALUES ($1, $2, $3, $4, $5, $6) " - "ON CONFLICT (assistant_id) DO UPDATE SET " - "user_id = EXCLUDED.user_id, " - "name = EXCLUDED.name, " - "config = EXCLUDED.config, " - "updated_at = EXCLUDED.updated_at, " - "public = EXCLUDED.public;" - ), - assistant_id, - user_id, - name, - config, - updated_at, - public, - ) - return { - "assistant_id": assistant_id, - "user_id": user_id, - "name": name, - "config": config, - "updated_at": updated_at, - "public": public, - } - - -async def list_threads(user_id: str) -> List[Thread]: - """List all threads for the current user.""" - async with get_pg_pool().acquire() as conn: - return await conn.fetch("SELECT * FROM thread WHERE user_id = $1", user_id) - - -async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]: - """Get a thread by ID.""" - async with get_pg_pool().acquire() as conn: - return await conn.fetchrow( - "SELECT * FROM thread WHERE thread_id = $1 AND user_id = $2", - thread_id, - user_id, - ) - - -async def get_thread_state(*, user_id: str, thread_id: str, assistant_id: str): - """Get state for a thread.""" - assistant = await get_assistant(user_id, assistant_id) - state = await agent.aget_state( - { - "configurable": { - **assistant["config"]["configurable"], - "thread_id": thread_id, - "assistant_id": assistant_id, - } - } - ) - return { - "values": state.values, - "next": state.next, - } - - -async def update_thread_state( - config: RunnableConfig, - values: Union[Sequence[AnyMessage], dict[str, Any]], - *, - user_id: str, - assistant_id: str, -): - """Add state to a thread.""" - assistant = await get_assistant(user_id, assistant_id) - await agent.aupdate_state( - { - "configurable": { - **assistant["config"]["configurable"], - **config["configurable"], - "assistant_id": assistant_id, - } - }, - values, - ) - - -async def get_thread_history(*, user_id: str, thread_id: str, assistant_id: str): - """Get the history of a thread.""" - assistant = await get_assistant(user_id, assistant_id) - return [ - { - "values": c.values, - "next": c.next, - "config": c.config, - "parent": c.parent_config, - } - async for c in agent.aget_state_history( - { - "configurable": { - **assistant["config"]["configurable"], - "thread_id": thread_id, - "assistant_id": assistant_id, - } - } - ) - ] - - -async def put_thread( - user_id: str, thread_id: str, *, assistant_id: str, name: str -) -> Thread: - """Modify a thread.""" - updated_at = datetime.now(timezone.utc) - async with get_pg_pool().acquire() as conn: - await conn.execute( - ( - "INSERT INTO thread (thread_id, user_id, assistant_id, name, updated_at) VALUES ($1, $2, $3, $4, $5) " - "ON CONFLICT (thread_id) DO UPDATE SET " - "user_id = EXCLUDED.user_id," - "assistant_id = EXCLUDED.assistant_id, " - "name = EXCLUDED.name, " - "updated_at = EXCLUDED.updated_at;" - ), - thread_id, - user_id, - assistant_id, - name, - updated_at, - ) - return { - "thread_id": thread_id, - "user_id": user_id, - "assistant_id": assistant_id, - "name": name, - "updated_at": updated_at, - } - - -async def get_or_create_user(sub: str) -> tuple[User, bool]: - """Returns a tuple of the user and a boolean indicating whether the user was created.""" - async with get_pg_pool().acquire() as conn: - if user := await conn.fetchrow('SELECT * FROM "user" WHERE sub = $1', sub): - return user, False - user = await conn.fetchrow( - 'INSERT INTO "user" (sub) VALUES ($1) RETURNING *', sub - ) - return user, True - - -async def delete_thread(user_id: str, thread_id: str): - """Delete a thread by ID.""" - async with get_pg_pool().acquire() as conn: - await conn.execute( - "DELETE FROM thread WHERE thread_id = $1 AND user_id = $2", - thread_id, - user_id, - ) diff --git a/backend/app/storage/__init__.py b/backend/app/storage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/app/storage/base.py b/backend/app/storage/base.py new file mode 100644 index 00000000..9b97b815 --- /dev/null +++ b/backend/app/storage/base.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Sequence, Union + +from langchain_core.messages import AnyMessage + +from app.schema import Assistant, Thread, User + + +class BaseStorage(ABC): + @abstractmethod + async def setup(self) -> None: + """Setup the storage.""" + + @abstractmethod + async def teardown(self) -> None: + """Teardown the storage.""" + + @abstractmethod + async def list_assistants(self, user_id: str) -> list[Assistant]: + """List all assistants for the current user.""" + + @abstractmethod + async def get_assistant( + self, user_id: str, assistant_id: str + ) -> Optional[Assistant]: + """Get an assistant by ID.""" + + @abstractmethod + async def list_public_assistants( + self, assistant_ids: Sequence[str] + ) -> list[Assistant]: + """List all the public assistants.""" + + @abstractmethod + async def put_assistant( + self, + user_id: str, + assistant_id: str, + *, + name: str, + config: dict, + public: bool = False, + ) -> Assistant: + """Modify an assistant.""" + + @abstractmethod + async def list_threads(self, user_id: str) -> list[Thread]: + """List all threads for the current user.""" + + @abstractmethod + async def get_thread(self, user_id: str, thread_id: str) -> Optional[Thread]: + """Get a thread by ID.""" + + @abstractmethod + async def get_thread_state(self, user_id: str, thread_id: str): + """Get state for a thread.""" + + @abstractmethod + async def update_thread_state( + self, + user_id: str, + thread_id: str, + values: Union[Sequence[AnyMessage], dict[str, Any]], + ): + """Add state to a thread.""" + + @abstractmethod + async def get_thread_history(self, user_id: str, thread_id: str): + """Get the history of a thread.""" + + @abstractmethod + async def put_thread( + self, user_id: str, thread_id: str, *, assistant_id: str, name: str + ) -> Thread: + """Modify a thread.""" + + @abstractmethod + async def get_or_create_user(self, sub: str) -> tuple[User, bool]: + """Returns a tuple of the user and a boolean indicating whether the user was created.""" + + @abstractmethod + async def delete_thread(self, user_id: str, thread_id: str) -> None: + """Delete a thread by ID.""" diff --git a/backend/app/storage/postgres.py b/backend/app/storage/postgres.py new file mode 100644 index 00000000..75193fa3 --- /dev/null +++ b/backend/app/storage/postgres.py @@ -0,0 +1,256 @@ +from datetime import datetime, timezone +from typing import Any, List, Optional, Sequence, Union + +import asyncpg +import orjson +from langchain_core.messages import AnyMessage +from langchain_core.runnables import RunnableConfig + +from app.schema import Assistant, Thread, User +from app.storage.base import BaseStorage +from app.storage.settings import settings as storage_settings + + +class PostgresStorage(BaseStorage): + _pool: asyncpg.pool.Pool = None + _is_setup: bool = False + + async def setup(self) -> None: + if self._is_setup: + return + self._pool = await asyncpg.create_pool( + database=storage_settings.postgres.db, + user=storage_settings.postgres.user, + password=storage_settings.postgres.password, + host=storage_settings.postgres.host, + port=storage_settings.postgres.port, + init=self._init_connection, + ) + self._is_setup = True + + async def teardown(self) -> None: + await self._pool.close() + self._pool = None + self._is_setup = False + + async def _init_connection(self, conn) -> None: + await conn.set_type_codec( + "json", + encoder=lambda v: orjson.dumps(v).decode(), + decoder=orjson.loads, + schema="pg_catalog", + ) + await conn.set_type_codec( + "uuid", encoder=lambda v: str(v), decoder=lambda v: v, schema="pg_catalog" + ) + + def get_pool(self) -> asyncpg.pool.Pool: + if not self._is_setup: + raise RuntimeError("Storage is not set up.") + return self._pool + + async def list_assistants(self, user_id: str) -> List[Assistant]: + """List all assistants for the current user.""" + async with self.get_pool().acquire() as conn: + return await conn.fetch( + "SELECT * FROM assistant WHERE user_id = $1", user_id + ) + + async def get_assistant( + self, user_id: str, assistant_id: str + ) -> Optional[Assistant]: + """Get an assistant by ID.""" + async with self.get_pool().acquire() as conn: + return await conn.fetchrow( + "SELECT * FROM assistant WHERE assistant_id = $1 AND (user_id = $2 OR public IS true)", + assistant_id, + user_id, + ) + + async def list_public_assistants(self) -> List[Assistant]: + """List all the public assistants.""" + async with self.get_pool().acquire() as conn: + return await conn.fetch(("SELECT * FROM assistant WHERE public IS true;")) + + async def put_assistant( + self, + user_id: str, + assistant_id: str, + *, + name: str, + config: dict, + public: bool = False, + ) -> Assistant: + """Modify an assistant. + + Args: + user_id: The user ID. + assistant_id: The assistant ID. + name: The assistant name. + config: The assistant config. + public: Whether the assistant is public. + + Returns: + return the assistant model if no exception is raised. + """ + updated_at = datetime.now(timezone.utc) + async with self.get_pool().acquire() as conn: + async with conn.transaction(): + await conn.execute( + ( + "INSERT INTO assistant (assistant_id, user_id, name, config, updated_at, public) VALUES ($1, $2, $3, $4, $5, $6) " + "ON CONFLICT (assistant_id) DO UPDATE SET " + "user_id = EXCLUDED.user_id, " + "name = EXCLUDED.name, " + "config = EXCLUDED.config, " + "updated_at = EXCLUDED.updated_at, " + "public = EXCLUDED.public;" + ), + assistant_id, + user_id, + name, + config, + updated_at, + public, + ) + return { + "assistant_id": assistant_id, + "user_id": user_id, + "name": name, + "config": config, + "updated_at": updated_at, + "public": public, + } + + async def list_threads(self, user_id: str) -> List[Thread]: + """List all threads for the current user.""" + async with self.get_pool().acquire() as conn: + return await conn.fetch("SELECT * FROM thread WHERE user_id = $1", user_id) + + async def get_thread(self, user_id: str, thread_id: str) -> Optional[Thread]: + """Get a thread by ID.""" + async with self.get_pool().acquire() as conn: + return await conn.fetchrow( + "SELECT * FROM thread WHERE thread_id = $1 AND user_id = $2", + thread_id, + user_id, + ) + + async def get_thread_state( + self, *, user_id: str, thread_id: str, assistant_id: str + ): + """Get state for a thread.""" + from app.agent import agent + + assistant = await self.get_assistant(user_id, assistant_id) + state = await agent.aget_state( + { + "configurable": { + **assistant["config"]["configurable"], + "thread_id": thread_id, + "assistant_id": assistant_id, + } + } + ) + return { + "values": state.values, + "next": state.next, + } + + async def update_thread_state( + self, + config: RunnableConfig, + values: Union[Sequence[AnyMessage], dict[str, Any]], + *, + user_id: str, + assistant_id: str, + ): + """Add state to a thread.""" + from app.agent import agent + + assistant = await self.get_assistant(user_id, assistant_id) + await agent.aupdate_state( + { + "configurable": { + **assistant["config"]["configurable"], + **config["configurable"], + "assistant_id": assistant_id, + } + }, + values, + ) + + async def get_thread_history( + self, *, user_id: str, thread_id: str, assistant_id: str + ): + """Get the history of a thread.""" + from app.agent import agent + + assistant = await self.get_assistant(user_id, assistant_id) + return [ + { + "values": c.values, + "next": c.next, + "config": c.config, + "parent": c.parent_config, + } + async for c in agent.aget_state_history( + { + "configurable": { + **assistant["config"]["configurable"], + "thread_id": thread_id, + "assistant_id": assistant_id, + } + } + ) + ] + + async def put_thread( + self, user_id: str, thread_id: str, *, assistant_id: str, name: str + ) -> Thread: + """Modify a thread.""" + updated_at = datetime.now(timezone.utc) + async with self.get_pool().acquire() as conn: + await conn.execute( + ( + "INSERT INTO thread (thread_id, user_id, assistant_id, name, updated_at) VALUES ($1, $2, $3, $4, $5) " + "ON CONFLICT (thread_id) DO UPDATE SET " + "user_id = EXCLUDED.user_id," + "assistant_id = EXCLUDED.assistant_id, " + "name = EXCLUDED.name, " + "updated_at = EXCLUDED.updated_at;" + ), + thread_id, + user_id, + assistant_id, + name, + updated_at, + ) + return { + "thread_id": thread_id, + "user_id": user_id, + "assistant_id": assistant_id, + "name": name, + "updated_at": updated_at, + } + + async def get_or_create_user(self, sub: str) -> tuple[User, bool]: + """Returns a tuple of the user and a boolean indicating whether the user was created.""" + async with self.get_pool().acquire() as conn: + user = await conn.fetchrow( + 'INSERT INTO "user" (sub) VALUES ($1) ON CONFLICT (sub) DO NOTHING RETURNING *', + sub, + ) + if user: + return user, True + user = await conn.fetchrow('SELECT * FROM "user" WHERE sub = $1', sub) + return user, False + + async def delete_thread(self, user_id: str, thread_id: str) -> None: + """Delete a thread by ID.""" + async with self.get_pool().acquire() as conn: + await conn.execute( + "DELETE FROM thread WHERE thread_id = $1 AND user_id = $2", + thread_id, + user_id, + ) diff --git a/backend/app/storage/settings.py b/backend/app/storage/settings.py new file mode 100644 index 00000000..a5ffee12 --- /dev/null +++ b/backend/app/storage/settings.py @@ -0,0 +1,27 @@ +from enum import Enum + +from pydantic import BaseSettings + + +class StorageType(Enum): + POSTGRES = "postgres" + SQLITE = "sqlite" + + +class PostgresSettings(BaseSettings): + host: str + port: int + db: str + user: str + password: str + + class Config: + env_prefix = "postgres_" + + +class Settings(BaseSettings): + storage_type: StorageType = StorageType.SQLITE + postgres: PostgresSettings = PostgresSettings() + + +settings = Settings() diff --git a/backend/app/storage/sqlite.py b/backend/app/storage/sqlite.py new file mode 100644 index 00000000..48d2d44e --- /dev/null +++ b/backend/app/storage/sqlite.py @@ -0,0 +1,291 @@ +import json +from datetime import datetime, timezone +from typing import Any, Optional, Sequence, Union +from uuid import uuid4 + +import aiosqlite +from langchain_core.messages import AnyMessage +from langchain_core.runnables import RunnableConfig + +from app.schema import Assistant, Thread, User +from app.storage.base import BaseStorage + + +def _deserialize_assistant(row: aiosqlite.Row) -> Assistant: + """Deserialize an assistant from a SQLite row.""" + return { + "assistant_id": row["assistant_id"], + "user_id": row["user_id"], + "name": row["name"], + "config": json.loads(row["config"]), + "updated_at": datetime.fromisoformat(row["updated_at"]), + "public": bool(row["public"]), + } + + +def _deserialize_thread(row: aiosqlite.Row) -> Thread: + """Deserialize a thread from a SQLite row.""" + return { + "thread_id": row["thread_id"], + "user_id": row["user_id"], + "assistant_id": row["assistant_id"], + "name": row["name"], + "updated_at": datetime.fromisoformat(row["updated_at"]), + } + + +def _deserialize_user(row: aiosqlite.Row) -> User: + """Deserialize a user from a SQLite row.""" + return { + "user_id": row["user_id"], + "sub": row["sub"], + "created_at": datetime.fromisoformat(row["created_at"]), + } + + +class SqliteStorage(BaseStorage): + _conn: aiosqlite.Connection = None + _is_setup: bool = False + + async def setup(self) -> None: + if self._is_setup: + return + self._conn = await aiosqlite.connect("opengpts.db") + self._conn.row_factory = aiosqlite.Row + await self._conn.execute("pragma journal_mode=wal") + self._is_setup = True + + # TODO remove + await self._conn.set_trace_callback(print) + + async def teardown(self) -> None: + await self._conn.close() + self._conn = None + self._is_setup = False + + def get_conn(self) -> aiosqlite.Connection: + if not self._is_setup: + raise RuntimeError("Storage is not set up.") + return self._conn + + async def list_assistants(self, user_id: str) -> list[Assistant]: + """List all assistants for the current user.""" + async with self.get_conn().cursor() as cur: + await cur.execute("SELECT * FROM assistant WHERE user_id = ?", (user_id,)) + rows = await cur.fetchall() + return [_deserialize_assistant(row) for row in rows] + + async def get_assistant( + self, user_id: str, assistant_id: str + ) -> Optional[Assistant]: + """Get an assistant by ID.""" + async with self.get_conn().cursor() as cur: + await cur.execute( + "SELECT * FROM assistant WHERE assistant_id = ? AND (user_id = ? OR public = 1)", + (assistant_id, user_id), + ) + row = await cur.fetchone() + return _deserialize_assistant(row) if row else None + + async def list_public_assistants(self) -> list[Assistant]: + """List all the public assistants.""" + async with self.get_conn().cursor() as cur: + await cur.execute("SELECT * FROM assistant WHERE public = 1") + rows = await cur.fetchall() + return [_deserialize_assistant(row) for row in rows] + + async def put_assistant( + self, + user_id: str, + assistant_id: str, + *, + name: str, + config: dict, + public: bool = False, + ) -> Assistant: + """Modify an assistant.""" + updated_at = datetime.now(timezone.utc) + conn = self.get_conn() + async with conn.cursor() as cur: + await cur.execute( + """ + INSERT INTO assistant (assistant_id, user_id, name, config, updated_at, public) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(assistant_id) + DO UPDATE SET + user_id = EXCLUDED.user_id, + name = EXCLUDED.name, + config = EXCLUDED.config, + updated_at = EXCLUDED.updated_at, + public = EXCLUDED.public + """, + ( + assistant_id, + user_id, + name, + json.dumps(config), + updated_at.isoformat(), + public, + ), + ) + await conn.commit() + return { + "assistant_id": assistant_id, + "user_id": user_id, + "name": name, + "config": config, + "updated_at": updated_at, + "public": public, + } + + async def list_threads(self, user_id: str) -> list[Thread]: + """List all threads for the current user.""" + async with self.get_conn().cursor() as cur: + await cur.execute("SELECT * FROM thread WHERE user_id = ?", (user_id,)) + rows = await cur.fetchall() + return [_deserialize_thread(row) for row in rows] + + async def get_thread(self, user_id: str, thread_id: str) -> Optional[Thread]: + """Get a thread by ID.""" + async with self.get_conn().cursor() as cur: + await cur.execute( + "SELECT * FROM thread WHERE thread_id = ? AND user_id = ?", + (thread_id, user_id), + ) + row = await cur.fetchone() + return _deserialize_thread(row) if row else None + + async def get_thread_state( + self, *, user_id: str, thread_id: str, assistant_id: str + ): + """Get state for a thread.""" + from app.agent import agent + + assistant = await self.get_assistant(user_id, assistant_id) + state = await agent.aget_state( + { + "configurable": { + **assistant["config"]["configurable"], + "thread_id": thread_id, + "assistant_id": assistant_id, + } + } + ) + return {"values": state.values, "next": state.next} + + async def update_thread_state( + self, + config: RunnableConfig, + values: Union[Sequence[AnyMessage], dict[str, Any]], + *, + user_id: str, + assistant_id: str, + ): + """Add state to a thread.""" + from app.agent import agent + + assistant = await self.get_assistant(user_id, assistant_id) + await agent.aupdate_state( + { + "configurable": { + **assistant["config"]["configurable"], + **config["configurable"], + "assistant_id": assistant_id, + } + }, + values, + ) + + async def get_thread_history( + self, *, user_id: str, thread_id: str, assistant_id: str + ): + """Get the history of a thread.""" + from app.agent import agent + + assistant = await self.get_assistant(user_id, assistant_id) + return [ + { + "values": c.values, + "next": c.next, + "config": c.config, + "parent": c.parent_config, + } + async for c in agent.aget_state_history( + { + "configurable": { + **assistant["config"]["configurable"], + "thread_id": thread_id, + "assistant_id": assistant_id, + } + } + ) + ] + + async def put_thread( + self, user_id: str, thread_id: str, *, assistant_id: str, name: str + ) -> Thread: + """Modify a thread.""" + updated_at = datetime.now(timezone.utc) + conn = self.get_conn() + async with conn.cursor() as cur: + await cur.execute( + """ + INSERT INTO thread (thread_id, user_id, assistant_id, name, updated_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(thread_id) + DO UPDATE SET + user_id = EXCLUDED.user_id, + assistant_id = EXCLUDED.assistant_id, + name = EXCLUDED.name, + updated_at = EXCLUDED.updated_at + """, + (thread_id, user_id, assistant_id, name, updated_at.isoformat()), + ) + await conn.commit() + return { + "thread_id": thread_id, + "user_id": user_id, + "assistant_id": assistant_id, + "name": name, + "updated_at": updated_at, + } + + async def get_or_create_user(self, sub: str) -> tuple[User, bool]: + """Returns a tuple of the user and a boolean indicating whether the user was created.""" + conn = self.get_conn() + async with conn.cursor() as cur: + # Start a write transaction to avoid the unique contraint error due to + # concurrent inserts. + # TODO worked when connection wasn't shared across app + await cur.execute("BEGIN EXCLUSIVE") + await cur.execute('SELECT * FROM "user" WHERE sub = ?', (sub,)) + row = await cur.fetchone() + if row: + # Since we are using a single connection in the whole application, + # we can't leave the transaction open, so we need to commit it here. + await conn.commit() + return _deserialize_user(row), False + + # SQLite doesn't support RETURNING *, so we need to manually fetch the created user. + await cur.execute( + 'INSERT INTO "user" (user_id, sub, created_at) VALUES (?, ?, ?)', + (str(uuid4()), sub, datetime.now(timezone.utc).isoformat()), + ) + await conn.commit() + + await cur.execute('SELECT * FROM "user" WHERE sub = ?', (sub,)) + row = await cur.fetchone() + return _deserialize_user(row), True + + async def delete_thread(self, user_id: str, thread_id: str) -> None: + """Delete a thread by ID.""" + conn = self.get_conn() + async with conn.cursor() as cur: + await cur.execute( + "DELETE FROM thread WHERE thread_id = ? AND user_id = ?", + (thread_id, user_id), + ) + await conn.commit() + + +storage = SqliteStorage() diff --git a/backend/app/storage/storage.py b/backend/app/storage/storage.py new file mode 100644 index 00000000..7c15f826 --- /dev/null +++ b/backend/app/storage/storage.py @@ -0,0 +1,10 @@ +from app.storage.postgres import PostgresStorage +from app.storage.settings import StorageType, settings +from app.storage.sqlite import SqliteStorage + +if settings.storage_type == StorageType.SQLITE: + storage = SqliteStorage() +elif settings.storage_type == StorageType.POSTGRES: + storage = PostgresStorage() +else: + raise NotImplementedError() diff --git a/backend/app/upload.py b/backend/app/upload.py index e2dac7e9..59bf520c 100644 --- a/backend/app/upload.py +++ b/backend/app/upload.py @@ -11,7 +11,7 @@ import mimetypes import os -from typing import BinaryIO, List, Optional +from typing import BinaryIO, List, Optional, Union from fastapi import UploadFile from langchain_community.vectorstores.pgvector import PGVector @@ -27,6 +27,7 @@ from app.ingest import ingest_blob from app.parsing import MIMETYPE_BASED_PARSER +from app.storage.settings import StorageType, settings def _guess_mimetype(file_name: str, file_bytes: bytes) -> str: @@ -82,28 +83,38 @@ def convert_ingestion_input_to_blob(file: UploadFile) -> Blob: ) -def _determine_azure_or_openai_embeddings() -> PGVector: +def _get_embedding_function() -> Union[OpenAIEmbeddings, AzureOpenAIEmbeddings]: if os.environ.get("OPENAI_API_KEY"): - return PGVector( - connection_string=PG_CONNECTION_STRING, - embedding_function=OpenAIEmbeddings(), - use_jsonb=True, + return OpenAIEmbeddings() + elif os.environ.get("AZURE_OPENAI_API_KEY"): + return AzureOpenAIEmbeddings( + azure_endpoint=os.environ.get("AZURE_OPENAI_API_BASE"), + azure_deployment=os.environ.get("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME"), + openai_api_version=os.environ.get("AZURE_OPENAI_API_VERSION"), + ) + raise ValueError( + "Either OPENAI_API_KEY or AZURE_OPENAI_API_KEY needs to be set for embeddings to work." + ) + + +def _get_vstore() -> VectorStore: + # TODO Need to add a sqlite-based vectorstore for StorageType.SQLITE. + # Using PGVector is temporary. + if settings.storage_type in (StorageType.POSTGRES, StorageType.SQLITE): + PG_CONNECTION_STRING = PGVector.connection_string_from_db_params( + driver="psycopg2", + host=settings.postgres.host, + port=settings.postgres.port, + database=settings.postgres.db, + user=settings.postgres.user, + password=settings.postgres.password, ) - if os.environ.get("AZURE_OPENAI_API_KEY"): return PGVector( connection_string=PG_CONNECTION_STRING, - embedding_function=AzureOpenAIEmbeddings( - azure_endpoint=os.environ.get("AZURE_OPENAI_API_BASE"), - azure_deployment=os.environ.get( - "AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME" - ), - openai_api_version=os.environ.get("AZURE_OPENAI_API_VERSION"), - ), + embedding_function=_get_embedding_function(), use_jsonb=True, ) - raise ValueError( - "Either OPENAI_API_KEY or AZURE_OPENAI_API_KEY needs to be set for embeddings to work." - ) + raise NotImplementedError() class IngestRunnable(RunnableSerializable[BinaryIO, List[str]]): @@ -144,15 +155,7 @@ def invoke(self, blob: Blob, config: Optional[RunnableConfig] = None) -> List[st return out -PG_CONNECTION_STRING = PGVector.connection_string_from_db_params( - driver="psycopg2", - host=os.environ["POSTGRES_HOST"], - port=int(os.environ["POSTGRES_PORT"]), - database=os.environ["POSTGRES_DB"], - user=os.environ["POSTGRES_USER"], - password=os.environ["POSTGRES_PASSWORD"], -) -vstore = _determine_azure_or_openai_embeddings() +vstore = _get_vstore() ingest_runnable = IngestRunnable( diff --git a/backend/migrations/000001_create_extensions_and_first_tables.down.sql b/backend/migrations/postgres/000001_create_extensions_and_first_tables.down.sql similarity index 100% rename from backend/migrations/000001_create_extensions_and_first_tables.down.sql rename to backend/migrations/postgres/000001_create_extensions_and_first_tables.down.sql diff --git a/backend/migrations/000001_create_extensions_and_first_tables.up.sql b/backend/migrations/postgres/000001_create_extensions_and_first_tables.up.sql similarity index 100% rename from backend/migrations/000001_create_extensions_and_first_tables.up.sql rename to backend/migrations/postgres/000001_create_extensions_and_first_tables.up.sql diff --git a/backend/migrations/000002_checkpoints_update_schema.down.sql b/backend/migrations/postgres/000002_checkpoints_update_schema.down.sql similarity index 100% rename from backend/migrations/000002_checkpoints_update_schema.down.sql rename to backend/migrations/postgres/000002_checkpoints_update_schema.down.sql diff --git a/backend/migrations/000002_checkpoints_update_schema.up.sql b/backend/migrations/postgres/000002_checkpoints_update_schema.up.sql similarity index 100% rename from backend/migrations/000002_checkpoints_update_schema.up.sql rename to backend/migrations/postgres/000002_checkpoints_update_schema.up.sql diff --git a/backend/migrations/000003_create_user.down.sql b/backend/migrations/postgres/000003_create_user.down.sql similarity index 100% rename from backend/migrations/000003_create_user.down.sql rename to backend/migrations/postgres/000003_create_user.down.sql diff --git a/backend/migrations/000003_create_user.up.sql b/backend/migrations/postgres/000003_create_user.up.sql similarity index 100% rename from backend/migrations/000003_create_user.up.sql rename to backend/migrations/postgres/000003_create_user.up.sql diff --git a/backend/migrations/sqlite/000001_create_extensions_and_first_tables.down.sql b/backend/migrations/sqlite/000001_create_extensions_and_first_tables.down.sql new file mode 100644 index 00000000..08c8d5d5 --- /dev/null +++ b/backend/migrations/sqlite/000001_create_extensions_and_first_tables.down.sql @@ -0,0 +1,3 @@ +DROP TABLE IF EXISTS thread; +DROP TABLE IF EXISTS assistant; +DROP TABLE IF EXISTS checkpoints; diff --git a/backend/migrations/sqlite/000001_create_extensions_and_first_tables.up.sql b/backend/migrations/sqlite/000001_create_extensions_and_first_tables.up.sql new file mode 100644 index 00000000..63ec1278 --- /dev/null +++ b/backend/migrations/sqlite/000001_create_extensions_and_first_tables.up.sql @@ -0,0 +1,25 @@ +CREATE TABLE IF NOT EXISTS assistant ( + assistant_id TEXT PRIMARY KEY NOT NULL, -- Manually ensure this is a UUID v4 + user_id TEXT NOT NULL, + name TEXT NOT NULL, + config TEXT NOT NULL, -- Store JSON data as text + updated_at DATETIME DEFAULT (datetime('now')), -- Stores in UTC by default + public BOOLEAN NOT NULL CHECK (public IN (0,1)) -- SQLite uses 0 and 1 for BOOLEAN +); + +CREATE TABLE IF NOT EXISTS thread ( + thread_id TEXT PRIMARY KEY NOT NULL, -- Manually ensure this is a UUID v4 + assistant_id TEXT, -- Store as text and ensure it's a UUID in your application + user_id TEXT NOT NULL, + name TEXT NOT NULL, + updated_at DATETIME DEFAULT (datetime('now')), -- Stores in UTC by default + FOREIGN KEY (assistant_id) REFERENCES assistant(assistant_id) ON DELETE SET NULL +); + +CREATE TABLE IF NOT EXISTS checkpoints ( + thread_id TEXT NOT NULL, + thread_ts DATETIME NOT NULL, + parent_ts DATETIME, + checkpoint BLOB, -- BLOB for binary data, assuming pickle serialization + PRIMARY KEY (thread_id, thread_ts) +); diff --git a/backend/migrations/sqlite/000002_checkpoints_update_schema.down.sql b/backend/migrations/sqlite/000002_checkpoints_update_schema.down.sql new file mode 100644 index 00000000..3baf8e2d --- /dev/null +++ b/backend/migrations/sqlite/000002_checkpoints_update_schema.down.sql @@ -0,0 +1,18 @@ +-- Step 1: Create a new temporary table that reflects the desired final structure, +-- excluding thread_ts and parent_ts columns, and setting thread_id as the primary key. +CREATE TABLE IF NOT EXISTS temp_checkpoints ( + thread_id TEXT NOT NULL, + checkpoint BLOB, + PRIMARY KEY (thread_id) +); + +-- Step 2: Copy relevant data from the original table to the temporary table. +-- Since thread_ts and parent_ts are being dropped, they are not included in the copy. +INSERT INTO temp_checkpoints (thread_id, checkpoint) +SELECT thread_id, checkpoint FROM checkpoints; + +-- Step 3: Drop the original checkpoints table. +DROP TABLE checkpoints; + +-- Step 4: Rename the temporary table to 'checkpoints', effectively recreating the original table structure. +ALTER TABLE temp_checkpoints RENAME TO checkpoints; diff --git a/backend/migrations/sqlite/000002_checkpoints_update_schema.up.sql b/backend/migrations/sqlite/000002_checkpoints_update_schema.up.sql new file mode 100644 index 00000000..50b382c0 --- /dev/null +++ b/backend/migrations/sqlite/000002_checkpoints_update_schema.up.sql @@ -0,0 +1,27 @@ +-- Step 2: Update the newly added columns with current UTC datetime where they are NULL. +-- This assumes you handle NULL values appropriately in your application if these columns are expected to have meaningful timestamps. +UPDATE checkpoints +SET thread_ts = datetime('now') +WHERE thread_ts IS NULL; + +-- Since SQLite does not allow altering a table to drop or add a primary key constraint directly, +-- you need to create a new table with the desired structure, copy the data, drop the old table, and rename the new one. + +-- Step 3: Create a new table with the correct structure and primary key. +CREATE TABLE IF NOT EXISTS new_checkpoints ( + thread_id TEXT NOT NULL, + thread_ts DATETIME NOT NULL, + parent_ts DATETIME, + checkpoint BLOB, + PRIMARY KEY (thread_id, thread_ts) +); + +-- Step 4: Copy data from the old table to the new table. +INSERT INTO new_checkpoints (thread_id, thread_ts, parent_ts, checkpoint) +SELECT thread_id, thread_ts, parent_ts, checkpoint FROM checkpoints; + +-- Step 5: Drop the old table. +DROP TABLE checkpoints; + +-- Step 6: Rename the new table to the original table's name. +ALTER TABLE new_checkpoints RENAME TO checkpoints; diff --git a/backend/migrations/sqlite/000003_create_user.down.sql b/backend/migrations/sqlite/000003_create_user.down.sql new file mode 100644 index 00000000..c83349bd --- /dev/null +++ b/backend/migrations/sqlite/000003_create_user.down.sql @@ -0,0 +1,9 @@ +-- SQLite doesn't support ALTER TABLE to drop constraints or change column types directly. +-- Similar to the "up" migration, if you need to reverse the changes, +-- you would have to recreate each table without the foreign keys and with the original column types. + +-- For "assistant" and "thread", remove the foreign keys by recreating the tables without them. +-- Follow a similar process as described in the "up" migration, but omit the FOREIGN KEY definitions. + +-- Drop the "user" table. +DROP TABLE IF EXISTS "user"; \ No newline at end of file diff --git a/backend/migrations/sqlite/000003_create_user.up.sql b/backend/migrations/sqlite/000003_create_user.up.sql new file mode 100644 index 00000000..8a5c0a7d --- /dev/null +++ b/backend/migrations/sqlite/000003_create_user.up.sql @@ -0,0 +1,51 @@ +-- Create the "user" table. Use TEXT for UUID and store timestamps as TEXT. +CREATE TABLE IF NOT EXISTS "user" ( + user_id TEXT PRIMARY KEY, + sub TEXT UNIQUE NOT NULL, + created_at DATETIME DEFAULT (datetime('now')) +); + +-- Insert distinct users from the "assistant" table. +-- SQLite doesn't support ON CONFLICT DO NOTHING in the same way, so use INSERT OR IGNORE. +-- The casting (user_id::uuid) isn't needed since we treat all UUIDs as TEXT. +INSERT OR IGNORE INTO "user" (user_id, sub) +SELECT DISTINCT user_id, user_id +FROM assistant +WHERE user_id IS NOT NULL; + +-- Insert distinct users from the "thread" table. +INSERT OR IGNORE INTO "user" (user_id, sub) +SELECT DISTINCT user_id, user_id +FROM thread +WHERE user_id IS NOT NULL; + +-- SQLite does not support adding foreign keys via ALTER TABLE. +-- You will need to recreate tables to add foreign key constraints, as shown previously. +-- Here's a simplified approach for "assistant" assuming dropping and recreating is acceptable. + +-- Example for "assistant", assuming it's acceptable to drop & recreate it: +-- 1. Rename existing table. +ALTER TABLE assistant RENAME TO assistant_old; + +-- 2. Create new table with foreign key constraint. +CREATE TABLE assistant ( + assistant_id TEXT PRIMARY KEY NOT NULL, -- Manually ensure this is a UUID v4 + user_id TEXT NOT NULL, + name TEXT NOT NULL, + config TEXT NOT NULL, -- Store JSON data as text + updated_at DATETIME DEFAULT (datetime('now')), -- Stores in UTC by default + public BOOLEAN NOT NULL CHECK (public IN (0,1)), + FOREIGN KEY (user_id) REFERENCES "user" (user_id) +); + +-- Version 3 - Create user table. +CREATE TABLE IF NOT EXISTS "user" ( + user_id TEXT PRIMARY KEY NOT NULL, + sub TEXT UNIQUE NOT NULL, + created_at DATETIME DEFAULT (datetime('now')) +); + +-- 4. Drop old table. +DROP TABLE assistant_old; + +-- Repeat similar steps for "thread" table to add the foreign key constraint. \ No newline at end of file diff --git a/backend/poetry.lock b/backend/poetry.lock index 62ebed8a..22082937 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -110,6 +110,24 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "aiosqlite" +version = "0.20.0" +description = "asyncio bridge to the standard sqlite3 module" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiosqlite-0.20.0-py3-none-any.whl", hash = "sha256:36a1deaca0cac40ebe32aac9977a6e2bbc7f5189f23f4a54d5908986729e5bd6"}, + {file = "aiosqlite-0.20.0.tar.gz", hash = "sha256:6d35c8c256637f4672f843c31021464090805bf925385ac39473fb16eaaca3d7"}, +] + +[package.dependencies] +typing_extensions = ">=4.0" + +[package.extras] +dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"] +docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"] + [[package]] name = "anthropic" version = "0.25.2" @@ -1935,6 +1953,7 @@ description = "Powerful and Pythonic XML processing library combining libxml2/li optional = false python-versions = ">=3.6" files = [ + {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"}, {file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"}, {file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"}, {file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"}, @@ -1944,6 +1963,7 @@ files = [ {file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"}, {file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"}, {file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"}, + {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"}, {file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"}, {file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"}, {file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"}, @@ -1953,6 +1973,7 @@ files = [ {file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"}, {file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"}, {file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"}, + {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"}, {file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"}, {file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"}, {file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"}, @@ -1978,8 +1999,8 @@ files = [ {file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"}, {file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"}, {file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"}, + {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"}, {file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"}, - {file = "lxml-5.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cfbac9f6149174f76df7e08c2e28b19d74aed90cad60383ad8671d3af7d0502f"}, {file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"}, {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"}, {file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"}, @@ -1987,6 +2008,7 @@ files = [ {file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"}, {file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"}, {file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"}, + {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"}, {file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"}, {file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"}, {file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"}, @@ -2904,6 +2926,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4212,4 +4235,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9.0,<3.12" -content-hash = "fc96cf95416874baa59fc1b85463f4e1e9e5ade9e1c76febb74cadf341da11bb" +content-hash = "16cd528b4f7ee3c971f435c204690a07cdab1c508c72cfd926600cf8ed376a64" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4c99f56e..07bdf209 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -43,6 +43,7 @@ asyncpg = "^0.29.0" langchain-core = "^0.1.44" pyjwt = {extras = ["crypto"], version = "^2.8.0"} langchain-anthropic = "^0.1.8" +aiosqlite = "^0.20.0" [tool.poetry.group.dev.dependencies] uvicorn = "^0.23.2" diff --git a/backend/tests/unit_tests/agent_executor/test_upload.py b/backend/tests/unit_tests/agent_executor/test_upload.py index e239ef02..19736027 100644 --- a/backend/tests/unit_tests/agent_executor/test_upload.py +++ b/backend/tests/unit_tests/agent_executor/test_upload.py @@ -1,7 +1,8 @@ from io import BytesIO -from langchain.text_splitter import RecursiveCharacterTextSplitter from fastapi import UploadFile +from langchain.text_splitter import RecursiveCharacterTextSplitter + from app.upload import IngestRunnable, _guess_mimetype, convert_ingestion_input_to_blob from tests.unit_tests.fixtures import get_sample_paths from tests.unit_tests.utils import InMemoryVectorStore diff --git a/backend/tests/unit_tests/conftest.py b/backend/tests/unit_tests/conftest.py index 4d21da0d..29635759 100644 --- a/backend/tests/unit_tests/conftest.py +++ b/backend/tests/unit_tests/conftest.py @@ -7,8 +7,10 @@ from app.auth.settings import AuthType from app.auth.settings import settings as auth_settings -from app.lifespan import get_pg_pool, lifespan +from app.lifespan import lifespan from app.server import app +from app.storage.settings import settings as storage_settings +from app.storage.storage import storage auth_settings.auth_type = AuthType.NOOP @@ -16,16 +18,16 @@ os.environ["OPENAI_API_KEY"] = "test" TEST_DB = "test" -assert os.environ["POSTGRES_DB"] != TEST_DB, "Test and main database conflict." -os.environ["POSTGRES_DB"] = TEST_DB +assert storage_settings.postgres.db != TEST_DB, "Test and main database conflict." +storage_settings.postgres.db = TEST_DB async def _get_conn() -> asyncpg.Connection: return await asyncpg.connect( - user=os.environ["POSTGRES_USER"], - password=os.environ["POSTGRES_PASSWORD"], - host=os.environ["POSTGRES_HOST"], - port=os.environ["POSTGRES_PORT"], + user=storage_settings.postgres.user, + password=storage_settings.postgres.password, + host=storage_settings.postgres.host, + port=storage_settings.postgres.port, database="postgres", ) @@ -49,7 +51,21 @@ async def _drop_test_db() -> None: def _migrate_test_db() -> None: - subprocess.run(["make", "migrate"], check=True) + subprocess.run( + [ + "migrate", + "-database", + ( + f"postgres://{storage_settings.postgres.user}:{storage_settings.postgres.password}" + f"@{storage_settings.postgres.host}:{storage_settings.postgres.port}" + f"/{storage_settings.postgres.db}?sslmode=disable" + ), + "-path", + "./migrations/postgres", + "up", + ], + check=True, + ) @pytest.fixture(scope="session") @@ -58,7 +74,7 @@ async def pool(): await _create_test_db() _migrate_test_db() async with lifespan(app): - yield get_pg_pool() + yield storage.get_pool() await _drop_test_db() diff --git a/docker-compose-prod.yml b/docker-compose-prod.yml index e94ac4ac..157ebcc9 100644 --- a/docker-compose-prod.yml +++ b/docker-compose-prod.yml @@ -36,4 +36,5 @@ services: env_file: - .env environment: + STORAGE_TYPE: "postgres" POSTGRES_HOST: "postgres" diff --git a/docker-compose.pg.yml b/docker-compose.pg.yml new file mode 100644 index 00000000..381b9484 --- /dev/null +++ b/docker-compose.pg.yml @@ -0,0 +1,58 @@ +version: "3" + +services: + postgres: + image: pgvector/pgvector:pg16 + healthcheck: + test: pg_isready -U $POSTGRES_USER + start_interval: 1s + start_period: 5s + interval: 5s + retries: 5 + ports: + - "5433:5432" + env_file: + - .env + volumes: + - ./postgres-volume:/var/lib/postgresql/data + postgres-setup: + image: migrate/migrate + depends_on: + postgres: + condition: service_healthy + volumes: + - ./backend/migrations:/migrations + env_file: + - .env + command: ["-path", "/migrations/postgres", "-database", "postgres://$POSTGRES_USER:$POSTGRES_PASSWORD@postgres:$POSTGRES_PORT/$POSTGRES_DB?sslmode=disable", "up"] + backend: + container_name: opengpts-backend + build: + context: backend + ports: + - "8100:8000" # Backend is accessible on localhost:8100 + depends_on: + postgres-setup: + condition: service_completed_successfully + env_file: + - .env + volumes: + - ./backend:/backend + environment: + STORAGE_TYPE: "postgres" + POSTGRES_HOST: "postgres" + command: + - --reload + frontend: + container_name: opengpts-frontend + build: + context: frontend + depends_on: + backend: + condition: service_healthy + volumes: + - ./frontend/src:/frontend/src + ports: + - "5173:5173" # Frontend is accessible on localhost:5173 + environment: + VITE_BACKEND_URL: "http://backend:8000" diff --git a/docker-compose.yml b/docker-compose.yml index 6d59cb42..36aada47 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,7 +24,13 @@ services: - ./backend/migrations:/migrations env_file: - .env - command: ["-path", "/migrations", "-database", "postgres://$POSTGRES_USER:$POSTGRES_PASSWORD@postgres:$POSTGRES_PORT/$POSTGRES_DB?sslmode=disable", "up"] + command: ["-path", "/migrations/postgres", "-database", "postgres://$POSTGRES_USER:$POSTGRES_PASSWORD@postgres:$POSTGRES_PORT/$POSTGRES_DB?sslmode=disable", "up"] + sqlite-setup: + build: + context: tools/sqlite_migrate + volumes: + - ./backend:/backend + command: ["migrate", "-path", "/backend/migrations/sqlite", "-database", "sqlite3://opengpts.db", "up"] backend: container_name: opengpts-backend build: @@ -32,6 +38,8 @@ services: ports: - "8100:8000" # Backend is accessible on localhost:8100 depends_on: + sqlite-setup: + condition: service_completed_successfully postgres-setup: condition: service_completed_successfully env_file: @@ -39,6 +47,7 @@ services: volumes: - ./backend:/backend environment: + STORAGE_TYPE: "sqlite" POSTGRES_HOST: "postgres" command: - --reload diff --git a/tools/redis_to_postgres/migrate_data.py b/tools/redis_to_postgres/migrate_data.py index 84cfd9a0..bd20e27e 100644 --- a/tools/redis_to_postgres/migrate_data.py +++ b/tools/redis_to_postgres/migrate_data.py @@ -21,7 +21,8 @@ from redis.client import Redis as RedisType from app.checkpoint import PostgresCheckpoint -from app.lifespan import get_pg_pool, lifespan +from app.lifespan import lifespan +from app.storage.storage import storage from app.server import app logging.basicConfig( @@ -265,7 +266,7 @@ def _get_embedding(doc: dict) -> str: async def migrate_data(): logger.info("Starting to migrate data from Redis to Postgres.") - async with get_pg_pool().acquire() as conn, conn.transaction(): + async with storage.get_pool().acquire() as conn, conn.transaction(): await migrate_assistants(conn) await migrate_threads(conn) await migrate_checkpoints() diff --git a/tools/sqlite_migrate/Dockerfile b/tools/sqlite_migrate/Dockerfile new file mode 100644 index 00000000..94114c0c --- /dev/null +++ b/tools/sqlite_migrate/Dockerfile @@ -0,0 +1,10 @@ +# Can't use the migrate/migrate image like we do for Postgres because it doesn't +# support sqlite3 driver by default. We need to install it via Go toolchain: +# https://github.com/golang-migrate/migrate/issues/899#issuecomment-1483741684 +# +# Using the alpine3.18 tag because of: https://github.com/nkanaev/yarr/issues/187 + +FROM golang:alpine3.18 +RUN apk add build-base +RUN go install -tags 'sqlite3' github.com/golang-migrate/migrate/v4/cmd/migrate@v4.17.1 +WORKDIR /backend \ No newline at end of file