Skip to content

Commit

Permalink
Merge pull request #1196 from fau-st/typing_fix_task_return_type
Browse files Browse the repository at this point in the history
  • Loading branch information
ewjoachim authored Sep 19, 2024
2 parents f7463ee + d8c43f5 commit 70509f4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 26 deletions.
31 changes: 19 additions & 12 deletions procrastinate/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import functools
import logging
import sys
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, cast, overload
from typing import TYPE_CHECKING, Callable, Literal, Union, cast, overload

from typing_extensions import Concatenate, ParamSpec, Unpack
from typing_extensions import Concatenate, ParamSpec, TypeVar, Unpack

from procrastinate import exceptions, jobs, periodic, retry, utils
from procrastinate.job_context import JobContext
Expand All @@ -17,6 +17,7 @@
logger = logging.getLogger(__name__)

P = ParamSpec("P")
R = TypeVar("R")


class Blueprint:
Expand Down Expand Up @@ -210,7 +211,7 @@ def task(
priority: int = jobs.DEFAULT_PRIORITY,
lock: str | None = None,
queueing_lock: str | None = None,
) -> Callable[[Callable[P]], Task[P, P]]:
) -> Callable[[Callable[P, R]], Task[P, R, P]]:
"""Declare a function as a task. This method is meant to be used as a decorator
Parameters
----------
Expand Down Expand Up @@ -265,8 +266,8 @@ def task(
lock: str | None = None,
queueing_lock: str | None = None,
) -> Callable[
[Callable[Concatenate[JobContext, P]]],
Task[Concatenate[JobContext, P], P],
[Callable[Concatenate[JobContext, P], R]],
Task[Concatenate[JobContext, P], R, P],
]:
"""Declare a function as a task. This method is meant to be used as a decorator
Parameters
Expand All @@ -277,7 +278,7 @@ def task(
...

@overload
def task(self, _func: Callable[P]) -> Task[P, P]:
def task(self, _func: Callable[P, R]) -> Task[P, R, P]:
"""Declare a function as a task. This method is meant to be used as a decorator
Parameters
----------
Expand All @@ -288,7 +289,7 @@ def task(self, _func: Callable[P]) -> Task[P, P]:

def task(
self,
_func: Callable[P] | None = None,
_func: Callable[P, R] | None = None,
*,
name: str | None = None,
aliases: list[str] | None = None,
Expand Down Expand Up @@ -316,7 +317,7 @@ def my_task(args):
The second form will use the default value for all parameters.
"""

def _wrap(func: Callable[P]) -> Callable[P, Task[P, P]]:
def _wrap(func: Callable[P, R]) -> Task[P, R, P]:
task = Task(
func,
blueprint=self,
Expand All @@ -331,15 +332,21 @@ def _wrap(func: Callable[P]) -> Callable[P, Task[P, P]]:
)
self._register_task(task)

return functools.update_wrapper(task, func, updated=())
# The signature of a function returned by functools.update_wrapper
# is the same as the signature of the wrapped function (at least on pyright).
# Here, we're actually returning a Task so a cast is needed to provide the correct signature.
return cast(
Task[P, R, P],
functools.update_wrapper(task, func, updated=()),
)

if _func is None: # Called as @app.task(...)
return cast(
Union[
Callable[[Callable[P, Any]], Task[P, P]],
Callable[[Callable[P, R]], Task[P, R, P]],
Callable[
[Callable[Concatenate[JobContext, P], Any]],
Task[Concatenate[JobContext, P], P],
[Callable[Concatenate[JobContext, P], R]],
Task[Concatenate[JobContext, P], R, P],
],
],
_wrap,
Expand Down
19 changes: 11 additions & 8 deletions procrastinate/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

import attr
import croniter
from typing_extensions import Concatenate, ParamSpec, Unpack
from typing_extensions import Concatenate, ParamSpec, TypeVar, Unpack

from procrastinate import exceptions, tasks

P = ParamSpec("P")
R = TypeVar("R")
Args = ParamSpec("Args")

# The maximum delay after which tasks will be considered as
Expand All @@ -28,8 +29,8 @@


@attr.dataclass(frozen=True)
class PeriodicTask(Generic[P, Args]):
task: tasks.Task[P, Args]
class PeriodicTask(Generic[P, R, Args]):
task: tasks.Task[P, R, Args]
cron: str
periodic_id: str
configure_kwargs: tasks.ConfigureTaskOptions
Expand All @@ -51,31 +52,33 @@ def periodic_decorator(
cron: str,
periodic_id: str,
**configure_kwargs: Unpack[tasks.ConfigureTaskOptions],
) -> Callable[[tasks.Task[P, Concatenate[int, Args]]], tasks.Task[P, Args]]:
) -> Callable[[tasks.Task[P, R, Concatenate[int, Args]]], tasks.Task[P, R, Args]]:
"""
Decorator over a task definition that registers that task for periodic
launch. This decorator should not be used directly, ``@app.periodic()`` is meant
to be used instead.
"""

def wrapper(task: tasks.Task[P, Concatenate[int, Args]]) -> tasks.Task[P, Args]:
def wrapper(
task: tasks.Task[P, R, Concatenate[int, Args]],
) -> tasks.Task[P, R, Args]:
self.register_task(
task=task,
cron=cron,
periodic_id=periodic_id,
configure_kwargs=configure_kwargs,
)
return cast(tasks.Task[P, Args], task)
return cast(tasks.Task[P, R, Args], task)

return wrapper

def register_task(
self,
task: tasks.Task[P, Concatenate[int, Args]],
task: tasks.Task[P, R, Concatenate[int, Args]],
cron: str,
periodic_id: str,
configure_kwargs: tasks.ConfigureTaskOptions,
) -> PeriodicTask[P, Concatenate[int, Args]]:
) -> PeriodicTask[P, R, Concatenate[int, Args]]:
key = (task.name, periodic_id)
if key in self.periodic_tasks:
raise exceptions.TaskAlreadyRegistered(
Expand Down
13 changes: 7 additions & 6 deletions procrastinate/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import datetime
import logging
from typing import Any, Callable, Generic, TypedDict, cast
from typing import Callable, Generic, TypedDict, cast

from typing_extensions import NotRequired, ParamSpec, Unpack
from typing_extensions import NotRequired, ParamSpec, TypeVar, Unpack

from procrastinate import app as app_module
from procrastinate import blueprints, exceptions, jobs, manager, types, utils
Expand All @@ -15,6 +15,7 @@

Args = ParamSpec("Args")
P = ParamSpec("P")
R = TypeVar("R")


class ConfigureTaskOptions(TypedDict):
Expand Down Expand Up @@ -62,15 +63,15 @@ def configure_task(
)


class Task(Generic[P, Args]):
class Task(Generic[P, R, Args]):
"""
A task is a function that should be executed later. It is linked to a
default queue, and expects keyword arguments.
"""

def __init__(
self,
func: Callable[P],
func: Callable[P, R],
*,
blueprint: blueprints.Blueprint,
# task naming
Expand All @@ -94,7 +95,7 @@ def __init__(
#: priority is 0.
self.priority: int = priority
self.blueprint: blueprints.Blueprint = blueprint
self.func: Callable[P] = func
self.func: Callable[P, R] = func
#: Additional names for the task.
self.aliases: list[str] = aliases if aliases else []
#: Value indicating the retry conditions in case of
Expand Down Expand Up @@ -123,7 +124,7 @@ def add_namespace(self, namespace: str) -> None:
for alias in self.aliases
]

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Any:
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
return self.func(*args, **kwargs)

@property
Expand Down

0 comments on commit 70509f4

Please sign in to comment.