Skip to content

Commit

Permalink
Switch to async sqlalchemy (#21)
Browse files Browse the repository at this point in the history
* Switch to async sqlalchemy

* Small fix

---------

Co-authored-by: Nicolas Frank <[email protected]>
  • Loading branch information
WonderPG and Nicolas Frank authored Oct 17, 2024
1 parent 335d222 commit edf17f1
Show file tree
Hide file tree
Showing 13 changed files with 217 additions and 144 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Switched from OAUTH2 security on FASTAPI to HTTPBearer.
- Switched to async sqlalchemy.

### Added
- Add get morphoelectric (me) model tool
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dynamic = ["version"]
dependencies = [
"aiosqlite",
"asgi-correlation-id",
"asyncpg",
"bluepyefe",
"efel",
"fastapi",
Expand All @@ -24,11 +25,10 @@ dependencies = [
"langgraph-checkpoint-sqlite",
"neurom",
"psycopg-binary",
"psycopg2-binary",
"pydantic-settings",
"python-dotenv",
"python-keycloak",
"sqlalchemy",
"sqlalchemy[asyncio]",
"uvicorn",
]

Expand Down
38 changes: 35 additions & 3 deletions src/neuroagent/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from langchain.chat_models.base import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from pydantic import BaseModel, ConfigDict

Expand Down Expand Up @@ -58,16 +60,46 @@ class AsyncSqliteSaverWithPrefix(AsyncSqliteSaver):
@asynccontextmanager
async def from_conn_string(
cls, conn_string: str
) -> AsyncIterator["AsyncSqliteSaver"]:
) -> AsyncIterator[AsyncSqliteSaver]:
"""Create a new AsyncSqliteSaver instance from a connection string.
Args:
conn_string (str): The SQLite connection string. It can have the 'sqlite:///' prefix.
conn_string (str): The async SQLite connection string. It can have the 'sqlite+aiosqlite:///' prefix.
Yields
------
AsyncSqliteSaverWithPrefix: A new AsyncSqliteSaverWithPrefix instance.
AsyncSqliteSaverWithPrefix: A new connected AsyncSqliteSaverWithPrefix instance.
"""
conn_string = conn_string.split("///")[-1]
async with super().from_conn_string(conn_string) as memory:
yield AsyncSqliteSaverWithPrefix(memory.conn)


class AsyncPostgresSaverWithPrefix(AsyncPostgresSaver):
"""Wrapper around the AsyncSqliteSaver that accepts a connection string with prefix."""

@classmethod
@asynccontextmanager
async def from_conn_string(
cls,
conn_string: str,
*,
pipeline: bool = False,
serde: SerializerProtocol | None = None,
) -> AsyncIterator[AsyncPostgresSaver]:
"""Create a new AsyncPostgresSaver instance from a connection string.
Args:
conn_string (str): The async Postgres connection string. It can have the 'postgresql+asyncpg://' prefix.
Yields
------
AsyncPostgresSaverWithPrefix: A new connected AsyncPostgresSaverWithPrefix instance.
"""
prefix, body = conn_string.split("://", maxsplit=1)
currated_prefix = prefix.split("+", maxsplit=1)[0] # Still works if + not there
conn_string = currated_prefix + "://" + body
async with super().from_conn_string(
conn_string, pipeline=pipeline, serde=serde
) as memory:
yield AsyncPostgresSaverWithPrefix(memory.conn, memory.pipe, memory.serde)
37 changes: 37 additions & 0 deletions src/neuroagent/app/app_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
"""App utilities functions."""

import logging
from typing import Any

from fastapi import HTTPException
from httpx import AsyncClient
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from starlette.status import HTTP_401_UNAUTHORIZED

from neuroagent.app.config import Settings

logger = logging.getLogger(__name__)


async def validate_project(
httpx_client: AsyncClient,
Expand All @@ -22,3 +31,31 @@ async def validate_project(
status_code=HTTP_401_UNAUTHORIZED,
detail="User does not belong to the project.",
)


def setup_engine(
settings: Settings, connection_string: str | None = None
) -> AsyncEngine | None:
"""Get the SQL engine."""
if connection_string:
engine_kwargs: dict[str, Any] = {"url": connection_string}
if "sqlite" in settings.db.prefix: # type: ignore
# https://fastapi.tiangolo.com/tutorial/sql-databases/#create-the-sqlalchemy-engine
engine_kwargs["connect_args"] = {"check_same_thread": False}
engine = create_async_engine(**engine_kwargs)
else:
logger.warning("The SQL db_prefix needs to be set to use the SQL DB.")
return None
try:
engine.connect()
logger.info(
"Successfully connected to the SQL database"
f" {connection_string if not settings.db.password else connection_string.replace(settings.db.password.get_secret_value(), '*****')}."
)
return engine
except SQLAlchemyError:
logger.warning(
"Failed connection to SQL database"
f" {connection_string if not settings.db.password else connection_string.replace(settings.db.password.get_secret_value(), '*****')}."
)
return None
58 changes: 17 additions & 41 deletions src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,26 @@

import logging
from functools import cache
from typing import Annotated, Any, AsyncIterator, Iterator
from typing import Annotated, Any, AsyncIterator

from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPBearer
from httpx import AsyncClient, HTTPStatusError
from keycloak import KeycloakOpenID
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from starlette.status import HTTP_401_UNAUTHORIZED

from neuroagent.agents import (
BaseAgent,
SimpleAgent,
SimpleChatAgent,
)
from neuroagent.agents.base_agent import AsyncSqliteSaverWithPrefix
from neuroagent.agents.base_agent import (
AsyncPostgresSaverWithPrefix,
AsyncSqliteSaverWithPrefix,
)
from neuroagent.app.app_utils import validate_project
from neuroagent.app.config import Settings
from neuroagent.app.routers.database.schemas import Threads
Expand Down Expand Up @@ -101,39 +100,14 @@ def get_connection_string(
return None


@cache
def get_engine(
settings: Annotated[Settings, Depends(get_settings)],
connection_string: Annotated[str | None, Depends(get_connection_string)],
) -> Engine | None:
def get_engine(request: Request) -> AsyncEngine | None:
"""Get the SQL engine."""
if connection_string:
engine_kwargs: dict[str, Any] = {"url": connection_string}
if "sqlite" in settings.db.prefix: # type: ignore
# https://fastapi.tiangolo.com/tutorial/sql-databases/#create-the-sqlalchemy-engine
engine_kwargs["connect_args"] = {"check_same_thread": False}
engine = create_engine(**engine_kwargs)
else:
logger.warning("The SQL db_prefix needs to be set to use the SQL DB.")
return None
try:
engine.connect()
logger.info(
"Successfully connected to the SQL database"
f" {connection_string if not settings.db.password else connection_string.replace(settings.db.password.get_secret_value(), '*****')}."
)
return engine
except SQLAlchemyError:
logger.warning(
"Failed connection to SQL database"
f" {connection_string if not settings.db.password else connection_string.replace(settings.db.password.get_secret_value(), '*****')}."
)
return None
return request.app.state.engine


def get_session(
engine: Annotated[Engine | None, Depends(get_engine)],
) -> Iterator[Session]:
async def get_session(
engine: Annotated[AsyncEngine | None, Depends(get_engine)],
) -> AsyncIterator[AsyncSession]:
"""Yield a session per request."""
if not engine:
raise HTTPException(
Expand All @@ -142,7 +116,7 @@ def get_session(
"detail": "Couldn't connect to the SQL DB.",
},
)
with Session(engine) as session:
async with AsyncSession(engine) as session:
yield session


Expand Down Expand Up @@ -375,7 +349,9 @@ async def get_agent_memory(
await memory.conn.close()

elif connection_string.startswith("postgresql"):
async with AsyncPostgresSaver.from_conn_string(connection_string) as memory:
async with AsyncPostgresSaverWithPrefix.from_conn_string(
connection_string
) as memory:
await memory.setup()
yield memory
await memory.conn.close()
Expand All @@ -396,7 +372,7 @@ async def get_agent_memory(

async def get_vlab_and_project(
user_id: Annotated[str, Depends(get_user_id)],
session: Annotated[Session, Depends(get_session)],
session: Annotated[AsyncSession, Depends(get_session)],
request: Request,
settings: Annotated[Settings, Depends(get_settings)],
httpx_client: Annotated[AsyncClient, Depends(get_httpx_client)],
Expand All @@ -415,7 +391,7 @@ async def get_vlab_and_project(
}
else:
thread_id = request.path_params.get("thread_id")
thread = session.get(Threads, (thread_id, user_id))
thread = await session.get(Threads, (thread_id, user_id))
if thread and thread.vlab_id and thread.project_id:
vlab_and_project = {
"vlab_id": thread.vlab_id,
Expand Down
20 changes: 14 additions & 6 deletions src/neuroagent/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from starlette.middleware.base import BaseHTTPMiddleware

from neuroagent import __version__
from neuroagent.app.app_utils import setup_engine
from neuroagent.app.config import Settings
from neuroagent.app.dependencies import (
get_agent_memory,
get_cell_types_kg_hierarchy,
get_connection_string,
get_engine,
get_kg_token,
get_settings,
get_update_kg_hierarchy,
Expand Down Expand Up @@ -71,17 +71,23 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncContextManager[None]: # type:
"""Read environment (settings of the application)."""
# hacky but works: https://github.com/tiangolo/fastapi/issues/425
app_settings = fastapi_app.dependency_overrides.get(get_settings, get_settings)()
engine = fastapi_app.dependency_overrides.get(get_engine, get_engine)(
app_settings, get_connection_string(app_settings)
)
# This creates the checkpoints and writes tables.

# Get the sqlalchemy engine
conn_string = get_connection_string(app_settings)
engine = setup_engine(app_settings, conn_string)

# Store it in the state
fastapi_app.state.engine = engine

# Create the checkpoints and writes tables.
await anext(
fastapi_app.dependency_overrides.get(get_agent_memory, get_agent_memory)(
get_connection_string(app_settings)
)
)
if engine:
Base.metadata.create_all(bind=engine)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

prefix = app_settings.misc.application_prefix
fastapi_app.openapi_url = f"{prefix}/openapi.json"
Expand Down Expand Up @@ -115,6 +121,8 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncContextManager[None]: # type:
)

yield
if engine:
await engine.dispose()


app = FastAPI(
Expand Down
8 changes: 4 additions & 4 deletions src/neuroagent/app/routers/database/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@

from fastapi import Depends, HTTPException
from fastapi.security import HTTPBasic
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession

from neuroagent.app.dependencies import get_session, get_user_id
from neuroagent.app.routers.database.schemas import Threads

security = HTTPBasic()


def get_object(
session: Annotated[Session, Depends(get_session)],
async def get_object(
session: Annotated[AsyncSession, Depends(get_session)],
thread_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> Threads:
Expand All @@ -33,7 +33,7 @@ def get_object(
object
Relevant row of the relevant table in the SQL DB.
"""
sql_object = session.get(Threads, (thread_id, user_id))
sql_object = await session.get(Threads, (thread_id, user_id))
if not sql_object:
raise HTTPException(
status_code=404,
Expand Down
Loading

0 comments on commit edf17f1

Please sign in to comment.