Skip to content

Commit

Permalink
WIP/POC sharing a single sqlite connection across the application.
Browse files Browse the repository at this point in the history
  • Loading branch information
bakar-io committed May 3, 2024
1 parent bf3bd82 commit 5e2a9dd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 37 deletions.
2 changes: 1 addition & 1 deletion backend/app/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]:
async def setup(self) -> None:
if self.is_setup:
return
self.conn = await storage.create_sqlite_conn(global_=True)
self.conn = storage.get_conn()
self.is_setup = True


Expand Down
73 changes: 37 additions & 36 deletions backend/app/storage/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Any, AsyncGenerator, Optional, Sequence, Union
from typing import Any, Optional, Sequence, Union
from uuid import uuid4

import aiosqlite
Expand Down Expand Up @@ -45,39 +44,33 @@ def _deserialize_user(row: aiosqlite.Row) -> User:


class SqliteStorage(BaseStorage):
_global_sqlite_connections = []
_conn: aiosqlite.Connection = None
_is_setup: bool = False

async def setup(self) -> None:
pass
if self._is_setup:
return
self._conn = await aiosqlite.connect("opengpts.db")
self._conn.row_factory = aiosqlite.Row
await self._conn.execute("pragma journal_mode=wal")
self._is_setup = True

# TODO remove
await self._conn.set_trace_callback(print)

async def teardown(self) -> None:
await self._close_global_sqlite_connections()

async def create_sqlite_conn(
self, global_: bool = False, **kwargs
) -> aiosqlite.Connection:
conn = await aiosqlite.connect("opengpts.db", **kwargs)
conn.row_factory = aiosqlite.Row
if global_:
self._global_sqlite_connections.append(conn)
return conn

@asynccontextmanager
async def sqlite_conn(self, **kwargs) -> AsyncGenerator[aiosqlite.Connection, None]:
conn = await self.create_sqlite_conn(**kwargs)
try:
yield conn
finally:
await conn.close()

async def _close_global_sqlite_connections(self) -> None:
for conn in self._global_sqlite_connections:
await conn.close()
self._global_sqlite_connections.clear()
await self._conn.close()
self._conn = None
self._is_setup = False

def get_conn(self) -> aiosqlite.Connection:
if not self._is_setup:
raise RuntimeError("Storage is not set up.")
return self._conn

async def list_assistants(self, user_id: str) -> list[Assistant]:
"""List all assistants for the current user."""
async with self.sqlite_conn() as conn, conn.cursor() as cur:
async with self.get_conn().cursor() as cur:
await cur.execute("SELECT * FROM assistant WHERE user_id = ?", (user_id,))
rows = await cur.fetchall()
return [_deserialize_assistant(row) for row in rows]
Expand All @@ -86,7 +79,7 @@ async def get_assistant(
self, user_id: str, assistant_id: str
) -> Optional[Assistant]:
"""Get an assistant by ID."""
async with self.sqlite_conn() as conn, conn.cursor() as cur:
async with self.get_conn().cursor() as cur:
await cur.execute(
"SELECT * FROM assistant WHERE assistant_id = ? AND (user_id = ? OR public = 1)",
(assistant_id, user_id),
Expand All @@ -96,7 +89,7 @@ async def get_assistant(

async def list_public_assistants(self) -> list[Assistant]:
"""List all the public assistants."""
async with self.sqlite_conn() as conn, conn.cursor() as cur:
async with self.get_conn().cursor() as cur:
await cur.execute("SELECT * FROM assistant WHERE public = 1")
rows = await cur.fetchall()
return [_deserialize_assistant(row) for row in rows]
Expand All @@ -112,7 +105,8 @@ async def put_assistant(
) -> Assistant:
"""Modify an assistant."""
updated_at = datetime.now(timezone.utc)
async with self.sqlite_conn() as conn, conn.cursor() as cur:
conn = self.get_conn()
async with conn.cursor() as cur:
await cur.execute(
"""
INSERT INTO assistant (assistant_id, user_id, name, config, updated_at, public)
Expand Down Expand Up @@ -146,14 +140,14 @@ async def put_assistant(

async def list_threads(self, user_id: str) -> list[Thread]:
"""List all threads for the current user."""
async with self.sqlite_conn() as conn, conn.cursor() as cur:
async with self.get_conn().cursor() as cur:
await cur.execute("SELECT * FROM thread WHERE user_id = ?", (user_id,))
rows = await cur.fetchall()
return [_deserialize_thread(row) for row in rows]

async def get_thread(self, user_id: str, thread_id: str) -> Optional[Thread]:
"""Get a thread by ID."""
async with self.sqlite_conn() as conn, conn.cursor() as cur:
async with self.get_conn().cursor() as cur:
await cur.execute(
"SELECT * FROM thread WHERE thread_id = ? AND user_id = ?",
(thread_id, user_id),
Expand Down Expand Up @@ -232,7 +226,8 @@ async def put_thread(
) -> Thread:
"""Modify a thread."""
updated_at = datetime.now(timezone.utc)
async with self.sqlite_conn() as conn, conn.cursor() as cur:
conn = self.get_conn()
async with conn.cursor() as cur:
await cur.execute(
"""
INSERT INTO thread (thread_id, user_id, assistant_id, name, updated_at)
Expand All @@ -257,13 +252,18 @@ async def put_thread(

async def get_or_create_user(self, sub: str) -> tuple[User, bool]:
"""Returns a tuple of the user and a boolean indicating whether the user was created."""
async with self.sqlite_conn() as conn, conn.cursor() as cur:
conn = self.get_conn()
async with conn.cursor() as cur:
# Start a write transaction to avoid the unique contraint error due to
# concurrent inserts.
# TODO worked when connection wasn't shared across app
await cur.execute("BEGIN EXCLUSIVE")
await cur.execute('SELECT * FROM "user" WHERE sub = ?', (sub,))
row = await cur.fetchone()
if row:
# Since we are using a single connection in the whole application,
# we can't leave the transaction open, so we need to commit it here.
await conn.commit()
return _deserialize_user(row), False

# SQLite doesn't support RETURNING *, so we need to manually fetch the created user.
Expand All @@ -279,7 +279,8 @@ async def get_or_create_user(self, sub: str) -> tuple[User, bool]:

async def delete_thread(self, user_id: str, thread_id: str) -> None:
"""Delete a thread by ID."""
async with self.sqlite_conn() as conn, conn.cursor() as cur:
conn = self.get_conn()
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM thread WHERE thread_id = ? AND user_id = ?",
(thread_id, user_id),
Expand Down

0 comments on commit 5e2a9dd

Please sign in to comment.