diff --git a/CHANGELOG.md b/CHANGELOG.md index e0d1342..5b9d766 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,16 @@ ## 0.17 +### 0.17.1 + +- Fixes bounded-tuple options like `tuple[str, str]` to infer as `num_args=2` +- Fixes bounded-tuple options to fail parsing if you give it a different number + of values +- Fixes "double sequence" inference on explicit `num_args=N` values which would + produce sequences. I.e. infer `action=ArgAction.set` in such cases to avoid + e.x. `num_args=3, action=ArgAction.append`; resulting in nonsensical nested + sequence `["[]"]` + ### 0.17.0 - feat: Add `hidden=True/False` option to Command, which allows hiding diff --git a/docs/source/annotation.md b/docs/source/annotation.md index a631a61..23af3fd 100644 --- a/docs/source/annotation.md +++ b/docs/source/annotation.md @@ -122,6 +122,12 @@ Annotations like `list[...]`, `set[...]`, `tuple[...]`, etc are what we call assert prog == Prog(foo=['foo', 'bar', 'baz']) ``` + ```{note} + You can specify `Annotated[list[str], Arg(short=True, num_args=n)]` where + `n` would yield a sequence (`-1` or > 1). In such a case, `action` would instead + be inferred as `ArgAction.set`. + ``` + See [Argument](./arg.md) for more details on the difference between `ArgAction.append` and `num_args=-1`. diff --git a/pyproject.toml b/pyproject.toml index 7de23ba..d3c37f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cappa" -version = "0.17.0" +version = "0.17.1" description = "Declarative CLI argument parser." repository = "https://github.com/dancardin/cappa" diff --git a/src/cappa/arg.py b/src/cappa/arg.py index ec85d14..deb9b35 100644 --- a/src/cappa/arg.py +++ b/src/cappa/arg.py @@ -176,19 +176,15 @@ def normalize( fallback_help: str | None = None, action: ArgAction | Callable | None = None, default: typing.Any = missing, - value_name: str | None = None, field_name: str | None = None, ) -> Arg: origin = typing.get_origin(annotation) or annotation type_args = typing.get_args(annotation) field_name = typing.cast(str, field_name or self.field_name) - value_name = value_name or ( - self.value_name if self.value_name is not missing else field_name - ) default = default if default is not missing else self.default - verify_type_compatibility(self, annotation, origin, type_args) + verify_type_compatibility(self, field_name, annotation, origin, type_args) short = infer_short(self, field_name) long = infer_long(self, origin, field_name) choices = infer_choices(self, origin, type_args) @@ -204,6 +200,8 @@ def normalize( group = infer_group(self, short, long) + value_name = infer_value_name(self, field_name, num_args) + return dataclasses.replace( self, default=default, @@ -237,7 +235,11 @@ def names_str(self, delimiter: str = ", ", *, n=0) -> str: def verify_type_compatibility( - arg: Arg, annotation: type, origin: type, type_args: tuple[type, ...] + arg: Arg, + field_name: str, + annotation: type, + origin: type, + type_args: tuple[type, ...], ): """Verify classes of annotations are compatible with one another. @@ -259,7 +261,7 @@ def verify_type_compatibility( } if len(all_same_arity) > 1: raise ValueError( - f"On field '{arg.field_name}', apparent mismatch of annotated type with `Arg` options. " + f"On field '{field_name}', apparent mismatch of annotated type with `Arg` options. " 'Unioning "sequence" types with non-sequence types is not currently supported, ' "unless using `Arg(parse=...)` or `Arg(action=)`. " "See [documentation](https://cappa.readthedocs.io/en/latest/annotation.html) for more details." @@ -267,17 +269,19 @@ def verify_type_compatibility( return num_args = arg.num_args + # print(is_sequence_type(origin), num_args, action) + # print(f" {num_args not in {0, 1} or action is ArgAction.append}") if is_sequence_type(origin): - if num_args == 1 or action not in {ArgAction.append, None}: + if num_args in {0, 1} and action not in {ArgAction.append, None}: raise ValueError( - f"On field '{arg.field_name}', apparent mismatch of annotated type with `Arg` options. " + f"On field '{field_name}', apparent mismatch of annotated type with `Arg` options. " f"'{annotation}' type produces a sequence, whereas `num_args=1`/`action={action}` do not. " "See [documentation](https://cappa.readthedocs.io/en/latest/annotation.html) for more details." ) else: - if num_args not in {None, 1} or action is ArgAction.append: + if num_args not in {None, 0, 1} or action is ArgAction.append: raise ValueError( - f"On field '{arg.field_name}', apparent mismatch of annotated type with `Arg` options. " + f"On field '{field_name}', apparent mismatch of annotated type with `Arg` options. " f"'{origin.__name__}' type produces a scalar, whereas `num_args={num_args}`/`action={action}` do not. " "See [documentation](https://cappa.readthedocs.io/en/latest/annotation.html) for more details." ) @@ -350,7 +354,7 @@ def infer_short(arg: Arg, name: str) -> list[str] | typing.Literal[False]: else: short = arg.short - return [item if item.startswith("-") else f"-{item[0]}" for item in short] + return [item if item.startswith("-") else f"-{item}" for item in short] def infer_long(arg: Arg, origin: type, name: str) -> list[str] | typing.Literal[False]: @@ -416,7 +420,12 @@ def infer_action( has_specific_num_args = arg.num_args is not None unbounded_num_args = arg.num_args == -1 - if arg.parse or unbounded_num_args or (is_positional and not has_specific_num_args): + if ( + arg.parse + or unbounded_num_args + or (is_positional and not has_specific_num_args) + or (has_specific_num_args and arg.num_args != 1) + ): return ArgAction.set if is_of_type(annotation, (typing.List, typing.Set)): @@ -436,7 +445,7 @@ def infer_num_args( type_args: tuple[type, ...], action: ArgAction | Callable, long, -) -> int | None: +) -> int: if arg.num_args is not None: return arg.num_args @@ -551,6 +560,19 @@ def infer_group( return typing.cast(typing.Tuple[int, str], group) +def infer_value_name(arg: Arg, field_name: str, num_args: int | None) -> str: + if arg.value_name is not missing: + return arg.value_name + + if num_args == -1: + return f"{field_name} ..." + + if num_args and num_args > 1: + return " ".join([field_name] * num_args) + + return field_name + + no_extra_arg_actions = { ArgAction.store_true, ArgAction.store_false, diff --git a/src/cappa/parser.py b/src/cappa/parser.py index 69c0950..2ee4bd8 100644 --- a/src/cappa/parser.py +++ b/src/cappa/parser.py @@ -484,16 +484,8 @@ def consume_arg( result = [] while num_args: if isinstance(context.peek_value(), RawOption): - if orig_num_args < 0: - break + break - raise BadArgumentError( - f"Argument requires {orig_num_args} values, " - f"only found {len(result)} ('{' '.join(result)}' so far)", - value=result, - command=context.command, - arg=arg, - ) try: next_val = typing.cast(RawArg, context.next_value()) except IndexError: @@ -538,6 +530,20 @@ def consume_arg( command=context.command, arg=arg, ) + else: + if orig_num_args > 0 and len(result) != orig_num_args: + quoted_result = [f"'{r}'" for r in result] + names_str = arg.names_str("/") + + message = f"Argument '{names_str}' requires {orig_num_args} values, found {len(result)}" + if quoted_result: + message += f" ({', '.join(quoted_result)} so far)" + raise BadArgumentError( + message, + value=result, + command=context.command, + arg=arg, + ) if option and arg.field_name in context.missing_options: context.missing_options.remove(arg.field_name) diff --git a/tests/arg/test_invalid_annotation_combination.py b/tests/arg/test_invalid_annotation_combination.py index d087add..77112fd 100644 --- a/tests/arg/test_invalid_annotation_combination.py +++ b/tests/arg/test_invalid_annotation_combination.py @@ -31,7 +31,7 @@ class Args: def test_sequence_with_scalar_action(backend): @dataclass class Args: - foo: Annotated[list[str], cappa.Arg(action=cappa.ArgAction.set)] + foo: Annotated[list[str], cappa.Arg(action=cappa.ArgAction.set, num_args=1)] with pytest.raises(ValueError) as e: parse(Args, "--help", backend=backend) @@ -50,13 +50,22 @@ def test_sequence_with_scalar_num_args(backend): class Args: foo: Annotated[list[str], cappa.Arg(num_args=1, short=True)] + args = parse(Args, "-f", "a", "-f", "b", backend=backend) + assert args == Args(["a", "b"]) + + @dataclass + class ArgsBad: + foo: Annotated[ + list[str], cappa.Arg(num_args=1, short=True, action=cappa.ArgAction.set) + ] + with pytest.raises(ValueError) as e: - parse(Args, "--help", backend=backend) + parse(ArgsBad, "--help", backend=backend) result = str(e.value).replace("typing.List", "list") assert result == ( "On field 'foo', apparent mismatch of annotated type with `Arg` options. " - "'list[str]' type produces a sequence, whereas `num_args=1`/`action=None` do not. " + "'list[str]' type produces a sequence, whereas `num_args=1`/`action=ArgAction.set` do not. " "See [documentation](https://cappa.readthedocs.io/en/latest/annotation.html) for more details." ) diff --git a/tests/arg/test_tuple.py b/tests/arg/test_tuple.py index 20974d7..caa7c4c 100644 --- a/tests/arg/test_tuple.py +++ b/tests/arg/test_tuple.py @@ -2,6 +2,10 @@ from dataclasses import dataclass +import cappa +import pytest +from typing_extensions import Annotated + from tests.utils import backends, parse @@ -14,3 +18,38 @@ class ArgTest: def test_valid(backend): test = parse(ArgTest, "1", "2", "3.4", backend=backend) assert test.numbers == (1, "2", 3.4) + + +@backends +def test_tuple_option(backend): + @dataclass + class Example: + start_project: Annotated[ + tuple[int, float], + cappa.Arg(short=True, default=(1, 9), required=False), + ] + + test = parse(Example, backend=backend) + assert test == Example(start_project=(1, 9)) + + test = parse(Example, "-s", "2", "2.4", backend=backend) + assert test == Example(start_project=(2, 2.4)) + + # Missing values + with pytest.raises(cappa.Exit) as e: + parse(Example, "-s", "1", backend=backend) + + assert e.value.code == 2 + + if backend: + assert str(e.value.message).lower() == "argument -s: expected 2 arguments" + else: + assert ( + e.value.message == "Argument '-s' requires 2 values, found 1 ('1' so far)" + ) + + # Extra values + with pytest.raises(cappa.Exit) as e: + parse(Example, "-s", "1", "2", "3", backend=backend) + assert e.value.code == 2 + assert e.value.message == "Unrecognized arguments: 3" diff --git a/tests/parser/test_missing_num_args.py b/tests/parser/test_missing_num_args.py index 9ce2df8..960e3fd 100644 --- a/tests/parser/test_missing_num_args.py +++ b/tests/parser/test_missing_num_args.py @@ -25,4 +25,4 @@ class Args: if backend == argparse.backend: assert "The following arguments are required: arg" in message else: - assert message == "Argument requires 2 values, only found 1 ('arg' so far)" + assert message == "Argument 'arg arg' requires 2 values, found 1 ('arg' so far)" diff --git a/tests/parser/test_num_args.py b/tests/parser/test_num_args.py index d02dfdf..1f069aa 100644 --- a/tests/parser/test_num_args.py +++ b/tests/parser/test_num_args.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass import cappa @@ -25,3 +27,34 @@ class Args: assert e.value.message is None assert e.value.code == 0 + + +@backends +def test_num_args_list(backend): + @dataclass + class Args: + foo: Annotated[ + list[str], + cappa.Arg( + short=True, + default=["", ""], + required=False, + num_args=2, + ), + ] + + args = parse(Args, backend=backend) + assert args == Args(["", ""]) + + args = parse(Args, "-f", "2", "4", backend=backend) + assert args == Args(["2", "4"]) + + with pytest.raises(cappa.Exit) as e: + parse(Args, "-f", backend=backend) + + assert e.value.code == 2 + + if backend: + assert str(e.value.message).lower() == "argument -f: expected 2 arguments" + else: + assert e.value.message == "Argument '-f' requires 2 values, found 0"