Skip to content

Commit

Permalink
Support stacked decorators for async workflows (#192)
Browse files Browse the repository at this point in the history
Currently, You can't declare an async function to be both a FastAPI
handler and a DBOS workflow:

```python
@app.get("/endpoint/{var1}/{var2}")
@DBOS.workflow()
async def test_endpoint(var1: str, var2: str) -> str:
    return f"{var1}, {var2}!"
```

The problem stems from the function returned by the `@DBOS.workflow`
decorator. Both the decorated function and the function returned by the
decorator return a coroutine, but the function returned by the decorator
is not defined with `async def`. This causes FastAPI to mis-categorize
the workflow function as sync when it is actually async. The function
retuned by the decorator has to appear as a coroutine to
`inspect.iscoroutinefunction`.

For Python 3.12 and later, any function can be marked as a coroutine
using
[`inspect.markcoroutinefunction`](https://docs.python.org/3/library/inspect.html#inspect.markcoroutinefunction).

For Python 3.11 and earlier, we have to wrap the coroutine returning
function in an async function like this:

```python
def _mark_coroutine(func: Callable[P, R]) -> Callable[P, R]:
    @wraps(func)
    async def async_wrapper(*args: Any, **kwargs: Any) -> R:
        return await func(*args, **kwargs)  # type: ignore

    return async_wrapper  # type: ignore
```
  • Loading branch information
devhawk authored Jan 31, 2025
1 parent 862a59e commit 9998cd8
Show file tree
Hide file tree
Showing 3 changed files with 632 additions and 24 deletions.
22 changes: 21 additions & 1 deletion dbos/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,22 @@ def start_workflow(
return WorkflowHandleFuture(new_wf_id, future, dbos)


if sys.version_info < (3, 12):

def _mark_coroutine(func: Callable[P, R]) -> Callable[P, R]:
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> R:
return await func(*args, **kwargs) # type: ignore

return async_wrapper # type: ignore

else:

def _mark_coroutine(func: Callable[P, R]) -> Callable[P, R]:
inspect.markcoroutinefunction(func)
return func


def workflow_wrapper(
dbosreg: "DBOSRegistry",
func: Callable[P, R],
Expand Down Expand Up @@ -548,7 +564,7 @@ def init_wf() -> Callable[[Callable[[], R]], R]:
)
return outcome() # type: ignore

return wrapper
return _mark_coroutine(wrapper) if inspect.iscoroutinefunction(func) else wrapper


def decorate_workflow(
Expand Down Expand Up @@ -838,6 +854,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
assert tempwf
return tempwf(*args, **kwargs)

wrapper = (
_mark_coroutine(wrapper) if inspect.iscoroutinefunction(func) else wrapper # type: ignore
)

def temp_wf_sync(*args: Any, **kwargs: Any) -> Any:
return wrapper(*args, **kwargs)

Expand Down
Loading

0 comments on commit 9998cd8

Please sign in to comment.