diff --git a/procrastinate/blueprints.py b/procrastinate/blueprints.py index 5f737d3b4..e39619189 100644 --- a/procrastinate/blueprints.py +++ b/procrastinate/blueprints.py @@ -3,9 +3,9 @@ import functools import logging import sys -from typing import TYPE_CHECKING, Any, Callable, Literal, Union, cast, overload +from typing import TYPE_CHECKING, Callable, Literal, Union, cast, overload -from typing_extensions import Concatenate, ParamSpec, Unpack +from typing_extensions import Concatenate, ParamSpec, TypeVar, Unpack from procrastinate import exceptions, jobs, periodic, retry, utils from procrastinate.job_context import JobContext @@ -17,6 +17,7 @@ logger = logging.getLogger(__name__) P = ParamSpec("P") +R = TypeVar("R") class Blueprint: @@ -210,7 +211,7 @@ def task( priority: int = jobs.DEFAULT_PRIORITY, lock: str | None = None, queueing_lock: str | None = None, - ) -> Callable[[Callable[P]], Task[P, P]]: + ) -> Callable[[Callable[P, R]], Task[P, R, P]]: """Declare a function as a task. This method is meant to be used as a decorator Parameters ---------- @@ -265,8 +266,8 @@ def task( lock: str | None = None, queueing_lock: str | None = None, ) -> Callable[ - [Callable[Concatenate[JobContext, P]]], - Task[Concatenate[JobContext, P], P], + [Callable[Concatenate[JobContext, P], R]], + Task[Concatenate[JobContext, P], R, P], ]: """Declare a function as a task. This method is meant to be used as a decorator Parameters @@ -277,7 +278,7 @@ def task( ... @overload - def task(self, _func: Callable[P]) -> Task[P, P]: + def task(self, _func: Callable[P, R]) -> Task[P, R, P]: """Declare a function as a task. This method is meant to be used as a decorator Parameters ---------- @@ -288,7 +289,7 @@ def task(self, _func: Callable[P]) -> Task[P, P]: def task( self, - _func: Callable[P] | None = None, + _func: Callable[P, R] | None = None, *, name: str | None = None, aliases: list[str] | None = None, @@ -316,7 +317,7 @@ def my_task(args): The second form will use the default value for all parameters. """ - def _wrap(func: Callable[P]) -> Callable[P, Task[P, P]]: + def _wrap(func: Callable[P, R]) -> Task[P, R, P]: task = Task( func, blueprint=self, @@ -331,15 +332,21 @@ def _wrap(func: Callable[P]) -> Callable[P, Task[P, P]]: ) self._register_task(task) - return functools.update_wrapper(task, func, updated=()) + # The signature of a function returned by functools.update_wrapper + # is the same as the signature of the wrapped function (at least on pyright). + # Here, we're actually returning a Task so a cast is needed to provide the correct signature. + return cast( + Task[P, R, P], + functools.update_wrapper(task, func, updated=()), + ) if _func is None: # Called as @app.task(...) return cast( Union[ - Callable[[Callable[P, Any]], Task[P, P]], + Callable[[Callable[P, R]], Task[P, R, P]], Callable[ - [Callable[Concatenate[JobContext, P], Any]], - Task[Concatenate[JobContext, P], P], + [Callable[Concatenate[JobContext, P], R]], + Task[Concatenate[JobContext, P], R, P], ], ], _wrap, diff --git a/procrastinate/periodic.py b/procrastinate/periodic.py index 938fa980d..78048bffb 100644 --- a/procrastinate/periodic.py +++ b/procrastinate/periodic.py @@ -8,11 +8,12 @@ import attr import croniter -from typing_extensions import Concatenate, ParamSpec, Unpack +from typing_extensions import Concatenate, ParamSpec, TypeVar, Unpack from procrastinate import exceptions, tasks P = ParamSpec("P") +R = TypeVar("R") Args = ParamSpec("Args") # The maximum delay after which tasks will be considered as @@ -28,8 +29,8 @@ @attr.dataclass(frozen=True) -class PeriodicTask(Generic[P, Args]): - task: tasks.Task[P, Args] +class PeriodicTask(Generic[P, R, Args]): + task: tasks.Task[P, R, Args] cron: str periodic_id: str configure_kwargs: tasks.ConfigureTaskOptions @@ -51,31 +52,33 @@ def periodic_decorator( cron: str, periodic_id: str, **configure_kwargs: Unpack[tasks.ConfigureTaskOptions], - ) -> Callable[[tasks.Task[P, Concatenate[int, Args]]], tasks.Task[P, Args]]: + ) -> Callable[[tasks.Task[P, R, Concatenate[int, Args]]], tasks.Task[P, R, Args]]: """ Decorator over a task definition that registers that task for periodic launch. This decorator should not be used directly, ``@app.periodic()`` is meant to be used instead. """ - def wrapper(task: tasks.Task[P, Concatenate[int, Args]]) -> tasks.Task[P, Args]: + def wrapper( + task: tasks.Task[P, R, Concatenate[int, Args]], + ) -> tasks.Task[P, R, Args]: self.register_task( task=task, cron=cron, periodic_id=periodic_id, configure_kwargs=configure_kwargs, ) - return cast(tasks.Task[P, Args], task) + return cast(tasks.Task[P, R, Args], task) return wrapper def register_task( self, - task: tasks.Task[P, Concatenate[int, Args]], + task: tasks.Task[P, R, Concatenate[int, Args]], cron: str, periodic_id: str, configure_kwargs: tasks.ConfigureTaskOptions, - ) -> PeriodicTask[P, Concatenate[int, Args]]: + ) -> PeriodicTask[P, R, Concatenate[int, Args]]: key = (task.name, periodic_id) if key in self.periodic_tasks: raise exceptions.TaskAlreadyRegistered( diff --git a/procrastinate/tasks.py b/procrastinate/tasks.py index 7ae3afb68..876da1c69 100644 --- a/procrastinate/tasks.py +++ b/procrastinate/tasks.py @@ -2,9 +2,9 @@ import datetime import logging -from typing import Any, Callable, Generic, TypedDict, cast +from typing import Callable, Generic, TypedDict, cast -from typing_extensions import NotRequired, ParamSpec, Unpack +from typing_extensions import NotRequired, ParamSpec, TypeVar, Unpack from procrastinate import app as app_module from procrastinate import blueprints, exceptions, jobs, manager, types, utils @@ -15,6 +15,7 @@ Args = ParamSpec("Args") P = ParamSpec("P") +R = TypeVar("R") class ConfigureTaskOptions(TypedDict): @@ -62,7 +63,7 @@ def configure_task( ) -class Task(Generic[P, Args]): +class Task(Generic[P, R, Args]): """ A task is a function that should be executed later. It is linked to a default queue, and expects keyword arguments. @@ -70,7 +71,7 @@ class Task(Generic[P, Args]): def __init__( self, - func: Callable[P], + func: Callable[P, R], *, blueprint: blueprints.Blueprint, # task naming @@ -94,7 +95,7 @@ def __init__( #: priority is 0. self.priority: int = priority self.blueprint: blueprints.Blueprint = blueprint - self.func: Callable[P] = func + self.func: Callable[P, R] = func #: Additional names for the task. self.aliases: list[str] = aliases if aliases else [] #: Value indicating the retry conditions in case of @@ -123,7 +124,7 @@ def add_namespace(self, namespace: str) -> None: for alias in self.aliases ] - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Any: + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: return self.func(*args, **kwargs) @property