diff --git a/CHANGELOG.md b/CHANGELOG.md index 47f1117..3eb03dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## 0.16 +### 0.16.1 + +- feat: Support `Dep` on function based commands. + ### 0.16.0 - feat: Add support for `BinaryIO` and `TextIO` for representing preconfigured diff --git a/pyproject.toml b/pyproject.toml index 4754d74..ba3fb59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cappa" -version = "0.16.0" +version = "0.16.1" description = "Declarative CLI argument parser." repository = "https://github.com/dancardin/cappa" diff --git a/src/cappa/class_inspect.py b/src/cappa/class_inspect.py index 3084036..d1b5b10 100644 --- a/src/cappa/class_inspect.py +++ b/src/cappa/class_inspect.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import functools import inspect import typing from enum import Enum @@ -8,7 +9,7 @@ import typing_inspect from typing_extensions import Self, get_args -from cappa.typing import MISSING, get_type_hints, missing +from cappa.typing import MISSING, find_type_annotation, get_type_hints, missing if typing.TYPE_CHECKING: from cappa import Arg, Subcommand @@ -201,28 +202,44 @@ def get_command_capable_object(obj): the arguments to the dataclass into the original callable. """ if inspect.isfunction(obj): + from cappa import Dep - def call(self): + function_args = [] + + @functools.wraps(obj) + def call(self, **deps): kwargs = dataclasses.asdict(self) - return obj(**kwargs) + return obj(**kwargs, **deps) + + # We need to create a fake signature for the above callable, which does + # not retain the `Arg` annotations + sig = inspect.signature(obj) + sig_params: dict = dict(sig.parameters) + sig._parameters = sig_params # type: ignore + call.__signature__ = sig # type: ignore args = get_type_hints(obj, include_extras=True) parameters = inspect.signature(obj).parameters - fields = [ - ( - name, - annotation, - dataclasses.field( - default=parameters[name].default - if parameters[name].default is not inspect.Parameter.empty - else dataclasses.MISSING - ), + for name, annotation in args.items(): + if find_type_annotation(annotation, Dep).obj: + continue + + sig_params.pop(name, None) + function_args.append( + ( + name, + annotation, + dataclasses.field( + default=parameters[name].default + if parameters[name].default is not inspect.Parameter.empty + else dataclasses.MISSING + ), + ) ) - for name, annotation in args.items() - ] + return dataclasses.make_dataclass( obj.__name__, - fields, + function_args, namespace={"__call__": call}, ) diff --git a/src/cappa/invoke.py b/src/cappa/invoke.py index ce9a93d..33157ac 100644 --- a/src/cappa/invoke.py +++ b/src/cappa/invoke.py @@ -264,7 +264,7 @@ def resolve_implicit_deps(command: Command, instance: HasCommand) -> dict: def fullfill_deps(fn: Callable, fullfilled_deps: dict) -> typing.Any: result = {} - signature = inspect.signature(fn) + signature = getattr(fn, "__signature__", None) or inspect.signature(fn) try: annotations = get_type_hints(fn, include_extras=True) except NameError as e: # pragma: no cover diff --git a/src/cappa/typing.py b/src/cappa/typing.py index 1fd077c..18b26c8 100644 --- a/src/cappa/typing.py +++ b/src/cappa/typing.py @@ -111,8 +111,9 @@ def is_subclass(typ, superclass): def get_type_hints(obj, include_extras=False): result = typing_extensions.get_type_hints(obj, include_extras=include_extras) if sys.version_info < (3, 11): # pragma: no cover - return fix_annotated_optional_type_hints(result) - return result + result = fix_annotated_optional_type_hints(result) + + return {k: v for k, v in result.items() if k not in {"return"}} def fix_annotated_optional_type_hints( diff --git a/tests/command/test_function_command.py b/tests/command/test_function_command.py index 6bb62a0..8d3713c 100644 --- a/tests/command/test_function_command.py +++ b/tests/command/test_function_command.py @@ -73,3 +73,22 @@ def function(sub: cappa.Subcommands[Union[Sub, None]] = None): result = invoke(function, "sub", "--bar", "34", backend=backend) assert result == 35 + + +def foo(): + return 5 + + +@backends +def test_invoke_partial_arg_partial_dep(backend): + def function( + dep: Annotated[int, cappa.Dep(foo)], + foo: Annotated[int, cappa.Arg(long=True)] = 15, + ): + return dep + foo + + result = invoke(function, backend=backend) + assert result == 20 + + result = invoke(function, "--foo", "53", backend=backend) + assert result == 58