Skip to content

Commit

Permalink
Cache seen serials for tab completion.
Browse files Browse the repository at this point in the history
  • Loading branch information
dainnilsson committed Jan 15, 2024
1 parent be354f8 commit 286ec5d
Showing 1 changed file with 44 additions and 30 deletions.
74 changes: 44 additions & 30 deletions ykman/_cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@

from .. import __version__
from ..pcsc import list_devices as list_ccid, list_readers
from ..device import scan_devices, list_all_devices
from ..device import scan_devices, list_all_devices as _list_all_devices
from ..util import get_windows_version
from ..logging import init_logging
from ..diagnostics import get_diagnostics, sys_info
from ..settings import AppData
from .util import YkmanContextObject, click_group, EnumChoice, CliFail, pretty_print
from .info import info
from .otp import otp
Expand All @@ -54,7 +55,6 @@
import click
import click.shell_completion
import ctypes
import os
import time
import sys

Expand Down Expand Up @@ -116,6 +116,20 @@ def require_reader(connection_types, reader):
raise CliFail("Not a CCID command.")


def list_all_devices(*args, **kwargs):
devices = _list_all_devices(*args, **kwargs)
if devices:
history = AppData("history")
cache = history.setdefault("devices", {})
for dev, dev_info in devices:
if dev_info.serial:
k = str(dev_info.serial)
cache[k] = cache.pop(k, None) or _describe_device(dev, info, False)
[cache.pop(k) for k in list(cache.keys())[:-3]]
history.write()
return devices


def require_device(connection_types, serial=None):
# Find all connected devices
devices, state = scan_devices()
Expand All @@ -128,6 +142,7 @@ def require_device(connection_types, serial=None):
except TimeoutError:
raise CliFail("No YubiKey detected!")
if n_devs > 1:
list_all_devices() # Update device cache
raise CliFail(
"Multiple YubiKeys detected. Use --device SERIAL to specify "
"which one to use."
Expand All @@ -153,45 +168,41 @@ def require_device(connection_types, serial=None):
raise CliFail("Failed to connect to YubiKey.")
return devs[0]
else:
for _ in (0, 1): # If no match initially, wait a bit for state change.
for retry in (
True,
False,
): # If no match initially, wait a bit for state change.
devs = list_all_devices(connection_types)
for dev, nfo in devs:
if nfo.serial == serial:
return dev, nfo
devices, state = _scan_changes(state)
for dev, dev_info in devs:
if dev_info.serial == serial:
return dev, dev_info
try:
if retry:
_, state = _scan_changes(state)
except TimeoutError:
break

raise CliFail(
f"Failed connecting to a YubiKey with serial: {serial}.\n"
"Make sure the application has the required permissions.",
)


def _experimental_completion(env_var_name, f):
if env_var_name in os.environ:
return f
else:
return lambda ctx, param, incomplete: []


@click_group(context_settings=CLICK_CONTEXT_SETTINGS)
@click.option(
"-d",
"--device",
type=int,
metavar="SERIAL",
help="specify which YubiKey to interact with by serial number",
shell_complete=_experimental_completion(
# Leading underscore for uniformity with _YKMAN_COMPLETE from Click
"_YKMAN_EXPERIMENTAL_COMPLETE_DEVICE",
lambda ctx, param, incomplete: [
click.shell_completion.CompletionItem(
str(dev_info.serial),
help=_describe_device(dev, dev_info),
)
for dev, dev_info in list_all_devices()
if dev_info.serial and str(dev_info.serial).startswith(incomplete)
],
),
shell_complete=lambda ctx, param, incomplete: [
click.shell_completion.CompletionItem(
serial,
help=description,
)
for serial, description in AppData("history").get("devices", {}).items()
if serial.startswith(incomplete)
],
)
@click.option(
"-r",
Expand All @@ -201,7 +212,7 @@ def _experimental_completion(env_var_name, f):
metavar="NAME",
default=None,
shell_complete=lambda ctx, param, incomplete: [
reader.name for reader in list_readers()
f'"{reader.name}"' for reader in list_readers()
],
)
@click.option(
Expand Down Expand Up @@ -347,13 +358,16 @@ def list_keys(ctx, serials, readers):
click.echo(f"{name} [{mode}] <access denied>")


def _describe_device(dev, dev_info):
def _describe_device(dev, dev_info, include_mode=True):
if dev.pid is None: # Devices from list_all_devices should always have PID.
raise AssertionError("PID is None")
name = get_name(dev_info, dev.pid.yubikey_type)
version = dev_info.version or "unknown"
mode = dev.pid.name.split("_", 1)[1].replace("_", "+")
return f"{name} ({version}) [{mode}]"
description = f"{name} ({version})"
if include_mode:
mode = dev.pid.name.split("_", 1)[1].replace("_", "+")
description += f" [{mode}]"
return description


COMMANDS = (
Expand Down

0 comments on commit 286ec5d

Please sign in to comment.