Skip to content

Commit

Permalink
Wrap Existing FastAPI Lifespan Handlers (#191)
Browse files Browse the repository at this point in the history
Make the DBOS FastAPI lifespan handler wrap whatever handler already
exists (if one does) instead of overriding it.
  • Loading branch information
kraftp authored Jan 31, 2025
1 parent 9998cd8 commit 1b3cff2
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 20 deletions.
19 changes: 10 additions & 9 deletions dbos/_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import uuid
from typing import Any, Callable, cast
from typing import Any, Callable, MutableMapping, cast

from fastapi import FastAPI
from fastapi import Request as FastAPIRequest
from fastapi.responses import JSONResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.types import ASGIApp, Receive, Scope, Send

from . import DBOS
from ._context import (
Expand Down Expand Up @@ -61,15 +61,16 @@ def __init__(self, app: ASGIApp, dbos: DBOS):

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":

async def wrapped_send(message: MutableMapping[str, Any]) -> None:
if message["type"] == "lifespan.startup.complete":
self.dbos._launch()
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
elif message["type"] == "lifespan.shutdown.complete":
self.dbos._destroy()
await send({"type": "lifespan.shutdown.complete"})
break
await send(message)

# Call the original app with our wrapped functions
await self.app(scope, receive, wrapped_send)
else:
await self.app(scope, receive, send)

Expand Down
10 changes: 1 addition & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,7 @@ def dbos_fastapi(
) -> Generator[Tuple[DBOS, FastAPI], Any, None]:
DBOS.destroy()
app = FastAPI()

# ignore the on_event deprecation warnings
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message=r"\s*on_event is deprecated, use lifespan event handlers instead\.",
)
dbos = DBOS(fastapi=app, config=config)
dbos = DBOS(fastapi=app, config=config)

# This is for test convenience.
# Usually fastapi itself does launch, but we are not completing the fastapi lifecycle
Expand Down
50 changes: 48 additions & 2 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import asyncio
import logging
import uuid
from typing import Tuple
from contextlib import asynccontextmanager
from typing import Any, Tuple

import httpx
import pytest
import sqlalchemy as sa
import uvicorn
from fastapi import FastAPI
from fastapi.testclient import TestClient

# Public API
from dbos import DBOS
from dbos import DBOS, ConfigFile

# Private API because this is a unit test
from dbos._context import assert_current_dbos_context
Expand Down Expand Up @@ -159,6 +163,48 @@ def test_endpoint(var1: str, var2: str) -> dict[str, str]:
assert workflow_handles[0].get_result() == ("a", wfuuid)


@pytest.mark.asyncio
async def test_custom_lifespan(
config: ConfigFile, cleanup_test_databases: None
) -> None:
resource = None
port = 8000

@asynccontextmanager
async def lifespan(app: FastAPI) -> Any:
nonlocal resource
resource = 1
yield
resource = None

app = FastAPI(lifespan=lifespan)

DBOS.destroy()
DBOS(fastapi=app, config=config)

@app.get("/")
@DBOS.workflow()
async def resource_workflow() -> Any:
return {"resource": resource}

uvicorn_config = uvicorn.Config(
app=app, host="127.0.0.1", port=port, log_level="error"
)
server = uvicorn.Server(config=uvicorn_config)

# Run server in background task
server_task = asyncio.create_task(server.serve())
await asyncio.sleep(0.2) # Give server time to start

async with httpx.AsyncClient() as client:
r = await client.get(f"http://127.0.0.1:{port}")
assert r.json()["resource"] == 1

server.should_exit = True
await server_task
assert resource is None


def test_stacked_decorators_wf(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None:
dbos, app = dbos_fastapi
client = TestClient(app)
Expand Down

0 comments on commit 1b3cff2

Please sign in to comment.