Skip to content

Commit

Permalink
Don't set workflow ID if the header is not set
Browse files Browse the repository at this point in the history
  • Loading branch information
qianl15 committed Jan 29, 2025
1 parent 983a028 commit b69ebf7
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 6 deletions.
8 changes: 6 additions & 2 deletions dbos/_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ async def dbos_fastapi_middleware(
with EnterDBOSHandler(attributes):
ctx = assert_current_dbos_context()
ctx.request = _make_request(request)
workflow_id = request.headers.get("dbos-idempotency-key", "")
with SetWorkflowID(workflow_id):
workflow_id = request.headers.get("dbos-idempotency-key")
if workflow_id is not None:
# Set the workflow ID for the handler
with SetWorkflowID(workflow_id):
response = await call_next(request)
else:
response = await call_next(request)
return response
8 changes: 6 additions & 2 deletions dbos/_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ def __call__(self, environ: Any, start_response: Any) -> Any:
with EnterDBOSHandler(attributes):
ctx = assert_current_dbos_context()
ctx.request = _make_request(request)
workflow_id = request.headers.get("dbos-idempotency-key", "")
with SetWorkflowID(workflow_id):
workflow_id = request.headers.get("dbos-idempotency-key")
if workflow_id is not None:
# Set the workflow ID for the handler
with SetWorkflowID(workflow_id):
response = self.app(environ, start_response)
else:
response = self.app(environ, start_response)
return response

Expand Down
21 changes: 20 additions & 1 deletion tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import uuid
from typing import Tuple

import pytest
import sqlalchemy as sa
from fastapi import FastAPI
from fastapi.testclient import TestClient
Expand All @@ -12,7 +14,9 @@
from dbos._context import assert_current_dbos_context


def test_simple_endpoint(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None:
def test_simple_endpoint(
dbos_fastapi: Tuple[DBOS, FastAPI], caplog: pytest.LogCaptureFixture
) -> None:
dbos, app = dbos_fastapi
client = TestClient(app)

Expand All @@ -32,6 +36,7 @@ def test_workflow(var1: str, var2: str) -> str:
res2 = test_step(var2)
return res1 + res2

@app.get("/transaction/{var}")
@DBOS.transaction()
def test_transaction(var: str) -> str:
rows = DBOS.sql_session.execute(sa.text("SELECT 1")).fetchall()
Expand All @@ -41,13 +46,27 @@ def test_transaction(var: str) -> str:
def test_step(var: str) -> str:
return var

original_propagate = logging.getLogger("dbos").propagate
caplog.set_level(logging.WARNING, "dbos")
logging.getLogger("dbos").propagate = True

response = client.get("/workflow/bob/bob")
assert response.status_code == 200
assert response.text == '"bob1bob"'
assert caplog.text == ""

response = client.get("/endpoint/bob/bob")
assert response.status_code == 200
assert response.text == '"bob1bob"'
assert caplog.text == ""

response = client.get("/transaction/bob")
assert response.status_code == 200
assert response.text == '"bob1"'
assert caplog.text == ""

# Reset logging
logging.getLogger("dbos").propagate = original_propagate


def test_start_workflow(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None:
Expand Down
21 changes: 20 additions & 1 deletion tests/test_flask.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import logging
import uuid
from typing import Tuple

import pytest
import sqlalchemy as sa
from flask import Flask, Response, jsonify

from dbos import DBOS
from dbos._context import assert_current_dbos_context


def test_flask_endpoint(dbos_flask: Tuple[DBOS, Flask]) -> None:
def test_flask_endpoint(
dbos_flask: Tuple[DBOS, Flask], caplog: pytest.LogCaptureFixture
) -> None:
_, app = dbos_flask

@app.route("/endpoint/<var1>/<var2>")
Expand All @@ -27,6 +31,7 @@ def test_workflow(var1: str, var2: str) -> Response:
result = res1 + res2
return jsonify({"result": result})

@app.route("/transaction/<var>")
@DBOS.transaction()
def test_transaction(var: str) -> str:
rows = DBOS.sql_session.execute(sa.text("SELECT 1")).fetchall()
Expand All @@ -39,13 +44,27 @@ def test_step(var: str) -> str:
app.config["TESTING"] = True
client = app.test_client()

original_propagate = logging.getLogger("dbos").propagate
caplog.set_level(logging.WARNING, "dbos")
logging.getLogger("dbos").propagate = True

response = client.get("/endpoint/a/b")
assert response.status_code == 200
assert response.json == {"result": "a1b"}
assert caplog.text == ""

response = client.get("/workflow/a/b")
assert response.status_code == 200
assert response.json == {"result": "a1b"}
assert caplog.text == ""

response = client.get("/transaction/bob")
assert response.status_code == 200
assert response.text == "bob1"
assert caplog.text == ""

# Reset logging
logging.getLogger("dbos").propagate = original_propagate


def test_endpoint_recovery(dbos_flask: Tuple[DBOS, Flask]) -> None:
Expand Down

0 comments on commit b69ebf7

Please sign in to comment.