Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: eager input to *not* read input after writing it #314

Merged
merged 4 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ Changelog

- Expand `arista_eos` prompt pattern to handle super long config sections (things like qos queues and such). Thanks
to @MarkRudenko over in scrapli_cfg repo for finding this and providing the fix!

- Add `comms_roughly_match_inputs` option -- this uses a "rough" match when looking for inputs (commands/configs you
send) in output printed back on the channel. Basically, if all input characters show up in the output in the correct
order, then we assume the input was found. Of course this could be less "exacting" but it also *probably* is ok 99%
of the time :)
- Added an `eager_input` option to send operations -- this option completely skips checking for inputs being echoed back
on the channel. With the addition of the `comms_roughly_match_inputs` option this is *probably* unnecessary, but
could be useful for some corner cases.

## 2023.07.30

Expand Down
22 changes: 18 additions & 4 deletions scrapli/channel/async_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scrapli.channel.base_channel import BaseChannel, BaseChannelArgs
from scrapli.decorators import timeout_wrapper
from scrapli.exceptions import ScrapliAuthenticationFailed, ScrapliTimeout
from scrapli.helper import output_roughly_contains_input
from scrapli.transport.base import AsyncTransport


Expand Down Expand Up @@ -104,9 +105,16 @@ async def _read_until_input(self, channel_input: bytes) -> bytes:
while True:
buf += await self.read()

# replace any backspace chars (particular problem w/ junos), and remove any added spaces
# this is just for comparison of the inputs to what was read from channel
if processed_channel_input in b"".join(buf.lower().replace(b"\x08", b"").split()):
if not self._base_channel_args.comms_roughly_match_inputs:
# replace any backspace chars (particular problem w/ junos), and remove any added
# spaces this is just for comparison of the inputs to what was read from channel
# note (2024) this would be worked around by using the roughly contains search,
# *but* that is slower (probably immaterially for most people but... ya know...)
processed_buf = b"".join(buf.lower().replace(b"\x08", b"").split())

if processed_channel_input in processed_buf:
return buf
elif output_roughly_contains_input(input_=processed_channel_input, output=buf):
return buf

async def _read_until_prompt(self, buf: bytes = b"") -> bytes:
Expand Down Expand Up @@ -455,6 +463,7 @@ async def send_input(
*,
strip_prompt: bool = True,
eager: bool = False,
eager_input: bool = False,
) -> Tuple[bytes, bytes]:
"""
Primary entry point to send data to devices in shell mode; accept input and returns result
Expand All @@ -465,6 +474,8 @@ async def send_input(
eager: eager mode reads and returns the `_read_until_input` value, but does not attempt
to read to the prompt pattern -- this should not be used manually! (only used by
`send_configs` with the eager flag set)
eager_input: when true does *not* try to read our input off the channel -- generally
this should be left alone unless you know what you are doing!

Returns:
Tuple[bytes, bytes]: tuple of "raw" output and "processed" (cleaned up/stripped) output
Expand All @@ -484,7 +495,10 @@ async def send_input(

async with self._channel_lock():
self.write(channel_input=channel_input)
_buf_until_input = await self._read_until_input(channel_input=bytes_channel_input)

if not eager_input:
_buf_until_input = await self._read_until_input(channel_input=bytes_channel_input)

self.send_return()

if not eager:
Expand Down
7 changes: 7 additions & 0 deletions scrapli/channel/base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class BaseChannelArgs:
comms_prompt_search_depth: depth of the buffer to search in for searching for the prompt
in "read_until_prompt"; smaller number here will generally be faster, though may be less
reliable; default value is 1000
comms_roughly_match_inputs: indicates if the channel should "roughly" match inputs sent
to the device. If False (default) inputs are strictly checked, as in any input
*must* be read back exactly on the channel. When set to True all input chars *must*
be read back in order in the output and all chars must be present, but the *exact*
input string does not need to be seen. This can be useful if a device echoes back
extra characters or rewrites the terminal during command input.
timeout_ops: timeout_ops to assign to the channel, see above
channel_log: log "channel" output -- this would be the output you would normally see on a
terminal. If `True` logs to `scrapli_channel.log`, if a string is provided, logs to
Expand All @@ -61,6 +67,7 @@ class BaseChannelArgs:
comms_prompt_pattern: str = r"^[a-z0-9.\-@()/:]{1,32}[#>$]$"
comms_return_char: str = "\n"
comms_prompt_search_depth: int = 1000
comms_roughly_match_inputs: bool = False
timeout_ops: float = 30.0
channel_log: Union[str, bool, BytesIO] = False
channel_log_mode: str = "write"
Expand Down
22 changes: 18 additions & 4 deletions scrapli/channel/sync_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scrapli.channel.base_channel import BaseChannel, BaseChannelArgs
from scrapli.decorators import timeout_wrapper
from scrapli.exceptions import ScrapliAuthenticationFailed, ScrapliConnectionError, ScrapliTimeout
from scrapli.helper import output_roughly_contains_input
from scrapli.transport.base import Transport


Expand Down Expand Up @@ -104,9 +105,16 @@ def _read_until_input(self, channel_input: bytes) -> bytes:
while True:
buf += self.read()

# replace any backspace chars (particular problem w/ junos), and remove any added spaces
# this is just for comparison of the inputs to what was read from channel
if processed_channel_input in b"".join(buf.lower().replace(b"\x08", b"").split()):
if not self._base_channel_args.comms_roughly_match_inputs:
# replace any backspace chars (particular problem w/ junos), and remove any added
# spaces this is just for comparison of the inputs to what was read from channel
# note (2024) this would be worked around by using the roughly contains search,
# *but* that is slower (probably immaterially for most people but... ya know...)
processed_buf = b"".join(buf.lower().replace(b"\x08", b"").split())

if processed_channel_input in processed_buf:
return buf
elif output_roughly_contains_input(input_=processed_channel_input, output=buf):
return buf

def _read_until_prompt(self, buf: bytes = b"") -> bytes:
Expand Down Expand Up @@ -456,6 +464,7 @@ def send_input(
*,
strip_prompt: bool = True,
eager: bool = False,
eager_input: bool = False,
) -> Tuple[bytes, bytes]:
"""
Primary entry point to send data to devices in shell mode; accept input and returns result
Expand All @@ -466,6 +475,8 @@ def send_input(
eager: eager mode reads and returns the `_read_until_input` value, but does not attempt
to read to the prompt pattern -- this should not be used manually! (only used by
`send_configs` with the eager flag set)
eager_input: when true does *not* try to read our input off the channel -- generally
this should be left alone unless you know what you are doing!

Returns:
Tuple[bytes, bytes]: tuple of "raw" output and "processed" (cleaned up/stripped) output
Expand All @@ -485,7 +496,10 @@ def send_input(

with self._channel_lock():
self.write(channel_input=channel_input)
_buf_until_input = self._read_until_input(channel_input=bytes_channel_input)

if not eager_input:
_buf_until_input = self._read_until_input(channel_input=bytes_channel_input)

self.send_return()

if not eager:
Expand Down
89 changes: 88 additions & 1 deletion scrapli/driver/base/base_driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""scrapli.driver.base.base_driver"""
"""scrapli.driver.base.base_driver""" # noqa: C0302
import importlib
from dataclasses import fields
from io import BytesIO
Expand Down Expand Up @@ -34,6 +34,7 @@
timeout_ops: float = 30.0,
comms_prompt_pattern: str = r"^[a-z0-9.\-@()/:]{1,48}[#>$]\s*$",
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -83,6 +84,12 @@
should be mostly sorted for you if using network drivers (i.e. `IOSXEDriver`).
Lastly, the case insensitive is just a convenience factor so i can be lazy.
comms_return_char: character to use to send returns to host
comms_roughly_match_inputs: indicates if the channel should "roughly" match inputs sent
to the device. If False (default) inputs are strictly checked, as in any input
*must* be read back exactly on the channel. When set to True all input chars *must*
be read back in order in the output and all chars must be present, but the *exact*
input string does not need to be seen. This can be useful if a device echoes back
extra characters or rewrites the terminal during command input.
ssh_config_file: string to path for ssh config file, True to use default ssh config file
or False to ignore default ssh config file
ssh_known_hosts_file: string to path for ssh known hosts file, True to use default known
Expand Down Expand Up @@ -149,6 +156,7 @@
auth_passphrase_pattern=auth_passphrase_pattern,
comms_prompt_pattern=comms_prompt_pattern,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
timeout_ops=timeout_ops,
channel_log=channel_log,
channel_log_mode=channel_log_mode,
Expand Down Expand Up @@ -249,6 +257,7 @@
f"timeout_ops={self._base_channel_args.timeout_ops!r}, "
f"comms_prompt_pattern={self._base_channel_args.comms_prompt_pattern!r}, "
f"comms_return_char={self._base_channel_args.comms_return_char!r}, "
f"comms_roughly_match_inputs={self._base_channel_args.comms_roughly_match_inputs!r}, "
f"ssh_config_file={self.ssh_config_file!r}, "
f"ssh_known_hosts_file={self.ssh_known_hosts_file!r}, "
f"on_init={self.on_init!r}, "
Expand Down Expand Up @@ -738,6 +747,84 @@

self._base_channel_args.comms_return_char = value

@property
def comms_prompt_search_depth(self) -> int:
"""
Getter for `comms_prompt_search_depth` attribute

Args:
N/A

Returns:
int: comms_prompt_search_depth int

Raises:
N/A

"""
return self._base_channel_args.comms_prompt_search_depth

Check warning on line 765 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L765

Added line #L765 was not covered by tests

@comms_prompt_search_depth.setter
def comms_prompt_search_depth(self, value: int) -> None:
"""
Setter for `comms_prompt_search_depth` attribute

Args:
value: int value for comms_prompt_search_depth

Returns:
None

Raises:
ScrapliTypeError: if value is not of type int

"""
self.logger.debug(f"setting 'comms_prompt_search_depth' value to {value!r}")

Check warning on line 782 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L782

Added line #L782 was not covered by tests

if not isinstance(value, int):
raise ScrapliTypeError

Check warning on line 785 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L784-L785

Added lines #L784 - L785 were not covered by tests

self._base_channel_args.comms_prompt_search_depth = value

Check warning on line 787 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L787

Added line #L787 was not covered by tests

@property
def comms_roughly_match_inputs(self) -> bool:
"""
Getter for `comms_roughly_match_inputs` attribute

Args:
N/A

Returns:
bool: comms_roughly_match_inputs bool

Raises:
N/A

"""
return self._base_channel_args.comms_roughly_match_inputs

Check warning on line 804 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L804

Added line #L804 was not covered by tests

@comms_roughly_match_inputs.setter
def comms_roughly_match_inputs(self, value: bool) -> None:
"""
Setter for `comms_roughly_match_inputs` attribute

Args:
value: int value for comms_roughly_match_inputs

Returns:
None

Raises:
ScrapliTypeError: if value is not of type bool

"""
self.logger.debug(f"setting 'comms_roughly_match_inputs' value to {value!r}")

Check warning on line 821 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L821

Added line #L821 was not covered by tests

if not isinstance(value, bool):
raise ScrapliTypeError

Check warning on line 824 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L823-L824

Added lines #L823 - L824 were not covered by tests

self._base_channel_args.comms_roughly_match_inputs = value

Check warning on line 826 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L826

Added line #L826 was not covered by tests

@property
def timeout_socket(self) -> float:
"""
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/arista_eos/async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/arista_eos/sync_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/cisco_iosxe/async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/cisco_iosxe/sync_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/cisco_iosxr/async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/cisco_iosxr/sync_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/cisco_nxos/async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
Loading