Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shell completion for --device and --reader #443

Merged
merged 8 commits into from
Jan 23, 2024
68 changes: 56 additions & 12 deletions ykman/_cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@

from .. import __version__
from ..pcsc import list_devices as list_ccid, list_readers
from ..device import scan_devices, list_all_devices
from ..device import scan_devices, list_all_devices as _list_all_devices
from ..util import get_windows_version
from ..logging import init_logging
from ..diagnostics import get_diagnostics, sys_info
from ..settings import AppData
from .util import YkmanContextObject, click_group, EnumChoice, CliFail, pretty_print
from .info import info
from .otp import otp
Expand All @@ -52,6 +53,7 @@
from .hsmauth import hsmauth

import click
import click.shell_completion
import ctypes
import time
import sys
Expand Down Expand Up @@ -114,6 +116,22 @@ def require_reader(connection_types, reader):
raise CliFail("Not a CCID command.")


def list_all_devices(*args, **kwargs):
devices = _list_all_devices(*args, **kwargs)
with_serial = [(dev, dev_info) for (dev, dev_info) in devices if dev_info.serial]
if with_serial:
history = AppData("history")
cache = history.setdefault("devices", {})
for dev, dev_info in with_serial:
if dev_info.serial:
k = str(dev_info.serial)
cache[k] = cache.pop(k, None) or _describe_device(dev, dev_info, False)
# 5, chosen by fair dice roll
[cache.pop(k) for k in list(cache.keys())[: -max(5, len(with_serial))]]
history.write()
return devices


def require_device(connection_types, serial=None):
# Find all connected devices
devices, state = scan_devices()
Expand All @@ -126,6 +144,7 @@ def require_device(connection_types, serial=None):
except TimeoutError:
raise CliFail("No YubiKey detected!")
if n_devs > 1:
list_all_devices() # Update device cache
raise CliFail(
"Multiple YubiKeys detected. Use --device SERIAL to specify "
"which one to use."
Expand All @@ -151,12 +170,19 @@ def require_device(connection_types, serial=None):
raise CliFail("Failed to connect to YubiKey.")
return devs[0]
else:
for _ in (0, 1): # If no match initially, wait a bit for state change.
for retry in (
True,
False,
): # If no match initially, wait a bit for state change.
devs = list_all_devices(connection_types)
for dev, nfo in devs:
if nfo.serial == serial:
return dev, nfo
devices, state = _scan_changes(state)
for dev, dev_info in devs:
if dev_info.serial == serial:
return dev, dev_info
try:
if retry:
_, state = _scan_changes(state)
except TimeoutError:
break

raise CliFail(
f"Failed connecting to a YubiKey with serial: {serial}.\n"
Expand All @@ -171,6 +197,14 @@ def require_device(connection_types, serial=None):
type=int,
metavar="SERIAL",
help="specify which YubiKey to interact with by serial number",
shell_complete=lambda ctx, param, incomplete: [
click.shell_completion.CompletionItem(
serial,
help=description,
)
for serial, description in AppData("history").get("devices", {}).items()
if serial.startswith(incomplete)
],
)
@click.option(
"-r",
Expand All @@ -179,6 +213,9 @@ def require_device(connection_types, serial=None):
"(can't be used with --device or list)",
metavar="NAME",
default=None,
shell_complete=lambda ctx, param, incomplete: [
f'"{reader.name}"' for reader in list_readers()
],
)
@click.option(
"-l",
Expand Down Expand Up @@ -306,13 +343,8 @@ def list_keys(ctx, serials, readers):
if dev_info.serial:
click.echo(dev_info.serial)
else:
if dev.pid is None: # Devices from list_all_devices should always have PID.
raise AssertionError("PID is None")
name = get_name(dev_info, dev.pid.yubikey_type)
version = dev_info.version or "unknown"
mode = dev.pid.name.split("_", 1)[1].replace("_", "+")
click.echo(
f"{name} ({version}) [{mode}]"
_describe_device(dev, dev_info)
+ (f" Serial: {dev_info.serial}" if dev_info.serial else "")
)
pids.add(dev.pid)
Expand All @@ -328,6 +360,18 @@ def list_keys(ctx, serials, readers):
click.echo(f"{name} [{mode}] <access denied>")


def _describe_device(dev, dev_info, include_mode=True):
if dev.pid is None: # Devices from list_all_devices should always have PID.
raise AssertionError("PID is None")
name = get_name(dev_info, dev.pid.yubikey_type)
version = dev_info.version or "unknown"
description = f"{name} ({version})"
if include_mode:
mode = dev.pid.name.split("_", 1)[1].replace("_", "+")
description += f" [{mode}]"
return description


COMMANDS = (
list_keys,
info,
Expand Down
Loading