diff --git a/docs/api_docs/channel/async_channel.md b/docs/api_docs/channel/async_channel.md index 8a04cc28..97c167fc 100644 --- a/docs/api_docs/channel/async_channel.md +++ b/docs/api_docs/channel/async_channel.md @@ -32,19 +32,13 @@ scrapli.channel.async_channel import asyncio import re import time -from io import BytesIO - -try: - from contextlib import asynccontextmanager -except ImportError: # pragma: nocover - # needed for 3.6 support, no asynccontextmanager until 3.7 - from async_generator import asynccontextmanager # type: ignore # pragma: nocover - +from contextlib import asynccontextmanager from datetime import datetime +from io import BytesIO from typing import AsyncIterator, List, Optional, Tuple from scrapli.channel.base_channel import BaseChannel, BaseChannelArgs -from scrapli.decorators import ChannelTimeout +from scrapli.decorators import timeout_wrapper from scrapli.exceptions import ScrapliAuthenticationFailed, ScrapliTimeout from scrapli.transport.base import AsyncTransport @@ -290,7 +284,7 @@ class AsyncChannel(BaseChannel): return read_buf.getvalue() - @ChannelTimeout(message="timed out during in channel ssh authentication") + @timeout_wrapper async def channel_authenticate_ssh( self, auth_password: str, auth_private_key_passphrase: str ) -> None: @@ -363,7 +357,7 @@ class AsyncChannel(BaseChannel): ): return - @ChannelTimeout(message="timed out during in channel telnet authentication") + @timeout_wrapper async def channel_authenticate_telnet( # noqa: C901 self, auth_username: str = "", auth_password: str = "" ) -> None: @@ -450,7 +444,7 @@ class AsyncChannel(BaseChannel): ): return - @ChannelTimeout(message="timed out getting prompt") + @timeout_wrapper async def get_prompt(self) -> str: """ Get current channel prompt @@ -486,7 +480,7 @@ class AsyncChannel(BaseChannel): current_prompt = channel_match.group(0) return current_prompt.decode().strip() - @ChannelTimeout(message="timed out sending input to device") + @timeout_wrapper async def send_input( self, channel_input: str, @@ -534,7 +528,7 @@ class AsyncChannel(BaseChannel): ) return buf, processed_buf - @ChannelTimeout(message="timed out sending input to device") + @timeout_wrapper async def send_input_and_read( self, channel_input: str, @@ -589,7 +583,7 @@ class AsyncChannel(BaseChannel): return buf, processed_buf - @ChannelTimeout(message="timed out sending interactive input to device") + @timeout_wrapper async def send_inputs_interact( self, interact_events: List[Tuple[str, str, Optional[bool]]], @@ -970,7 +964,7 @@ class AsyncChannel(BaseChannel): return read_buf.getvalue() - @ChannelTimeout(message="timed out during in channel ssh authentication") + @timeout_wrapper async def channel_authenticate_ssh( self, auth_password: str, auth_private_key_passphrase: str ) -> None: @@ -1043,7 +1037,7 @@ class AsyncChannel(BaseChannel): ): return - @ChannelTimeout(message="timed out during in channel telnet authentication") + @timeout_wrapper async def channel_authenticate_telnet( # noqa: C901 self, auth_username: str = "", auth_password: str = "" ) -> None: @@ -1130,7 +1124,7 @@ class AsyncChannel(BaseChannel): ): return - @ChannelTimeout(message="timed out getting prompt") + @timeout_wrapper async def get_prompt(self) -> str: """ Get current channel prompt @@ -1166,7 +1160,7 @@ class AsyncChannel(BaseChannel): current_prompt = channel_match.group(0) return current_prompt.decode().strip() - @ChannelTimeout(message="timed out sending input to device") + @timeout_wrapper async def send_input( self, channel_input: str, @@ -1214,7 +1208,7 @@ class AsyncChannel(BaseChannel): ) return buf, processed_buf - @ChannelTimeout(message="timed out sending input to device") + @timeout_wrapper async def send_input_and_read( self, channel_input: str, @@ -1269,7 +1263,7 @@ class AsyncChannel(BaseChannel): return buf, processed_buf - @ChannelTimeout(message="timed out sending interactive input to device") + @timeout_wrapper async def send_inputs_interact( self, interact_events: List[Tuple[str, str, Optional[bool]]], diff --git a/docs/api_docs/channel/sync_channel.md b/docs/api_docs/channel/sync_channel.md index 999cedab..d55df88f 100644 --- a/docs/api_docs/channel/sync_channel.md +++ b/docs/api_docs/channel/sync_channel.md @@ -38,7 +38,7 @@ from threading import Lock from typing import Iterator, List, Optional, Tuple from scrapli.channel.base_channel import BaseChannel, BaseChannelArgs -from scrapli.decorators import ChannelTimeout +from scrapli.decorators import timeout_wrapper from scrapli.exceptions import ScrapliAuthenticationFailed, ScrapliTimeout from scrapli.transport.base import Transport @@ -281,7 +281,7 @@ class Channel(BaseChannel): return read_buf.getvalue() - @ChannelTimeout(message="timed out during in channel ssh authentication") + @timeout_wrapper def channel_authenticate_ssh( self, auth_password: str, auth_private_key_passphrase: str ) -> None: @@ -353,7 +353,7 @@ class Channel(BaseChannel): ): return - @ChannelTimeout(message="timed out during in channel telnet authentication") + @timeout_wrapper def channel_authenticate_telnet(self, auth_username: str = "", auth_password: str = "") -> None: """ Handle Telnet Authentication @@ -434,7 +434,7 @@ class Channel(BaseChannel): ): return - @ChannelTimeout(message="timed out getting prompt") + @timeout_wrapper def get_prompt(self) -> str: """ Get current channel prompt @@ -470,7 +470,7 @@ class Channel(BaseChannel): current_prompt = channel_match.group(0) return current_prompt.decode().strip() - @ChannelTimeout(message="timed out sending input to device") + @timeout_wrapper def send_input( self, channel_input: str, @@ -518,7 +518,7 @@ class Channel(BaseChannel): ) return buf, processed_buf - @ChannelTimeout(message="timed out sending input to device") + @timeout_wrapper def send_input_and_read( self, channel_input: str, @@ -573,7 +573,7 @@ class Channel(BaseChannel): return buf, processed_buf - @ChannelTimeout(message="timed out sending interactive input to device") + @timeout_wrapper def send_inputs_interact( self, interact_events: List[Tuple[str, str, Optional[bool]]], @@ -951,7 +951,7 @@ class Channel(BaseChannel): return read_buf.getvalue() - @ChannelTimeout(message="timed out during in channel ssh authentication") + @timeout_wrapper def channel_authenticate_ssh( self, auth_password: str, auth_private_key_passphrase: str ) -> None: @@ -1023,7 +1023,7 @@ class Channel(BaseChannel): ): return - @ChannelTimeout(message="timed out during in channel telnet authentication") + @timeout_wrapper def channel_authenticate_telnet(self, auth_username: str = "", auth_password: str = "") -> None: """ Handle Telnet Authentication @@ -1104,7 +1104,7 @@ class Channel(BaseChannel): ): return - @ChannelTimeout(message="timed out getting prompt") + @timeout_wrapper def get_prompt(self) -> str: """ Get current channel prompt @@ -1140,7 +1140,7 @@ class Channel(BaseChannel): current_prompt = channel_match.group(0) return current_prompt.decode().strip() - @ChannelTimeout(message="timed out sending input to device") + @timeout_wrapper def send_input( self, channel_input: str, @@ -1188,7 +1188,7 @@ class Channel(BaseChannel): ) return buf, processed_buf - @ChannelTimeout(message="timed out sending input to device") + @timeout_wrapper def send_input_and_read( self, channel_input: str, @@ -1243,7 +1243,7 @@ class Channel(BaseChannel): return buf, processed_buf - @ChannelTimeout(message="timed out sending interactive input to device") + @timeout_wrapper def send_inputs_interact( self, interact_events: List[Tuple[str, str, Optional[bool]]], diff --git a/docs/api_docs/decorators.md b/docs/api_docs/decorators.md index 12c28830..dc0387c4 100644 --- a/docs/api_docs/decorators.md +++ b/docs/api_docs/decorators.md @@ -34,908 +34,359 @@ import signal import sys import threading from concurrent.futures import ThreadPoolExecutor, wait -from functools import update_wrapper +from functools import partial, update_wrapper from logging import Logger, LoggerAdapter -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Tuple from scrapli.exceptions import ScrapliTimeout if TYPE_CHECKING: - from scrapli.channel import Channel # pragma: no cover from scrapli.driver import AsyncGenericDriver, GenericDriver # pragma: no cover from scrapli.transport.base.base_transport import BaseTransport # pragma: no cover if TYPE_CHECKING: - LoggerAdapterT = LoggerAdapter[Logger] # pylint:disable=E1136 + LoggerAdapterT = LoggerAdapter[Logger] # pragma: no cover # pylint:disable=E1136 else: LoggerAdapterT = LoggerAdapter _IS_WINDOWS = sys.platform.startswith("win") -class TransportTimeout: - def __init__(self, message: str = "") -> None: - """ - Transport timeout decorator +FUNC_TIMEOUT_MESSAGE_MAP = { + "channel_authenticate_ssh": "timed out during in channel ssh authentication", + "channel_authenticate_telnet": "timed out during in channel telnet authentication", + "get_prompt": "timed out getting prompt", + "send_input": "timed out sending input to device", + "send_input_and_read": "timed out sending input to device", + "send_inputs_interact": "timed out sending interactive input to device", + "read": "timed out reading from transport", +} + + +def _get_timeout_message(func_name: str) -> str: + """ + Return appropriate timeout message for the given function name + + Args: + func_name: name of function to fetch timeout message for + + Returns: + str: timeout message + + Raises: + N/A + + """ + return FUNC_TIMEOUT_MESSAGE_MAP.get(func_name, "unspecified timeout occurred") + + +def _signal_raise_exception( + signum: Any, frame: Any, transport: "BaseTransport", logger: LoggerAdapterT, message: str +) -> None: + """ + Signal method exception handler + + Args: + signum: singum from the singal handler, unused here + frame: frame from the signal handler, unused here + transport: transport to close + logger: logger to write closing messages to + message: exception message + + Returns: + None + + Raises: + N/A + + """ + _, _ = signum, frame + + return _handle_timeout(transport=transport, logger=logger, message=message) + + +def _multiprocessing_timeout( + transport: "BaseTransport", + logger: LoggerAdapterT, + timeout: float, + wrapped_func: Callable[..., Any], + args: Any, + kwargs: Any, +) -> Any: + """ + Return appropriate timeout message for the given function name + + Args: + transport: transport to close (if timeout occurs) + logger: logger to write closing message to + timeout: timeout in seconds + wrapped_func: function being decorated + args: function args + kwargs: function kwargs + + Returns: + Any: result of the wrapped function + + Raises: + N/A + + """ + with ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(wrapped_func, *args, **kwargs) + wait([future], timeout=timeout) + if not future.done(): + return _handle_timeout( + transport=transport, + logger=logger, + message=_get_timeout_message(func_name=wrapped_func.__name__), + ) + return future.result() - Args: - message: accepts message from decorated function to add context to any timeout - (if a timeout happens!) - Returns: - None +def _handle_timeout(transport: "BaseTransport", logger: LoggerAdapterT, message: str) -> None: + """ + Timeout handler method to close connections and raise ScrapliTimeout - Raises: - N/A + Args: + transport: transport to close + logger: logger to write closing message to + message: message to pass to ScrapliTimeout exception - """ - self.message = message - self.transport_instance: "BaseTransport" - self.transport_timeout_transport = 0.0 + Returns: + None - def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]: - """ - Decorate an "operation" to modify the timeout_transport value for duration of that operation + Raises: + ScrapliTimeout: always, if we hit this method we have already timed out! - This decorator wraps a transport read operation and is used to allow users to control the - transport timeout via the `timeout_transport` attribute. This decorator should be applied to - any transport "read" operations. + """ + logger.critical("operation timed out, closing connection") + transport.close() + raise ScrapliTimeout(message) - Args: - wrapped_func: function being decorated - Returns: - decorate: decorated func +def _get_transport_logger_timeout( + cls: Any, +) -> Tuple["BaseTransport", LoggerAdapterT, float]: + """ + Fetch the transport, logger and timeout from the channel or transport object - Raises: - N/A + Args: + cls: Channel or Transport object (self from wrapped function) to grab transport/logger and + timeout values from - """ + Returns: + Tuple: transport, logger, and timeout value - if asyncio.iscoroutinefunction(wrapped_func): + Raises: + N/A - async def decorate(*args: Any, **kwargs: Any) -> Any: - self.transport_instance = args[0] - self.transport_timeout_transport = self._get_timeout_transport() + """ + if hasattr(cls, "transport"): + return ( + cls.transport, + cls.logger, + cls._base_channel_args.timeout_ops, # pylint: disable=W0212 + ) - if not self.transport_timeout_transport: - return await wrapped_func(*args, **kwargs) + return ( + cls, + cls.logger, + cls._base_transport_args.timeout_transport, # pylint: disable=W0212 + ) - try: - return await asyncio.wait_for( - wrapped_func(*args, **kwargs), timeout=self.transport_timeout_transport - ) - except asyncio.TimeoutError: - self._handle_timeout() - - else: - # ignoring type error: - # "All conditional function variants must have identical signatures" - # one is sync one is async so never going to be identical here! - def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore - self.transport_instance = args[0] - self.transport_timeout_transport = self._get_timeout_transport() - - if not self.transport_timeout_transport: - return wrapped_func(*args, **kwargs) - - transport_instance_class_name = self.transport_instance.__class__.__name__ - - if ( - transport_instance_class_name in ("SystemTransport", "TelnetTransport") - or _IS_WINDOWS - or threading.current_thread() is not threading.main_thread() - ): - return self._multiprocessing_timeout( - wrapped_func=wrapped_func, - args=args, - kwargs=kwargs, - ) - - old = signal.signal(signal.SIGALRM, self._signal_raise_exception) - signal.setitimer(signal.ITIMER_REAL, self.transport_timeout_transport) - try: - return wrapped_func(*args, **kwargs) - finally: - if self.transport_timeout_transport: - signal.setitimer(signal.ITIMER_REAL, 0) - signal.signal(signal.SIGALRM, old) - - # ensures that the wrapped function is updated w/ the original functions docs/etc. -- - # necessary for introspection for the auto gen docs to work! - update_wrapper(wrapper=decorate, wrapped=wrapped_func) - return decorate - - def _get_timeout_transport(self) -> float: - """ - Fetch and return timeout transport from the transport object - - Args: - N/A - - Returns: - float: transport timeout value - - Raises: - N/A - - """ - transport_args = self.transport_instance._base_transport_args # pylint: disable=W0212 - return transport_args.timeout_transport - - def _handle_timeout(self) -> None: - """ - Timeout handler method to close connections and raise ScrapliTimeout - - Args: - N/A - - Returns: - None - - Raises: - ScrapliTimeout: always, if we hit this method we have already timed out! - - """ - self.transport_instance.logger.critical("transport operation timed out, closing transport") - self.transport_instance.close() - raise ScrapliTimeout(self.message) - - def _multiprocessing_timeout( - self, wrapped_func: Callable[..., Any], args: Any, kwargs: Any - ) -> Any: - """ - Multiprocessing method for timeouts; works in threads and on windows - - Args: - wrapped_func: function being decorated - args: function being decorated args - kwargs: function being decorated kwargs - - Returns: - Any: result of decorated function - - Raises: - N/A - - """ - with ThreadPoolExecutor(max_workers=1) as pool: - future = pool.submit(wrapped_func, *args, **kwargs) - wait([future], timeout=self.transport_timeout_transport) - if not future.done(): - self._handle_timeout() - return future.result() - def _signal_raise_exception(self, signum: Any, frame: Any) -> None: - """ - Signal method exception handler +def timeout_wrapper(wrapped_func: Callable[..., Any]) -> Callable[..., Any]: + """ + Timeout wrapper for transports - Args: - signum: singum from the singal handler, unused here - frame: frame from the signal handler, unused here + Args: + wrapped_func: function being wrapped -- must be a method of Channel or Transport - Returns: - None + Returns: + Any: result of wrapped function - Raises: - N/A + Raises: + N/A - """ - _, _ = signum, frame - self._handle_timeout() + """ + if asyncio.iscoroutinefunction(wrapped_func): + async def decorate(*args: Any, **kwargs: Any) -> Any: + transport, logger, timeout = _get_transport_logger_timeout(cls=args[0]) -class ChannelTimeout: - def __init__(self, message: str = "") -> None: - """ - Channel timeout decorator + if not timeout: + return await wrapped_func(*args, **kwargs) - Args: - message: accepts message from decorated function to add context to any timeout - (if a timeout happens!) + try: + return await asyncio.wait_for(wrapped_func(*args, **kwargs), timeout=timeout) + except asyncio.TimeoutError: + _handle_timeout( + transport=transport, + logger=logger, + message=_get_timeout_message(func_name=wrapped_func.__name__), + ) - Returns: - None + else: + # ignoring type error: + # "All conditional function variants must have identical signatures" + # one is sync one is async so never going to be identical here! + def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore + transport, logger, timeout = _get_transport_logger_timeout(cls=args[0]) + + if not timeout: + return wrapped_func(*args, **kwargs) + + cls_name = transport.__class__.__name__ + + if ( + cls_name in ("SystemTransport", "TelnetTransport") + or _IS_WINDOWS + or threading.current_thread() is not threading.main_thread() + ): + return _multiprocessing_timeout( + transport=transport, + logger=logger, + timeout=timeout, + wrapped_func=wrapped_func, + args=args, + kwargs=kwargs, + ) - Raises: - N/A + callback = partial( + _signal_raise_exception, + transport=transport, + logger=logger, + message=_get_timeout_message(wrapped_func.__name__), + ) - """ - self.message = message - self.channel_timeout_ops = 0.0 - self.channel_logger: LoggerAdapterT - self.transport_instance: "BaseTransport" + old = signal.signal(signal.SIGALRM, callback) + signal.setitimer(signal.ITIMER_REAL, timeout) + try: + return wrapped_func(*args, **kwargs) + finally: + if timeout: + signal.setitimer(signal.ITIMER_REAL, 0) + signal.signal(signal.SIGALRM, old) - def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]: - """ - Decorate an "operation" to modify the timeout_ops value for duration of that operation + # ensures that the wrapped function is updated w/ the original functions docs/etc. -- + # necessary for introspection for the auto gen docs to work! + update_wrapper(wrapper=decorate, wrapped=wrapped_func) + return decorate - This decorator wraps send command/config ops and is used to allow users to set a - `timeout_ops` value for the duration of a single method call -- this makes it so users don't - need to manually set/reset the value - Args: - wrapped_func: function being decorated +def timeout_modifier(wrapped_func: Callable[..., Any]) -> Callable[..., Any]: + """ + Decorate an "operation" to modify the timeout_ops value for duration of that operation - Returns: - decorate: decorated func + This decorator wraps send command/config ops and is used to allow users to set a + `timeout_ops` value for the duration of a single method call -- this makes it so users don't + need to manually set/reset the value - Raises: - N/A + Args: + wrapped_func: function being decorated - """ - if asyncio.iscoroutinefunction(wrapped_func): + Returns: + decorate: decorated func - async def decorate(*args: Any, **kwargs: Any) -> Any: - channel_instance: "Channel" = args[0] - self.channel_logger = channel_instance.logger - self.channel_timeout_ops = ( - channel_instance._base_channel_args.timeout_ops # pylint: disable=W0212 - ) + Raises: + N/A - if not self.channel_timeout_ops: - return await wrapped_func(*args, **kwargs) - - self.transport_instance = channel_instance.transport - - try: - return await asyncio.wait_for( - wrapped_func(*args, **kwargs), timeout=self.channel_timeout_ops - ) - except asyncio.TimeoutError: - self._handle_timeout() - - else: - # ignoring type error: - # "All conditional function variants must have identical signatures" - # one is sync one is async so never going to be identical here! - def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore - channel_instance: "Channel" = args[0] - self.channel_logger = channel_instance.logger - self.channel_timeout_ops = ( - channel_instance._base_channel_args.timeout_ops # pylint: disable=W0212 - ) + """ + if asyncio.iscoroutinefunction(wrapped_func): - if not self.channel_timeout_ops: - return wrapped_func(*args, **kwargs) - - self.transport_instance = channel_instance.transport - transport_instance_class_name = self.transport_instance.__class__.__name__ - - if ( - transport_instance_class_name in ("SystemTransport", "TelnetTransport") - or _IS_WINDOWS - or threading.current_thread() is not threading.main_thread() - ): - return self._multiprocessing_timeout( - wrapped_func=wrapped_func, - args=args, - kwargs=kwargs, - ) - - old = signal.signal(signal.SIGALRM, self._signal_raise_exception) - signal.setitimer(signal.ITIMER_REAL, self.channel_timeout_ops) - try: - return wrapped_func(*args, **kwargs) - finally: - if self.channel_timeout_ops: - signal.setitimer(signal.ITIMER_REAL, 0) - signal.signal(signal.SIGALRM, old) - - # ensures that the wrapped function is updated w/ the original functions docs/etc. -- - # necessary for introspection for the auto gen docs to work! - update_wrapper(wrapper=decorate, wrapped=wrapped_func) - return decorate - - def _handle_timeout(self) -> None: - """ - Timeout handler method to close connections and raise ScrapliTimeout - - Args: - N/A - - Returns: - None - - Raises: - ScrapliTimeout: always, if we hit this method we have already timed out! - - """ - self.channel_logger.critical("channel operation timed out, closing transport") - self.transport_instance.close() - raise ScrapliTimeout(self.message) - - def _multiprocessing_timeout( - self, wrapped_func: Callable[..., Any], args: Any, kwargs: Any - ) -> Any: - """ - Multiprocessing method for timeouts; works in threads and on windows - - Args: - wrapped_func: function being decorated - args: function being decorated args - kwargs: function being decorated kwargs - - Returns: - Any: result of decorated function - - Raises: - N/A - - """ - with ThreadPoolExecutor(max_workers=1) as pool: - future = pool.submit(wrapped_func, *args, **kwargs) - wait([future], timeout=self.channel_timeout_ops) - if not future.done(): - self._handle_timeout() - return future.result() + async def decorate(*args: Any, **kwargs: Any) -> Any: + driver_instance: "AsyncGenericDriver" = args[0] + driver_logger = driver_instance.logger + + timeout_ops_kwarg = kwargs.get("timeout_ops", None) - def _signal_raise_exception(self, signum: Any, frame: Any) -> None: - """ - Signal method exception handler - - Args: - signum: singum from the singal handler, unused here - frame: frame from the signal handler, unused here - - Returns: - None - - Raises: - N/A - - """ - _, _ = signum, frame - self._handle_timeout() - - -class TimeoutOpsModifier: - def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]: - """ - Decorate an "operation" to modify the timeout_ops value for duration of that operation - - This decorator wraps send command/config ops and is used to allow users to set a - `timeout_ops` value for the duration of a single method call -- this makes it so users don't - need to manually set/reset the value - - Args: - wrapped_func: function being decorated - - Returns: - decorate: decorated func - - Raises: - N/A - - """ - if asyncio.iscoroutinefunction(wrapped_func): - - async def decorate(*args: Any, **kwargs: Any) -> Any: - driver_instance: "AsyncGenericDriver" = args[0] - driver_logger = driver_instance.logger - - timeout_ops_kwarg = kwargs.get("timeout_ops", None) - - if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops: - result = await wrapped_func(*args, **kwargs) - else: - driver_logger.info( - "modifying driver timeout for current operation, temporary timeout_ops " - f"value: '{timeout_ops_kwarg}'" - ) - base_timeout_ops = driver_instance.timeout_ops - driver_instance.timeout_ops = kwargs["timeout_ops"] - result = await wrapped_func(*args, **kwargs) - driver_instance.timeout_ops = base_timeout_ops - return result - - else: - # ignoring type error: - # "All conditional function variants must have identical signatures" - # one is sync one is async so never going to be identical here! - def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore - driver_instance: "GenericDriver" = args[0] - driver_logger = driver_instance.logger - - timeout_ops_kwarg = kwargs.get("timeout_ops", None) - - if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops: - result = wrapped_func(*args, **kwargs) - else: - driver_logger.info( - "modifying driver timeout for current operation, temporary timeout_ops " - f"value: '{timeout_ops_kwarg}'" - ) - base_timeout_ops = driver_instance.timeout_ops - driver_instance.timeout_ops = kwargs["timeout_ops"] - result = wrapped_func(*args, **kwargs) - driver_instance.timeout_ops = base_timeout_ops - return result - - # ensures that the wrapped function is updated w/ the original functions docs/etc. -- - # necessary for introspection for the auto gen docs to work! - update_wrapper(wrapper=decorate, wrapped=wrapped_func) - return decorate + if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops: + result = await wrapped_func(*args, **kwargs) + else: + driver_logger.info( + "modifying driver timeout for current operation, temporary timeout_ops " + f"value: '{timeout_ops_kwarg}'" + ) + base_timeout_ops = driver_instance.timeout_ops + driver_instance.timeout_ops = kwargs["timeout_ops"] + result = await wrapped_func(*args, **kwargs) + driver_instance.timeout_ops = base_timeout_ops + return result + + else: + # ignoring type error: + # "All conditional function variants must have identical signatures" + # one is sync one is async so never going to be identical here! + def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore + driver_instance: "GenericDriver" = args[0] + driver_logger = driver_instance.logger + + timeout_ops_kwarg = kwargs.get("timeout_ops", None) + + if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops: + result = wrapped_func(*args, **kwargs) + else: + driver_logger.info( + "modifying driver timeout for current operation, temporary timeout_ops " + f"value: '{timeout_ops_kwarg}'" + ) + base_timeout_ops = driver_instance.timeout_ops + driver_instance.timeout_ops = kwargs["timeout_ops"] + result = wrapped_func(*args, **kwargs) + driver_instance.timeout_ops = base_timeout_ops + return result + + # ensures that the wrapped function is updated w/ the original functions docs/etc. -- + # necessary for introspection for the auto gen docs to work! + update_wrapper(wrapper=decorate, wrapped=wrapped_func) + return decorate +## Functions -## Classes - -### ChannelTimeout + +#### timeout_modifier +`timeout_modifier(wrapped_func: Callable[..., Any]) ‑> Callable[..., Any]` ```text -Channel timeout decorator +Decorate an "operation" to modify the timeout_ops value for duration of that operation + +This decorator wraps send command/config ops and is used to allow users to set a +`timeout_ops` value for the duration of a single method call -- this makes it so users don't +need to manually set/reset the value Args: - message: accepts message from decorated function to add context to any timeout - (if a timeout happens!) + wrapped_func: function being decorated Returns: - None + decorate: decorated func Raises: N/A ``` -
- - Expand source code - -
-        
-class ChannelTimeout:
-    def __init__(self, message: str = "") -> None:
-        """
-        Channel timeout decorator
-
-        Args:
-            message: accepts message from decorated function to add context to any timeout
-                (if a timeout happens!)
-
-        Returns:
-            None
 
-        Raises:
-            N/A
 
-        """
-        self.message = message
-        self.channel_timeout_ops = 0.0
-        self.channel_logger: LoggerAdapterT
-        self.transport_instance: "BaseTransport"
 
-    def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]:
-        """
-        Decorate an "operation" to modify the timeout_ops value for duration of that operation
-
-        This decorator wraps send command/config ops and is used to allow users to set a
-        `timeout_ops` value for the duration of a single method call -- this makes it so users don't
-        need to manually set/reset the value
-
-        Args:
-            wrapped_func: function being decorated
-
-        Returns:
-            decorate: decorated func
-
-        Raises:
-            N/A
-
-        """
-        if asyncio.iscoroutinefunction(wrapped_func):
-
-            async def decorate(*args: Any, **kwargs: Any) -> Any:
-                channel_instance: "Channel" = args[0]
-                self.channel_logger = channel_instance.logger
-                self.channel_timeout_ops = (
-                    channel_instance._base_channel_args.timeout_ops  # pylint: disable=W0212
-                )
-
-                if not self.channel_timeout_ops:
-                    return await wrapped_func(*args, **kwargs)
-
-                self.transport_instance = channel_instance.transport
-
-                try:
-                    return await asyncio.wait_for(
-                        wrapped_func(*args, **kwargs), timeout=self.channel_timeout_ops
-                    )
-                except asyncio.TimeoutError:
-                    self._handle_timeout()
-
-        else:
-            # ignoring type error:
-            # "All conditional function variants must have identical signatures"
-            # one is sync one is async so never going to be identical here!
-            def decorate(*args: Any, **kwargs: Any) -> Any:  # type: ignore
-                channel_instance: "Channel" = args[0]
-                self.channel_logger = channel_instance.logger
-                self.channel_timeout_ops = (
-                    channel_instance._base_channel_args.timeout_ops  # pylint: disable=W0212
-                )
-
-                if not self.channel_timeout_ops:
-                    return wrapped_func(*args, **kwargs)
-
-                self.transport_instance = channel_instance.transport
-                transport_instance_class_name = self.transport_instance.__class__.__name__
-
-                if (
-                    transport_instance_class_name in ("SystemTransport", "TelnetTransport")
-                    or _IS_WINDOWS
-                    or threading.current_thread() is not threading.main_thread()
-                ):
-                    return self._multiprocessing_timeout(
-                        wrapped_func=wrapped_func,
-                        args=args,
-                        kwargs=kwargs,
-                    )
-
-                old = signal.signal(signal.SIGALRM, self._signal_raise_exception)
-                signal.setitimer(signal.ITIMER_REAL, self.channel_timeout_ops)
-                try:
-                    return wrapped_func(*args, **kwargs)
-                finally:
-                    if self.channel_timeout_ops:
-                        signal.setitimer(signal.ITIMER_REAL, 0)
-                        signal.signal(signal.SIGALRM, old)
-
-        # ensures that the wrapped function is updated w/ the original functions docs/etc. --
-        # necessary for introspection for the auto gen docs to work!
-        update_wrapper(wrapper=decorate, wrapped=wrapped_func)
-        return decorate
-
-    def _handle_timeout(self) -> None:
-        """
-        Timeout handler method to close connections and raise ScrapliTimeout
-
-        Args:
-            N/A
-
-        Returns:
-            None
-
-        Raises:
-            ScrapliTimeout: always, if we hit this method we have already timed out!
-
-        """
-        self.channel_logger.critical("channel operation timed out, closing transport")
-        self.transport_instance.close()
-        raise ScrapliTimeout(self.message)
-
-    def _multiprocessing_timeout(
-        self, wrapped_func: Callable[..., Any], args: Any, kwargs: Any
-    ) -> Any:
-        """
-        Multiprocessing method for timeouts; works in threads and on windows
-
-        Args:
-            wrapped_func: function being decorated
-            args: function being decorated args
-            kwargs: function being decorated kwargs
-
-        Returns:
-            Any: result of decorated function
-
-        Raises:
-            N/A
-
-        """
-        with ThreadPoolExecutor(max_workers=1) as pool:
-            future = pool.submit(wrapped_func, *args, **kwargs)
-            wait([future], timeout=self.channel_timeout_ops)
-            if not future.done():
-                self._handle_timeout()
-        return future.result()
-
-    def _signal_raise_exception(self, signum: Any, frame: Any) -> None:
-        """
-        Signal method exception handler
-
-        Args:
-            signum: singum from the singal handler, unused here
-            frame: frame from the signal handler, unused here
-
-        Returns:
-            None
-
-        Raises:
-            N/A
-
-        """
-        _, _ = signum, frame
-        self._handle_timeout()
-        
-    
-
- - - - - -### TimeoutOpsModifier - - - -
- - Expand source code - -
-        
-class TimeoutOpsModifier:
-    def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]:
-        """
-        Decorate an "operation" to modify the timeout_ops value for duration of that operation
-
-        This decorator wraps send command/config ops and is used to allow users to set a
-        `timeout_ops` value for the duration of a single method call -- this makes it so users don't
-        need to manually set/reset the value
-
-        Args:
-            wrapped_func: function being decorated
-
-        Returns:
-            decorate: decorated func
-
-        Raises:
-            N/A
-
-        """
-        if asyncio.iscoroutinefunction(wrapped_func):
-
-            async def decorate(*args: Any, **kwargs: Any) -> Any:
-                driver_instance: "AsyncGenericDriver" = args[0]
-                driver_logger = driver_instance.logger
-
-                timeout_ops_kwarg = kwargs.get("timeout_ops", None)
-
-                if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops:
-                    result = await wrapped_func(*args, **kwargs)
-                else:
-                    driver_logger.info(
-                        "modifying driver timeout for current operation, temporary timeout_ops "
-                        f"value: '{timeout_ops_kwarg}'"
-                    )
-                    base_timeout_ops = driver_instance.timeout_ops
-                    driver_instance.timeout_ops = kwargs["timeout_ops"]
-                    result = await wrapped_func(*args, **kwargs)
-                    driver_instance.timeout_ops = base_timeout_ops
-                return result
-
-        else:
-            # ignoring type error:
-            # "All conditional function variants must have identical signatures"
-            # one is sync one is async so never going to be identical here!
-            def decorate(*args: Any, **kwargs: Any) -> Any:  # type: ignore
-                driver_instance: "GenericDriver" = args[0]
-                driver_logger = driver_instance.logger
-
-                timeout_ops_kwarg = kwargs.get("timeout_ops", None)
-
-                if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops:
-                    result = wrapped_func(*args, **kwargs)
-                else:
-                    driver_logger.info(
-                        "modifying driver timeout for current operation, temporary timeout_ops "
-                        f"value: '{timeout_ops_kwarg}'"
-                    )
-                    base_timeout_ops = driver_instance.timeout_ops
-                    driver_instance.timeout_ops = kwargs["timeout_ops"]
-                    result = wrapped_func(*args, **kwargs)
-                    driver_instance.timeout_ops = base_timeout_ops
-                return result
-
-        # ensures that the wrapped function is updated w/ the original functions docs/etc. --
-        # necessary for introspection for the auto gen docs to work!
-        update_wrapper(wrapper=decorate, wrapped=wrapped_func)
-        return decorate
-        
-    
-
- - - - - -### TransportTimeout + +#### timeout_wrapper +`timeout_wrapper(wrapped_func: Callable[..., Any]) ‑> Callable[..., Any]` ```text -Transport timeout decorator +Timeout wrapper for transports Args: - message: accepts message from decorated function to add context to any timeout - (if a timeout happens!) + wrapped_func: function being wrapped -- must be a method of Channel or Transport Returns: - None + Any: result of wrapped function Raises: N/A -``` - -
- - Expand source code - -
-        
-class TransportTimeout:
-    def __init__(self, message: str = "") -> None:
-        """
-        Transport timeout decorator
-
-        Args:
-            message: accepts message from decorated function to add context to any timeout
-                (if a timeout happens!)
-
-        Returns:
-            None
-
-        Raises:
-            N/A
-
-        """
-        self.message = message
-        self.transport_instance: "BaseTransport"
-        self.transport_timeout_transport = 0.0
-
-    def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]:
-        """
-        Decorate an "operation" to modify the timeout_transport value for duration of that operation
-
-        This decorator wraps a transport read operation and is used to allow users to control the
-        transport timeout via the `timeout_transport` attribute. This decorator should be applied to
-        any transport "read" operations.
-
-        Args:
-            wrapped_func: function being decorated
-
-        Returns:
-            decorate: decorated func
-
-        Raises:
-            N/A
-
-        """
-
-        if asyncio.iscoroutinefunction(wrapped_func):
-
-            async def decorate(*args: Any, **kwargs: Any) -> Any:
-                self.transport_instance = args[0]
-                self.transport_timeout_transport = self._get_timeout_transport()
-
-                if not self.transport_timeout_transport:
-                    return await wrapped_func(*args, **kwargs)
-
-                try:
-                    return await asyncio.wait_for(
-                        wrapped_func(*args, **kwargs), timeout=self.transport_timeout_transport
-                    )
-                except asyncio.TimeoutError:
-                    self._handle_timeout()
-
-        else:
-            # ignoring type error:
-            # "All conditional function variants must have identical signatures"
-            # one is sync one is async so never going to be identical here!
-            def decorate(*args: Any, **kwargs: Any) -> Any:  # type: ignore
-                self.transport_instance = args[0]
-                self.transport_timeout_transport = self._get_timeout_transport()
-
-                if not self.transport_timeout_transport:
-                    return wrapped_func(*args, **kwargs)
-
-                transport_instance_class_name = self.transport_instance.__class__.__name__
-
-                if (
-                    transport_instance_class_name in ("SystemTransport", "TelnetTransport")
-                    or _IS_WINDOWS
-                    or threading.current_thread() is not threading.main_thread()
-                ):
-                    return self._multiprocessing_timeout(
-                        wrapped_func=wrapped_func,
-                        args=args,
-                        kwargs=kwargs,
-                    )
-
-                old = signal.signal(signal.SIGALRM, self._signal_raise_exception)
-                signal.setitimer(signal.ITIMER_REAL, self.transport_timeout_transport)
-                try:
-                    return wrapped_func(*args, **kwargs)
-                finally:
-                    if self.transport_timeout_transport:
-                        signal.setitimer(signal.ITIMER_REAL, 0)
-                        signal.signal(signal.SIGALRM, old)
-
-        # ensures that the wrapped function is updated w/ the original functions docs/etc. --
-        # necessary for introspection for the auto gen docs to work!
-        update_wrapper(wrapper=decorate, wrapped=wrapped_func)
-        return decorate
-
-    def _get_timeout_transport(self) -> float:
-        """
-        Fetch and return timeout transport from the transport object
-
-        Args:
-            N/A
-
-        Returns:
-            float: transport timeout value
-
-        Raises:
-            N/A
-
-        """
-        transport_args = self.transport_instance._base_transport_args  # pylint: disable=W0212
-        return transport_args.timeout_transport
-
-    def _handle_timeout(self) -> None:
-        """
-        Timeout handler method to close connections and raise ScrapliTimeout
-
-        Args:
-            N/A
-
-        Returns:
-            None
-
-        Raises:
-            ScrapliTimeout: always, if we hit this method we have already timed out!
-
-        """
-        self.transport_instance.logger.critical("transport operation timed out, closing transport")
-        self.transport_instance.close()
-        raise ScrapliTimeout(self.message)
-
-    def _multiprocessing_timeout(
-        self, wrapped_func: Callable[..., Any], args: Any, kwargs: Any
-    ) -> Any:
-        """
-        Multiprocessing method for timeouts; works in threads and on windows
-
-        Args:
-            wrapped_func: function being decorated
-            args: function being decorated args
-            kwargs: function being decorated kwargs
-
-        Returns:
-            Any: result of decorated function
-
-        Raises:
-            N/A
-
-        """
-        with ThreadPoolExecutor(max_workers=1) as pool:
-            future = pool.submit(wrapped_func, *args, **kwargs)
-            wait([future], timeout=self.transport_timeout_transport)
-            if not future.done():
-                self._handle_timeout()
-        return future.result()
-
-    def _signal_raise_exception(self, signum: Any, frame: Any) -> None:
-        """
-        Signal method exception handler
-
-        Args:
-            signum: singum from the singal handler, unused here
-            frame: frame from the signal handler, unused here
-
-        Returns:
-            None
-
-        Raises:
-            N/A
-
-        """
-        _, _ = signum, frame
-        self._handle_timeout()
-        
-    
-
\ No newline at end of file +``` \ No newline at end of file diff --git a/docs/api_docs/driver/base/async_driver.md b/docs/api_docs/driver/base/async_driver.md index d233b9f1..4eedcb56 100644 --- a/docs/api_docs/driver/base/async_driver.md +++ b/docs/api_docs/driver/base/async_driver.md @@ -30,13 +30,15 @@ scrapli.driver.base.async_driver """scrapli.driver.base.async_driver""" from types import TracebackType -from typing import Any, Optional, Type +from typing import Any, Optional, Type, TypeVar from scrapli.channel import AsyncChannel from scrapli.driver.base.base_driver import BaseDriver from scrapli.exceptions import ScrapliValueError from scrapli.transport import ASYNCIO_TRANSPORTS +_T = TypeVar("_T", bound="AsyncDriver") + class AsyncDriver(BaseDriver): def __init__(self, **kwargs: Any): @@ -53,7 +55,7 @@ class AsyncDriver(BaseDriver): base_channel_args=self._base_channel_args, ) - async def __aenter__(self) -> "AsyncDriver": + async def __aenter__(self: _T) -> _T: """ Enter method for context manager @@ -61,7 +63,7 @@ class AsyncDriver(BaseDriver): N/A Returns: - AsyncDriver: opened AsyncDriver object + _T: a concrete implementation of the opened AsyncDriver object Raises: N/A @@ -371,7 +373,7 @@ class AsyncDriver(BaseDriver): base_channel_args=self._base_channel_args, ) - async def __aenter__(self) -> "AsyncDriver": + async def __aenter__(self: _T) -> _T: """ Enter method for context manager @@ -379,7 +381,7 @@ class AsyncDriver(BaseDriver): N/A Returns: - AsyncDriver: opened AsyncDriver object + _T: a concrete implementation of the opened AsyncDriver object Raises: N/A diff --git a/docs/api_docs/driver/base/sync_driver.md b/docs/api_docs/driver/base/sync_driver.md index f2a19063..8436fd20 100644 --- a/docs/api_docs/driver/base/sync_driver.md +++ b/docs/api_docs/driver/base/sync_driver.md @@ -30,13 +30,15 @@ scrapli.driver.base.sync_driver """scrapli.driver.base.sync_driver""" from types import TracebackType -from typing import Any, Optional, Type +from typing import Any, Optional, Type, TypeVar from scrapli.channel import Channel from scrapli.driver.base.base_driver import BaseDriver from scrapli.exceptions import ScrapliValueError from scrapli.transport import ASYNCIO_TRANSPORTS +_T = TypeVar("_T", bound="Driver") + class Driver(BaseDriver): def __init__(self, **kwargs: Any): @@ -53,7 +55,7 @@ class Driver(BaseDriver): base_channel_args=self._base_channel_args, ) - def __enter__(self) -> "Driver": + def __enter__(self: _T) -> _T: """ Enter method for context manager @@ -61,7 +63,7 @@ class Driver(BaseDriver): N/A Returns: - Driver: opened Driver object + _T: a concrete implementation of the opened Driver object Raises: N/A @@ -322,7 +324,7 @@ class Driver(BaseDriver): base_channel_args=self._base_channel_args, ) - def __enter__(self) -> "Driver": + def __enter__(self: _T) -> _T: """ Enter method for context manager @@ -330,7 +332,7 @@ class Driver(BaseDriver): N/A Returns: - Driver: opened Driver object + _T: a concrete implementation of the opened Driver object Raises: N/A diff --git a/docs/api_docs/driver/core/juniper_junos/base_driver.md b/docs/api_docs/driver/core/juniper_junos/base_driver.md index 50d71fcb..0e3bd9dc 100644 --- a/docs/api_docs/driver/core/juniper_junos/base_driver.md +++ b/docs/api_docs/driver/core/juniper_junos/base_driver.md @@ -90,7 +90,7 @@ PRIVS = { ), "root_shell": ( PrivilegeLevel( - pattern=r"^.*root@(?:\S*:\S*\s?)?[%\#]\s?$", + pattern=r"^.*root@(?:\S*:?\S*\s?)?[%\#]\s?$", name="root_shell", previous_priv="exec", deescalate="exit", diff --git a/docs/api_docs/driver/generic/async_driver.md b/docs/api_docs/driver/generic/async_driver.md index e5ab89f4..2f78fc15 100644 --- a/docs/api_docs/driver/generic/async_driver.md +++ b/docs/api_docs/driver/generic/async_driver.md @@ -33,7 +33,7 @@ import asyncio from io import BytesIO from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union -from scrapli.decorators import TimeoutOpsModifier +from scrapli.decorators import timeout_modifier from scrapli.driver import AsyncDriver from scrapli.driver.generic.base_driver import BaseGenericDriver from scrapli.exceptions import ScrapliTimeout, ScrapliValueError @@ -118,7 +118,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver): prompt: str = await self.channel.get_prompt() return prompt - @TimeoutOpsModifier() + @timeout_modifier async def _send_command( self, command: str, @@ -312,7 +312,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver): timeout_ops=timeout_ops, ) - @TimeoutOpsModifier() + @timeout_modifier async def send_and_read( self, channel_input: str, @@ -372,7 +372,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver): raw_response=raw_response, processed_response=processed_response, response=response ) - @TimeoutOpsModifier() + @timeout_modifier async def send_interactive( self, interact_events: Union[List[Tuple[str, str]], List[Tuple[str, str, bool]]], @@ -784,7 +784,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver): prompt: str = await self.channel.get_prompt() return prompt - @TimeoutOpsModifier() + @timeout_modifier async def _send_command( self, command: str, @@ -978,7 +978,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver): timeout_ops=timeout_ops, ) - @TimeoutOpsModifier() + @timeout_modifier async def send_and_read( self, channel_input: str, @@ -1038,7 +1038,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver): raw_response=raw_response, processed_response=processed_response, response=response ) - @TimeoutOpsModifier() + @timeout_modifier async def send_interactive( self, interact_events: Union[List[Tuple[str, str]], List[Tuple[str, str, bool]]], diff --git a/docs/api_docs/driver/generic/sync_driver.md b/docs/api_docs/driver/generic/sync_driver.md index 1e661c26..aa3ef0bd 100644 --- a/docs/api_docs/driver/generic/sync_driver.md +++ b/docs/api_docs/driver/generic/sync_driver.md @@ -33,7 +33,7 @@ import time from io import BytesIO from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union -from scrapli.decorators import TimeoutOpsModifier +from scrapli.decorators import timeout_modifier from scrapli.driver import Driver from scrapli.driver.generic.base_driver import BaseGenericDriver from scrapli.exceptions import ScrapliTimeout, ScrapliValueError @@ -119,7 +119,7 @@ class GenericDriver(Driver, BaseGenericDriver): prompt: str = self.channel.get_prompt() return prompt - @TimeoutOpsModifier() + @timeout_modifier def _send_command( self, command: str, @@ -313,7 +313,7 @@ class GenericDriver(Driver, BaseGenericDriver): timeout_ops=timeout_ops, ) - @TimeoutOpsModifier() + @timeout_modifier def send_and_read( self, channel_input: str, @@ -373,7 +373,7 @@ class GenericDriver(Driver, BaseGenericDriver): raw_response=raw_response, processed_response=processed_response, response=response ) - @TimeoutOpsModifier() + @timeout_modifier def send_interactive( self, interact_events: Union[List[Tuple[str, str]], List[Tuple[str, str, bool]]], @@ -783,7 +783,7 @@ class GenericDriver(Driver, BaseGenericDriver): prompt: str = self.channel.get_prompt() return prompt - @TimeoutOpsModifier() + @timeout_modifier def _send_command( self, command: str, @@ -977,7 +977,7 @@ class GenericDriver(Driver, BaseGenericDriver): timeout_ops=timeout_ops, ) - @TimeoutOpsModifier() + @timeout_modifier def send_and_read( self, channel_input: str, @@ -1037,7 +1037,7 @@ class GenericDriver(Driver, BaseGenericDriver): raw_response=raw_response, processed_response=processed_response, response=response ) - @TimeoutOpsModifier() + @timeout_modifier def send_interactive( self, interact_events: Union[List[Tuple[str, str]], List[Tuple[str, str, bool]]], diff --git a/docs/api_docs/driver/network/async_driver.md b/docs/api_docs/driver/network/async_driver.md index 4fc72d19..3c5f2a93 100644 --- a/docs/api_docs/driver/network/async_driver.md +++ b/docs/api_docs/driver/network/async_driver.md @@ -136,7 +136,10 @@ class AsyncNetworkDriver(AsyncGenericDriver, BaseNetworkDriver): (escalate_priv.escalate, escalate_priv.escalate_prompt, False), (self.auth_secondary, escalate_priv.pattern, True), ], - interaction_complete_patterns=[escalate_priv.pattern], + interaction_complete_patterns=[ + self.privilege_levels[escalate_priv.previous_priv].pattern, + escalate_priv.pattern, + ], ) except ScrapliTimeout as exc: raise ScrapliAuthenticationFailed( @@ -462,7 +465,7 @@ class AsyncNetworkDriver(AsyncGenericDriver, BaseNetworkDriver): if failed_when_contains is None: failed_when_contains = self.failed_when_contains - # type hint is due to the TimeoutModifier wrapper returning `Any` so that we dont anger the + # type hint is due to the timeout_modifier wrapper returning `Any` so that we dont anger the # asyncio parts (which will get an awaitable not a Response returned) response: Response = await super().send_interactive( interact_events=interact_events, @@ -860,7 +863,10 @@ class AsyncNetworkDriver(AsyncGenericDriver, BaseNetworkDriver): (escalate_priv.escalate, escalate_priv.escalate_prompt, False), (self.auth_secondary, escalate_priv.pattern, True), ], - interaction_complete_patterns=[escalate_priv.pattern], + interaction_complete_patterns=[ + self.privilege_levels[escalate_priv.previous_priv].pattern, + escalate_priv.pattern, + ], ) except ScrapliTimeout as exc: raise ScrapliAuthenticationFailed( @@ -1186,7 +1192,7 @@ class AsyncNetworkDriver(AsyncGenericDriver, BaseNetworkDriver): if failed_when_contains is None: failed_when_contains = self.failed_when_contains - # type hint is due to the TimeoutModifier wrapper returning `Any` so that we dont anger the + # type hint is due to the timeout_modifier wrapper returning `Any` so that we dont anger the # asyncio parts (which will get an awaitable not a Response returned) response: Response = await super().send_interactive( interact_events=interact_events, diff --git a/docs/api_docs/driver/network/base_driver.md b/docs/api_docs/driver/network/base_driver.md index f5ba00dd..fdfca7fd 100644 --- a/docs/api_docs/driver/network/base_driver.md +++ b/docs/api_docs/driver/network/base_driver.md @@ -42,7 +42,7 @@ from scrapli.helper import user_warning from scrapli.response import MultiResponse, Response if TYPE_CHECKING: - LoggerAdapterT = LoggerAdapter[Logger] # pylint:disable=E1136 + LoggerAdapterT = LoggerAdapter[Logger] # pragma: no cover # pylint:disable=E1136 else: LoggerAdapterT = LoggerAdapter @@ -142,7 +142,7 @@ class BaseNetworkDriver: rf"({priv_level_data.pattern})" for priv_level_data in self.privilege_levels.values() ) - @lru_cache() + @lru_cache(maxsize=64) def _determine_current_priv(self, current_prompt: str) -> List[str]: """ Determine current privilege level from prompt string @@ -664,7 +664,7 @@ class BaseNetworkDriver: rf"({priv_level_data.pattern})" for priv_level_data in self.privilege_levels.values() ) - @lru_cache() + @lru_cache(maxsize=64) def _determine_current_priv(self, current_prompt: str) -> List[str]: """ Determine current privilege level from prompt string diff --git a/docs/api_docs/driver/network/sync_driver.md b/docs/api_docs/driver/network/sync_driver.md index 7e4ced0d..cf8fd253 100644 --- a/docs/api_docs/driver/network/sync_driver.md +++ b/docs/api_docs/driver/network/sync_driver.md @@ -136,7 +136,10 @@ class NetworkDriver(GenericDriver, BaseNetworkDriver): (escalate_priv.escalate, escalate_priv.escalate_prompt, False), (self.auth_secondary, escalate_priv.pattern, True), ], - interaction_complete_patterns=[escalate_priv.pattern], + interaction_complete_patterns=[ + self.privilege_levels[escalate_priv.previous_priv].pattern, + escalate_priv.pattern, + ], ) except ScrapliTimeout as exc: raise ScrapliAuthenticationFailed( @@ -462,7 +465,7 @@ class NetworkDriver(GenericDriver, BaseNetworkDriver): if failed_when_contains is None: failed_when_contains = self.failed_when_contains - # type hint is due to the TimeoutModifier wrapper returning `Any` so that we dont anger the + # type hint is due to the timeout_modifier wrapper returning `Any` so that we dont anger the # asyncio parts (which will get an awaitable not a Response returned) response: Response = super().send_interactive( interact_events=interact_events, @@ -860,7 +863,10 @@ class NetworkDriver(GenericDriver, BaseNetworkDriver): (escalate_priv.escalate, escalate_priv.escalate_prompt, False), (self.auth_secondary, escalate_priv.pattern, True), ], - interaction_complete_patterns=[escalate_priv.pattern], + interaction_complete_patterns=[ + self.privilege_levels[escalate_priv.previous_priv].pattern, + escalate_priv.pattern, + ], ) except ScrapliTimeout as exc: raise ScrapliAuthenticationFailed( @@ -1186,7 +1192,7 @@ class NetworkDriver(GenericDriver, BaseNetworkDriver): if failed_when_contains is None: failed_when_contains = self.failed_when_contains - # type hint is due to the TimeoutModifier wrapper returning `Any` so that we dont anger the + # type hint is due to the timeout_modifier wrapper returning `Any` so that we dont anger the # asyncio parts (which will get an awaitable not a Response returned) response: Response = super().send_interactive( interact_events=interact_events, diff --git a/docs/api_docs/transport/plugins/asyncssh.md b/docs/api_docs/transport/plugins/asyncssh.md index 861d5c59..ecfba1d4 100644 --- a/docs/api_docs/transport/plugins/asyncssh.md +++ b/docs/api_docs/transport/plugins/asyncssh.md @@ -37,7 +37,7 @@ from asyncssh.connection import SSHClientConnection, connect from asyncssh.misc import ConnectionLost, PermissionDenied from asyncssh.stream import SSHReader, SSHWriter -from scrapli.decorators import TransportTimeout +from scrapli.decorators import timeout_wrapper from scrapli.exceptions import ( ScrapliAuthenticationFailed, ScrapliConnectionError, @@ -275,7 +275,7 @@ class AsyncsshTransport(AsyncTransport): pass return False - @TransportTimeout("timed out reading from transport") + @timeout_wrapper async def read(self) -> bytes: if not self.stdout: raise ScrapliConnectionNotOpened @@ -583,7 +583,7 @@ class AsyncsshTransport(AsyncTransport): pass return False - @TransportTimeout("timed out reading from transport") + @timeout_wrapper async def read(self) -> bytes: if not self.stdout: raise ScrapliConnectionNotOpened diff --git a/docs/api_docs/transport/plugins/asynctelnet.md b/docs/api_docs/transport/plugins/asynctelnet.md index 4e8d0584..540ac108 100644 --- a/docs/api_docs/transport/plugins/asynctelnet.md +++ b/docs/api_docs/transport/plugins/asynctelnet.md @@ -34,22 +34,14 @@ import socket from dataclasses import dataclass from typing import Optional -from scrapli.decorators import TransportTimeout +from scrapli.decorators import timeout_wrapper from scrapli.exceptions import ( ScrapliAuthenticationFailed, ScrapliConnectionError, ScrapliConnectionNotOpened, ) from scrapli.transport.base import AsyncTransport, BasePluginTransportArgs, BaseTransportArgs - -# telnet control characters we care about -IAC = bytes([255]) -DONT = bytes([254]) -DO = bytes([253]) -WONT = bytes([252]) -WILL = bytes([251]) -TERM_TYPE = bytes([24]) -SUPPRESS_GO_AHEAD = bytes([3]) +from scrapli.transport.base.telnet_common import DO, DONT, IAC, SUPPRESS_GO_AHEAD, WILL, WONT @dataclass() @@ -68,10 +60,9 @@ class AsynctelnetTransport(AsyncTransport): self.stdin: Optional[asyncio.StreamWriter] = None self._initial_buf = b"" - self._stdout_binary_transmission = False def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes: - """ " + """ Handle the actual response to control characters Broken up to be easier to test as well as to appease mr. mccabe @@ -128,7 +119,7 @@ class AsynctelnetTransport(AsyncTransport): return control_buf async def _handle_control_chars(self) -> None: - """ " + """ Handle control characters -- nearly identical to CPython telnetlib Basically we want to read and "decline" any and all control options that the server proposes @@ -227,7 +218,7 @@ class AsynctelnetTransport(AsyncTransport): return False return not self.stdout.at_eof() - @TransportTimeout("timed out reading from transport") + @timeout_wrapper async def read(self) -> bytes: if not self.stdout: raise ScrapliConnectionNotOpened @@ -301,10 +292,9 @@ class AsynctelnetTransport(AsyncTransport): self.stdin: Optional[asyncio.StreamWriter] = None self._initial_buf = b"" - self._stdout_binary_transmission = False def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes: - """ " + """ Handle the actual response to control characters Broken up to be easier to test as well as to appease mr. mccabe @@ -361,7 +351,7 @@ class AsynctelnetTransport(AsyncTransport): return control_buf async def _handle_control_chars(self) -> None: - """ " + """ Handle control characters -- nearly identical to CPython telnetlib Basically we want to read and "decline" any and all control options that the server proposes @@ -460,7 +450,7 @@ class AsynctelnetTransport(AsyncTransport): return False return not self.stdout.at_eof() - @TransportTimeout("timed out reading from transport") + @timeout_wrapper async def read(self) -> bytes: if not self.stdout: raise ScrapliConnectionNotOpened diff --git a/docs/api_docs/transport/plugins/paramiko.md b/docs/api_docs/transport/plugins/paramiko.md index 6509cd30..3ccf0670 100644 --- a/docs/api_docs/transport/plugins/paramiko.md +++ b/docs/api_docs/transport/plugins/paramiko.md @@ -216,7 +216,11 @@ class ParamikoTransport(Transport): raise ScrapliConnectionNotOpened if self._base_transport_args.transport_options.get("enable_rsa2", False) is False: - self.session.disabled_algorithms = {"keys": ["rsa-sha2-256", "rsa-sha2-512"]} + # do this for "keys" and "pubkeys": https://github.com/paramiko/paramiko/issues/1984 + self.session.disabled_algorithms = { + "keys": ["rsa-sha2-256", "rsa-sha2-512"], + "pubkeys": ["rsa-sha2-256", "rsa-sha2-512"], + } try: paramiko_key = RSAKey(filename=self.plugin_transport_args.auth_private_key) @@ -537,7 +541,11 @@ class ParamikoTransport(Transport): raise ScrapliConnectionNotOpened if self._base_transport_args.transport_options.get("enable_rsa2", False) is False: - self.session.disabled_algorithms = {"keys": ["rsa-sha2-256", "rsa-sha2-512"]} + # do this for "keys" and "pubkeys": https://github.com/paramiko/paramiko/issues/1984 + self.session.disabled_algorithms = { + "keys": ["rsa-sha2-256", "rsa-sha2-512"], + "pubkeys": ["rsa-sha2-256", "rsa-sha2-512"], + } try: paramiko_key = RSAKey(filename=self.plugin_transport_args.auth_private_key) diff --git a/docs/api_docs/transport/plugins/system.md b/docs/api_docs/transport/plugins/system.md index 5d0c65c8..eb21223c 100644 --- a/docs/api_docs/transport/plugins/system.md +++ b/docs/api_docs/transport/plugins/system.md @@ -33,7 +33,7 @@ import sys from dataclasses import dataclass from typing import List, Optional -from scrapli.decorators import TransportTimeout +from scrapli.decorators import timeout_wrapper from scrapli.exceptions import ( ScrapliConnectionError, ScrapliConnectionNotOpened, @@ -180,7 +180,7 @@ class SystemTransport(Transport): return True return False - @TransportTimeout("timed out reading from transport") + @timeout_wrapper def read(self) -> bytes: if not self.session: raise ScrapliConnectionNotOpened @@ -437,7 +437,7 @@ class SystemTransport(Transport): return True return False - @TransportTimeout("timed out reading from transport") + @timeout_wrapper def read(self) -> bytes: if not self.session: raise ScrapliConnectionNotOpened diff --git a/docs/api_docs/transport/plugins/telnet.md b/docs/api_docs/transport/plugins/telnet.md index 298f3339..14f2c972 100644 --- a/docs/api_docs/transport/plugins/telnet.md +++ b/docs/api_docs/transport/plugins/telnet.md @@ -30,12 +30,13 @@ scrapli.transport.plugins.telnet.transport """scrapli.transport.plugins.telnet.transport""" from dataclasses import dataclass -from telnetlib import Telnet from typing import Optional -from scrapli.decorators import TransportTimeout +from scrapli.decorators import timeout_wrapper from scrapli.exceptions import ScrapliConnectionError, ScrapliConnectionNotOpened from scrapli.transport.base import BasePluginTransportArgs, BaseTransportArgs, Transport +from scrapli.transport.base.base_socket import Socket +from scrapli.transport.base.telnet_common import DO, DONT, IAC, SUPPRESS_GO_AHEAD, WILL, WONT @dataclass() @@ -43,81 +44,196 @@ class PluginTransportArgs(BasePluginTransportArgs): pass -class ScrapliTelnet(Telnet): - def __init__(self, host: str, port: int, timeout: float) -> None: +class TelnetTransport(Transport): + def __init__( + self, base_transport_args: BaseTransportArgs, plugin_transport_args: PluginTransportArgs + ) -> None: + super().__init__(base_transport_args=base_transport_args) + self.plugin_transport_args = plugin_transport_args + + self.socket: Optional[Socket] = None + self._initial_buf = b"" + + def _set_socket_timeout(self, timeout: float) -> None: """ - ScrapliTelnet class for typing purposes + Set underlying socket timeout + + Mostly this exists just to assert that socket and socket.sock are not None to appease mypy! Args: - host: string of host - port: integer port to connect to - timeout: timeout value in seconds + timeout: float value to set as the timeout Returns: - None + N/A Raises: + ScrapliConnectionNotOpened: if either socket or socket.sock are None + """ + if self.socket is None: + raise ScrapliConnectionNotOpened + if self.socket.sock is None: + raise ScrapliConnectionNotOpened + self.socket.sock.settimeout(timeout) + + def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes: + """ + Handle the actual response to control characters + + Broken up to be easier to test as well as to appease mr. mccabe + + NOTE: see the asynctelnet transport for additional comments inline about what is going on + here. + + Args: + control_buf: current control_buf to work with + c: currently read control char to process + + Returns: + bytes: updated control_buf + + Raises: + ScrapliConnectionNotOpened: if connection is not opened for some reason + + """ + if not self.socket: + raise ScrapliConnectionNotOpened + + if not control_buf: + if c != IAC: + self._initial_buf += c + else: + control_buf += c + + elif len(control_buf) == 1 and c in (DO, DONT, WILL, WONT): + control_buf += c + + elif len(control_buf) == 2: + cmd = control_buf[1:2] + control_buf = b"" + + if (cmd == DO) and (c == SUPPRESS_GO_AHEAD): + self.write(IAC + WILL + c) + elif cmd in (DO, DONT): + self.write(IAC + WONT + c) + elif cmd == WILL: + self.write(IAC + DO + c) + elif cmd == WONT: + self.write(IAC + DONT + c) + + return control_buf + + def _handle_control_chars(self) -> None: + """ + Handle control characters -- nearly identical to CPython (removed in 3.11) telnetlib + + Basically we want to read and "decline" any and all control options that the server proposes + to us -- so if they say "DO" XYZ directive, we say "DONT", if they say "WILL" we say "WONT". + + NOTE: see the asynctelnet transport for additional comments inline about what is going on + here. + + Args: N/A + Returns: + None + + Raises: + ScrapliConnectionNotOpened: if connection is not opened for some reason + ScrapliConnectionNotOpened: if we read an empty byte string from the reader -- this + indicates the server sent an EOF -- see #142 + """ - self.eof: bool - self.timeout: float + if not self.socket: + raise ScrapliConnectionNotOpened - super().__init__(host, port, int(timeout)) + control_buf = b"" + original_socket_timeout = self._base_transport_args.timeout_socket + self._set_socket_timeout(self._base_transport_args.timeout_socket / 4) -class TelnetTransport(Transport): - def __init__( - self, base_transport_args: BaseTransportArgs, plugin_transport_args: PluginTransportArgs - ) -> None: - super().__init__(base_transport_args=base_transport_args) - self.plugin_transport_args = plugin_transport_args + while True: + try: + c = self._read(1) + if not c: + raise ScrapliConnectionNotOpened("server returned EOF, connection not opened") + except TimeoutError: + # shouldn't really matter/need to be reset back to "normal", but don't really want + # to leave it modified as that would be confusing! + self._base_transport_args.timeout_socket = original_socket_timeout + return - self.session: Optional[ScrapliTelnet] = None + self._set_socket_timeout(self._base_transport_args.timeout_socket / 10) + control_buf = self._handle_control_chars_response(control_buf=control_buf, c=c) def open(self) -> None: self._pre_open_closing_log(closing=False) - # establish session with "socket" timeout, then reset timeout to "transport" timeout - try: - self.session = ScrapliTelnet( + if not self.socket: + self.socket = Socket( host=self._base_transport_args.host, port=self._base_transport_args.port, timeout=self._base_transport_args.timeout_socket, ) - self.session.timeout = self._base_transport_args.timeout_transport - except ConnectionError as exc: - msg = f"Failed to open telnet session to host {self._base_transport_args.host}" - if "connection refused" in str(exc).lower(): - msg = ( - f"Failed to open telnet session to host {self._base_transport_args.host}, " - "connection refused" - ) - raise ScrapliConnectionError(msg) from exc + + if not self.socket.isalive(): + self.socket.open() + + self._handle_control_chars() self._post_open_closing_log(closing=False) def close(self) -> None: self._pre_open_closing_log(closing=True) - if self.session: - self.session.close() + if self.socket: + self.socket.close() - self.session = None + self.socket = None self._post_open_closing_log(closing=True) def isalive(self) -> bool: - if not self.session: + if not self.socket: + return False + if not self.socket.isalive(): return False - return not self.session.eof + return True - @TransportTimeout("timed out reading from transport") + def _read(self, n: int = 65535) -> bytes: + """ + Read n bytes from the socket + + Mostly this exists just to assert that socket and socket.sock are not None to appease mypy! + + Args: + n: optional amount of bytes to try to recv from the underlying socket + + Returns: + N/A + + Raises: + ScrapliConnectionNotOpened: if either socket or socket.sock are None + """ + if self.socket is None: + raise ScrapliConnectionNotOpened + if self.socket.sock is None: + raise ScrapliConnectionNotOpened + return self.socket.sock.recv(n) + + @timeout_wrapper def read(self) -> bytes: - if not self.session: + if not self.socket: raise ScrapliConnectionNotOpened + + if self._initial_buf: + buf = self._initial_buf + self._initial_buf = b"" + return buf + try: - buf = self.session.read_eager() + buf = self._read() + buf = buf.replace(b"\x00", b"") except Exception as exc: raise ScrapliConnectionError( "encountered EOF reading from transport; typically means the device closed the " @@ -126,9 +242,11 @@ class TelnetTransport(Transport): return buf def write(self, channel_input: bytes) -> None: - if not self.session: + if self.socket is None: + raise ScrapliConnectionNotOpened + if self.socket.sock is None: raise ScrapliConnectionNotOpened - self.session.write(channel_input) + self.socket.sock.send(channel_input) @@ -164,67 +282,17 @@ class PluginTransportArgs(BasePluginTransportArgs): -### ScrapliTelnet +### TelnetTransport ```text -Telnet interface class. - -An instance of this class represents a connection to a telnet -server. The instance is initially not connected; the open() -method must be used to establish a connection. Alternatively, the -host name and optional port number can be passed to the -constructor, too. - -Don't try to reopen an already connected instance. - -This class has many read_*() methods. Note that some of them -raise EOFError when the end of the connection is read, because -they can return an empty string for other reasons. See the -individual doc strings. - -read_until(expected, [timeout]) - Read until the expected string has been seen, or a timeout is - hit (default is no timeout); may block. - -read_all() - Read all data until EOF; may block. - -read_some() - Read at least one byte or EOF; may block. - -read_very_eager() - Read all data available already queued or on the socket, - without blocking. - -read_eager() - Read either data already queued or some data available on the - socket, without blocking. - -read_lazy() - Read all data in the raw queue (processing it first), without - doing any socket I/O. - -read_very_lazy() - Reads all data in the cooked queue, without doing any socket - I/O. - -read_sb_data() - Reads available data between SB ... SE sequence. Don't block. - -set_option_negotiation_callback(callback) - Each time a telnet option is read on the input flow, this callback - (if set) is called with the following parameters : - callback(telnet socket, command, option) - option will be chr(0) when there is no option. - No other action is done afterwards by telnetlib. +Helper class that provides a standard way to create an ABC using +inheritance. -ScrapliTelnet class for typing purposes +Scrapli's transport base class Args: - host: string of host - port: integer port to connect to - timeout: timeout value in seconds + base_transport_args: base transport args dataclass Returns: None @@ -239,114 +307,196 @@ Raises:
         
-class ScrapliTelnet(Telnet):
-    def __init__(self, host: str, port: int, timeout: float) -> None:
+class TelnetTransport(Transport):
+    def __init__(
+        self, base_transport_args: BaseTransportArgs, plugin_transport_args: PluginTransportArgs
+    ) -> None:
+        super().__init__(base_transport_args=base_transport_args)
+        self.plugin_transport_args = plugin_transport_args
+
+        self.socket: Optional[Socket] = None
+        self._initial_buf = b""
+
+    def _set_socket_timeout(self, timeout: float) -> None:
         """
-        ScrapliTelnet class for typing purposes
+        Set underlying socket timeout
+
+        Mostly this exists just to assert that socket and socket.sock are not None to appease mypy!
 
         Args:
-            host: string of host
-            port: integer port to connect to
-            timeout: timeout value in seconds
+            timeout: float value to set as the timeout
 
         Returns:
-            None
+            N/A
 
         Raises:
-            N/A
+            ScrapliConnectionNotOpened: if either socket or socket.sock are None
+        """
+        if self.socket is None:
+            raise ScrapliConnectionNotOpened
+        if self.socket.sock is None:
+            raise ScrapliConnectionNotOpened
+        self.socket.sock.settimeout(timeout)
 
+    def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes:
         """
-        self.eof: bool
-        self.timeout: float
+        Handle the actual response to control characters
 
-        super().__init__(host, port, int(timeout))
-        
-    
- + Broken up to be easier to test as well as to appease mr. mccabe + NOTE: see the asynctelnet transport for additional comments inline about what is going on + here. -#### Ancestors (in MRO) -- telnetlib.Telnet + Args: + control_buf: current control_buf to work with + c: currently read control char to process + Returns: + bytes: updated control_buf + Raises: + ScrapliConnectionNotOpened: if connection is not opened for some reason -### TelnetTransport + """ + if not self.socket: + raise ScrapliConnectionNotOpened + if not control_buf: + if c != IAC: + self._initial_buf += c + else: + control_buf += c -```text -Helper class that provides a standard way to create an ABC using -inheritance. + elif len(control_buf) == 1 and c in (DO, DONT, WILL, WONT): + control_buf += c -Scrapli's transport base class + elif len(control_buf) == 2: + cmd = control_buf[1:2] + control_buf = b"" -Args: - base_transport_args: base transport args dataclass + if (cmd == DO) and (c == SUPPRESS_GO_AHEAD): + self.write(IAC + WILL + c) + elif cmd in (DO, DONT): + self.write(IAC + WONT + c) + elif cmd == WILL: + self.write(IAC + DO + c) + elif cmd == WONT: + self.write(IAC + DONT + c) -Returns: - None + return control_buf -Raises: - N/A -``` + def _handle_control_chars(self) -> None: + """ + Handle control characters -- nearly identical to CPython (removed in 3.11) telnetlib -
- - Expand source code - -
-        
-class TelnetTransport(Transport):
-    def __init__(
-        self, base_transport_args: BaseTransportArgs, plugin_transport_args: PluginTransportArgs
-    ) -> None:
-        super().__init__(base_transport_args=base_transport_args)
-        self.plugin_transport_args = plugin_transport_args
+        Basically we want to read and "decline" any and all control options that the server proposes
+        to us -- so if they say "DO" XYZ directive, we say "DONT", if they say "WILL" we say "WONT".
+
+        NOTE: see the asynctelnet transport for additional comments inline about what is going on
+        here.
+
+        Args:
+            N/A
+
+        Returns:
+            None
+
+        Raises:
+            ScrapliConnectionNotOpened: if connection is not opened for some reason
+            ScrapliConnectionNotOpened: if we read an empty byte string from the reader -- this
+                indicates the server sent an EOF -- see #142
+
+        """
+        if not self.socket:
+            raise ScrapliConnectionNotOpened
+
+        control_buf = b""
+
+        original_socket_timeout = self._base_transport_args.timeout_socket
+        self._set_socket_timeout(self._base_transport_args.timeout_socket / 4)
 
-        self.session: Optional[ScrapliTelnet] = None
+        while True:
+            try:
+                c = self._read(1)
+                if not c:
+                    raise ScrapliConnectionNotOpened("server returned EOF, connection not opened")
+            except TimeoutError:
+                # shouldn't really matter/need to be reset back to "normal", but don't really want
+                # to leave it modified as that would be confusing!
+                self._base_transport_args.timeout_socket = original_socket_timeout
+                return
+
+            self._set_socket_timeout(self._base_transport_args.timeout_socket / 10)
+            control_buf = self._handle_control_chars_response(control_buf=control_buf, c=c)
 
     def open(self) -> None:
         self._pre_open_closing_log(closing=False)
 
-        # establish session with "socket" timeout, then reset timeout to "transport" timeout
-        try:
-            self.session = ScrapliTelnet(
+        if not self.socket:
+            self.socket = Socket(
                 host=self._base_transport_args.host,
                 port=self._base_transport_args.port,
                 timeout=self._base_transport_args.timeout_socket,
             )
-            self.session.timeout = self._base_transport_args.timeout_transport
-        except ConnectionError as exc:
-            msg = f"Failed to open telnet session to host {self._base_transport_args.host}"
-            if "connection refused" in str(exc).lower():
-                msg = (
-                    f"Failed to open telnet session to host {self._base_transport_args.host}, "
-                    "connection refused"
-                )
-            raise ScrapliConnectionError(msg) from exc
+
+        if not self.socket.isalive():
+            self.socket.open()
+
+        self._handle_control_chars()
 
         self._post_open_closing_log(closing=False)
 
     def close(self) -> None:
         self._pre_open_closing_log(closing=True)
 
-        if self.session:
-            self.session.close()
+        if self.socket:
+            self.socket.close()
 
-        self.session = None
+        self.socket = None
 
         self._post_open_closing_log(closing=True)
 
     def isalive(self) -> bool:
-        if not self.session:
+        if not self.socket:
+            return False
+        if not self.socket.isalive():
             return False
-        return not self.session.eof
+        return True
 
-    @TransportTimeout("timed out reading from transport")
+    def _read(self, n: int = 65535) -> bytes:
+        """
+        Read n bytes from the socket
+
+        Mostly this exists just to assert that socket and socket.sock are not None to appease mypy!
+
+        Args:
+            n: optional amount of bytes to try to recv from the underlying socket
+
+        Returns:
+            N/A
+
+        Raises:
+            ScrapliConnectionNotOpened: if either socket or socket.sock are None
+        """
+        if self.socket is None:
+            raise ScrapliConnectionNotOpened
+        if self.socket.sock is None:
+            raise ScrapliConnectionNotOpened
+        return self.socket.sock.recv(n)
+
+    @timeout_wrapper
     def read(self) -> bytes:
-        if not self.session:
+        if not self.socket:
             raise ScrapliConnectionNotOpened
+
+        if self._initial_buf:
+            buf = self._initial_buf
+            self._initial_buf = b""
+            return buf
+
         try:
-            buf = self.session.read_eager()
+            buf = self._read()
+            buf = buf.replace(b"\x00", b"")
         except Exception as exc:
             raise ScrapliConnectionError(
                 "encountered EOF reading from transport; typically means the device closed the "
@@ -355,9 +505,11 @@ class TelnetTransport(Transport):
         return buf
 
     def write(self, channel_input: bytes) -> None:
-        if not self.session:
+        if self.socket is None:
+            raise ScrapliConnectionNotOpened
+        if self.socket.sock is None:
             raise ScrapliConnectionNotOpened
-        self.session.write(channel_input)
+        self.socket.sock.send(channel_input)
         
     
diff --git a/setup.py b/setup.py index b81183f0..ae38ea88 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ import setuptools -__version__ = "2022.01.30.post1" +__version__ = "2022.07.30" __author__ = "Carl Montanari" with open("README.md", "r", encoding="utf-8") as f: