Skip to content

Commit

Permalink
Merge pull request #107 from DanCardin/dc/inference-handling
Browse files Browse the repository at this point in the history
fix: Add better sequence inference.
  • Loading branch information
DanCardin authored Mar 7, 2024
2 parents 77fbc19 + 391d7e0 commit 0bde925
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 28 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

## 0.17

### 0.17.1

- Fixes bounded-tuple options like `tuple[str, str]` to infer as `num_args=2`
- Fixes bounded-tuple options to fail parsing if you give it a different number
of values
- Fixes "double sequence" inference on explicit `num_args=N` values which would
produce sequences. I.e. infer `action=ArgAction.set` in such cases to avoid
e.x. `num_args=3, action=ArgAction.append`; resulting in nonsensical nested
sequence `["[]"]`

### 0.17.0

- feat: Add `hidden=True/False` option to Command, which allows hiding
Expand Down
6 changes: 6 additions & 0 deletions docs/source/annotation.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ Annotations like `list[...]`, `set[...]`, `tuple[...]`, etc are what we call
assert prog == Prog(foo=['foo', 'bar', 'baz'])
```

```{note}
You can specify `Annotated[list[str], Arg(short=True, num_args=n)]` where
`n` would yield a sequence (`-1` or > 1). In such a case, `action` would instead
be inferred as `ArgAction.set`.
```

See [Argument](./arg.md) for more details on the difference between
`ArgAction.append` and `num_args=-1`.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cappa"
version = "0.17.0"
version = "0.17.1"
description = "Declarative CLI argument parser."

repository = "https://github.com/dancardin/cappa"
Expand Down
50 changes: 36 additions & 14 deletions src/cappa/arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,15 @@ def normalize(
fallback_help: str | None = None,
action: ArgAction | Callable | None = None,
default: typing.Any = missing,
value_name: str | None = None,
field_name: str | None = None,
) -> Arg:
origin = typing.get_origin(annotation) or annotation
type_args = typing.get_args(annotation)

field_name = typing.cast(str, field_name or self.field_name)
value_name = value_name or (
self.value_name if self.value_name is not missing else field_name
)
default = default if default is not missing else self.default

verify_type_compatibility(self, annotation, origin, type_args)
verify_type_compatibility(self, field_name, annotation, origin, type_args)
short = infer_short(self, field_name)
long = infer_long(self, origin, field_name)
choices = infer_choices(self, origin, type_args)
Expand All @@ -204,6 +200,8 @@ def normalize(

group = infer_group(self, short, long)

value_name = infer_value_name(self, field_name, num_args)

return dataclasses.replace(
self,
default=default,
Expand Down Expand Up @@ -237,7 +235,11 @@ def names_str(self, delimiter: str = ", ", *, n=0) -> str:


def verify_type_compatibility(
arg: Arg, annotation: type, origin: type, type_args: tuple[type, ...]
arg: Arg,
field_name: str,
annotation: type,
origin: type,
type_args: tuple[type, ...],
):
"""Verify classes of annotations are compatible with one another.
Expand All @@ -259,25 +261,27 @@ def verify_type_compatibility(
}
if len(all_same_arity) > 1:
raise ValueError(
f"On field '{arg.field_name}', apparent mismatch of annotated type with `Arg` options. "
f"On field '{field_name}', apparent mismatch of annotated type with `Arg` options. "
'Unioning "sequence" types with non-sequence types is not currently supported, '
"unless using `Arg(parse=...)` or `Arg(action=<callable>)`. "
"See [documentation](https://cappa.readthedocs.io/en/latest/annotation.html) for more details."
)
return

num_args = arg.num_args
# print(is_sequence_type(origin), num_args, action)
# print(f" {num_args not in {0, 1} or action is ArgAction.append}")
if is_sequence_type(origin):
if num_args == 1 or action not in {ArgAction.append, None}:
if num_args in {0, 1} and action not in {ArgAction.append, None}:
raise ValueError(
f"On field '{arg.field_name}', apparent mismatch of annotated type with `Arg` options. "
f"On field '{field_name}', apparent mismatch of annotated type with `Arg` options. "
f"'{annotation}' type produces a sequence, whereas `num_args=1`/`action={action}` do not. "
"See [documentation](https://cappa.readthedocs.io/en/latest/annotation.html) for more details."
)
else:
if num_args not in {None, 1} or action is ArgAction.append:
if num_args not in {None, 0, 1} or action is ArgAction.append:
raise ValueError(
f"On field '{arg.field_name}', apparent mismatch of annotated type with `Arg` options. "
f"On field '{field_name}', apparent mismatch of annotated type with `Arg` options. "
f"'{origin.__name__}' type produces a scalar, whereas `num_args={num_args}`/`action={action}` do not. "
"See [documentation](https://cappa.readthedocs.io/en/latest/annotation.html) for more details."
)
Expand Down Expand Up @@ -350,7 +354,7 @@ def infer_short(arg: Arg, name: str) -> list[str] | typing.Literal[False]:
else:
short = arg.short

return [item if item.startswith("-") else f"-{item[0]}" for item in short]
return [item if item.startswith("-") else f"-{item}" for item in short]


def infer_long(arg: Arg, origin: type, name: str) -> list[str] | typing.Literal[False]:
Expand Down Expand Up @@ -416,7 +420,12 @@ def infer_action(
has_specific_num_args = arg.num_args is not None
unbounded_num_args = arg.num_args == -1

if arg.parse or unbounded_num_args or (is_positional and not has_specific_num_args):
if (
arg.parse
or unbounded_num_args
or (is_positional and not has_specific_num_args)
or (has_specific_num_args and arg.num_args != 1)
):
return ArgAction.set

if is_of_type(annotation, (typing.List, typing.Set)):
Expand All @@ -436,7 +445,7 @@ def infer_num_args(
type_args: tuple[type, ...],
action: ArgAction | Callable,
long,
) -> int | None:
) -> int:
if arg.num_args is not None:
return arg.num_args

Expand Down Expand Up @@ -551,6 +560,19 @@ def infer_group(
return typing.cast(typing.Tuple[int, str], group)


def infer_value_name(arg: Arg, field_name: str, num_args: int | None) -> str:
if arg.value_name is not missing:
return arg.value_name

if num_args == -1:
return f"{field_name} ..."

if num_args and num_args > 1:
return " ".join([field_name] * num_args)

return field_name


no_extra_arg_actions = {
ArgAction.store_true,
ArgAction.store_false,
Expand Down
24 changes: 15 additions & 9 deletions src/cappa/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,16 +484,8 @@ def consume_arg(
result = []
while num_args:
if isinstance(context.peek_value(), RawOption):
if orig_num_args < 0:
break
break

raise BadArgumentError(
f"Argument requires {orig_num_args} values, "
f"only found {len(result)} ('{' '.join(result)}' so far)",
value=result,
command=context.command,
arg=arg,
)
try:
next_val = typing.cast(RawArg, context.next_value())
except IndexError:
Expand Down Expand Up @@ -538,6 +530,20 @@ def consume_arg(
command=context.command,
arg=arg,
)
else:
if orig_num_args > 0 and len(result) != orig_num_args:
quoted_result = [f"'{r}'" for r in result]
names_str = arg.names_str("/")

message = f"Argument '{names_str}' requires {orig_num_args} values, found {len(result)}"
if quoted_result:
message += f" ({', '.join(quoted_result)} so far)"
raise BadArgumentError(
message,
value=result,
command=context.command,
arg=arg,
)

if option and arg.field_name in context.missing_options:
context.missing_options.remove(arg.field_name)
Expand Down
15 changes: 12 additions & 3 deletions tests/arg/test_invalid_annotation_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Args:
def test_sequence_with_scalar_action(backend):
@dataclass
class Args:
foo: Annotated[list[str], cappa.Arg(action=cappa.ArgAction.set)]
foo: Annotated[list[str], cappa.Arg(action=cappa.ArgAction.set, num_args=1)]

with pytest.raises(ValueError) as e:
parse(Args, "--help", backend=backend)
Expand All @@ -50,13 +50,22 @@ def test_sequence_with_scalar_num_args(backend):
class Args:
foo: Annotated[list[str], cappa.Arg(num_args=1, short=True)]

args = parse(Args, "-f", "a", "-f", "b", backend=backend)
assert args == Args(["a", "b"])

@dataclass
class ArgsBad:
foo: Annotated[
list[str], cappa.Arg(num_args=1, short=True, action=cappa.ArgAction.set)
]

with pytest.raises(ValueError) as e:
parse(Args, "--help", backend=backend)
parse(ArgsBad, "--help", backend=backend)

result = str(e.value).replace("typing.List", "list")
assert result == (
"On field 'foo', apparent mismatch of annotated type with `Arg` options. "
"'list[str]' type produces a sequence, whereas `num_args=1`/`action=None` do not. "
"'list[str]' type produces a sequence, whereas `num_args=1`/`action=ArgAction.set` do not. "
"See [documentation](https://cappa.readthedocs.io/en/latest/annotation.html) for more details."
)

Expand Down
39 changes: 39 additions & 0 deletions tests/arg/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from dataclasses import dataclass

import cappa
import pytest
from typing_extensions import Annotated

from tests.utils import backends, parse


Expand All @@ -14,3 +18,38 @@ class ArgTest:
def test_valid(backend):
test = parse(ArgTest, "1", "2", "3.4", backend=backend)
assert test.numbers == (1, "2", 3.4)


@backends
def test_tuple_option(backend):
@dataclass
class Example:
start_project: Annotated[
tuple[int, float],
cappa.Arg(short=True, default=(1, 9), required=False),
]

test = parse(Example, backend=backend)
assert test == Example(start_project=(1, 9))

test = parse(Example, "-s", "2", "2.4", backend=backend)
assert test == Example(start_project=(2, 2.4))

# Missing values
with pytest.raises(cappa.Exit) as e:
parse(Example, "-s", "1", backend=backend)

assert e.value.code == 2

if backend:
assert str(e.value.message).lower() == "argument -s: expected 2 arguments"
else:
assert (
e.value.message == "Argument '-s' requires 2 values, found 1 ('1' so far)"
)

# Extra values
with pytest.raises(cappa.Exit) as e:
parse(Example, "-s", "1", "2", "3", backend=backend)
assert e.value.code == 2
assert e.value.message == "Unrecognized arguments: 3"
2 changes: 1 addition & 1 deletion tests/parser/test_missing_num_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ class Args:
if backend == argparse.backend:
assert "The following arguments are required: arg" in message
else:
assert message == "Argument requires 2 values, only found 1 ('arg' so far)"
assert message == "Argument 'arg arg' requires 2 values, found 1 ('arg' so far)"
33 changes: 33 additions & 0 deletions tests/parser/test_num_args.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass

import cappa
Expand Down Expand Up @@ -25,3 +27,34 @@ class Args:

assert e.value.message is None
assert e.value.code == 0


@backends
def test_num_args_list(backend):
@dataclass
class Args:
foo: Annotated[
list[str],
cappa.Arg(
short=True,
default=["", ""],
required=False,
num_args=2,
),
]

args = parse(Args, backend=backend)
assert args == Args(["", ""])

args = parse(Args, "-f", "2", "4", backend=backend)
assert args == Args(["2", "4"])

with pytest.raises(cappa.Exit) as e:
parse(Args, "-f", backend=backend)

assert e.value.code == 2

if backend:
assert str(e.value.message).lower() == "argument -f: expected 2 arguments"
else:
assert e.value.message == "Argument '-f' requires 2 values, found 0"

0 comments on commit 0bde925

Please sign in to comment.