diff --git a/docs/api_docs/channel/async_channel.md b/docs/api_docs/channel/async_channel.md
index 8a04cc28..97c167fc 100644
--- a/docs/api_docs/channel/async_channel.md
+++ b/docs/api_docs/channel/async_channel.md
@@ -32,19 +32,13 @@ scrapli.channel.async_channel
import asyncio
import re
import time
-from io import BytesIO
-
-try:
- from contextlib import asynccontextmanager
-except ImportError: # pragma: nocover
- # needed for 3.6 support, no asynccontextmanager until 3.7
- from async_generator import asynccontextmanager # type: ignore # pragma: nocover
-
+from contextlib import asynccontextmanager
from datetime import datetime
+from io import BytesIO
from typing import AsyncIterator, List, Optional, Tuple
from scrapli.channel.base_channel import BaseChannel, BaseChannelArgs
-from scrapli.decorators import ChannelTimeout
+from scrapli.decorators import timeout_wrapper
from scrapli.exceptions import ScrapliAuthenticationFailed, ScrapliTimeout
from scrapli.transport.base import AsyncTransport
@@ -290,7 +284,7 @@ class AsyncChannel(BaseChannel):
return read_buf.getvalue()
- @ChannelTimeout(message="timed out during in channel ssh authentication")
+ @timeout_wrapper
async def channel_authenticate_ssh(
self, auth_password: str, auth_private_key_passphrase: str
) -> None:
@@ -363,7 +357,7 @@ class AsyncChannel(BaseChannel):
):
return
- @ChannelTimeout(message="timed out during in channel telnet authentication")
+ @timeout_wrapper
async def channel_authenticate_telnet( # noqa: C901
self, auth_username: str = "", auth_password: str = ""
) -> None:
@@ -450,7 +444,7 @@ class AsyncChannel(BaseChannel):
):
return
- @ChannelTimeout(message="timed out getting prompt")
+ @timeout_wrapper
async def get_prompt(self) -> str:
"""
Get current channel prompt
@@ -486,7 +480,7 @@ class AsyncChannel(BaseChannel):
current_prompt = channel_match.group(0)
return current_prompt.decode().strip()
- @ChannelTimeout(message="timed out sending input to device")
+ @timeout_wrapper
async def send_input(
self,
channel_input: str,
@@ -534,7 +528,7 @@ class AsyncChannel(BaseChannel):
)
return buf, processed_buf
- @ChannelTimeout(message="timed out sending input to device")
+ @timeout_wrapper
async def send_input_and_read(
self,
channel_input: str,
@@ -589,7 +583,7 @@ class AsyncChannel(BaseChannel):
return buf, processed_buf
- @ChannelTimeout(message="timed out sending interactive input to device")
+ @timeout_wrapper
async def send_inputs_interact(
self,
interact_events: List[Tuple[str, str, Optional[bool]]],
@@ -970,7 +964,7 @@ class AsyncChannel(BaseChannel):
return read_buf.getvalue()
- @ChannelTimeout(message="timed out during in channel ssh authentication")
+ @timeout_wrapper
async def channel_authenticate_ssh(
self, auth_password: str, auth_private_key_passphrase: str
) -> None:
@@ -1043,7 +1037,7 @@ class AsyncChannel(BaseChannel):
):
return
- @ChannelTimeout(message="timed out during in channel telnet authentication")
+ @timeout_wrapper
async def channel_authenticate_telnet( # noqa: C901
self, auth_username: str = "", auth_password: str = ""
) -> None:
@@ -1130,7 +1124,7 @@ class AsyncChannel(BaseChannel):
):
return
- @ChannelTimeout(message="timed out getting prompt")
+ @timeout_wrapper
async def get_prompt(self) -> str:
"""
Get current channel prompt
@@ -1166,7 +1160,7 @@ class AsyncChannel(BaseChannel):
current_prompt = channel_match.group(0)
return current_prompt.decode().strip()
- @ChannelTimeout(message="timed out sending input to device")
+ @timeout_wrapper
async def send_input(
self,
channel_input: str,
@@ -1214,7 +1208,7 @@ class AsyncChannel(BaseChannel):
)
return buf, processed_buf
- @ChannelTimeout(message="timed out sending input to device")
+ @timeout_wrapper
async def send_input_and_read(
self,
channel_input: str,
@@ -1269,7 +1263,7 @@ class AsyncChannel(BaseChannel):
return buf, processed_buf
- @ChannelTimeout(message="timed out sending interactive input to device")
+ @timeout_wrapper
async def send_inputs_interact(
self,
interact_events: List[Tuple[str, str, Optional[bool]]],
diff --git a/docs/api_docs/channel/sync_channel.md b/docs/api_docs/channel/sync_channel.md
index 999cedab..d55df88f 100644
--- a/docs/api_docs/channel/sync_channel.md
+++ b/docs/api_docs/channel/sync_channel.md
@@ -38,7 +38,7 @@ from threading import Lock
from typing import Iterator, List, Optional, Tuple
from scrapli.channel.base_channel import BaseChannel, BaseChannelArgs
-from scrapli.decorators import ChannelTimeout
+from scrapli.decorators import timeout_wrapper
from scrapli.exceptions import ScrapliAuthenticationFailed, ScrapliTimeout
from scrapli.transport.base import Transport
@@ -281,7 +281,7 @@ class Channel(BaseChannel):
return read_buf.getvalue()
- @ChannelTimeout(message="timed out during in channel ssh authentication")
+ @timeout_wrapper
def channel_authenticate_ssh(
self, auth_password: str, auth_private_key_passphrase: str
) -> None:
@@ -353,7 +353,7 @@ class Channel(BaseChannel):
):
return
- @ChannelTimeout(message="timed out during in channel telnet authentication")
+ @timeout_wrapper
def channel_authenticate_telnet(self, auth_username: str = "", auth_password: str = "") -> None:
"""
Handle Telnet Authentication
@@ -434,7 +434,7 @@ class Channel(BaseChannel):
):
return
- @ChannelTimeout(message="timed out getting prompt")
+ @timeout_wrapper
def get_prompt(self) -> str:
"""
Get current channel prompt
@@ -470,7 +470,7 @@ class Channel(BaseChannel):
current_prompt = channel_match.group(0)
return current_prompt.decode().strip()
- @ChannelTimeout(message="timed out sending input to device")
+ @timeout_wrapper
def send_input(
self,
channel_input: str,
@@ -518,7 +518,7 @@ class Channel(BaseChannel):
)
return buf, processed_buf
- @ChannelTimeout(message="timed out sending input to device")
+ @timeout_wrapper
def send_input_and_read(
self,
channel_input: str,
@@ -573,7 +573,7 @@ class Channel(BaseChannel):
return buf, processed_buf
- @ChannelTimeout(message="timed out sending interactive input to device")
+ @timeout_wrapper
def send_inputs_interact(
self,
interact_events: List[Tuple[str, str, Optional[bool]]],
@@ -951,7 +951,7 @@ class Channel(BaseChannel):
return read_buf.getvalue()
- @ChannelTimeout(message="timed out during in channel ssh authentication")
+ @timeout_wrapper
def channel_authenticate_ssh(
self, auth_password: str, auth_private_key_passphrase: str
) -> None:
@@ -1023,7 +1023,7 @@ class Channel(BaseChannel):
):
return
- @ChannelTimeout(message="timed out during in channel telnet authentication")
+ @timeout_wrapper
def channel_authenticate_telnet(self, auth_username: str = "", auth_password: str = "") -> None:
"""
Handle Telnet Authentication
@@ -1104,7 +1104,7 @@ class Channel(BaseChannel):
):
return
- @ChannelTimeout(message="timed out getting prompt")
+ @timeout_wrapper
def get_prompt(self) -> str:
"""
Get current channel prompt
@@ -1140,7 +1140,7 @@ class Channel(BaseChannel):
current_prompt = channel_match.group(0)
return current_prompt.decode().strip()
- @ChannelTimeout(message="timed out sending input to device")
+ @timeout_wrapper
def send_input(
self,
channel_input: str,
@@ -1188,7 +1188,7 @@ class Channel(BaseChannel):
)
return buf, processed_buf
- @ChannelTimeout(message="timed out sending input to device")
+ @timeout_wrapper
def send_input_and_read(
self,
channel_input: str,
@@ -1243,7 +1243,7 @@ class Channel(BaseChannel):
return buf, processed_buf
- @ChannelTimeout(message="timed out sending interactive input to device")
+ @timeout_wrapper
def send_inputs_interact(
self,
interact_events: List[Tuple[str, str, Optional[bool]]],
diff --git a/docs/api_docs/decorators.md b/docs/api_docs/decorators.md
index 12c28830..dc0387c4 100644
--- a/docs/api_docs/decorators.md
+++ b/docs/api_docs/decorators.md
@@ -34,908 +34,359 @@ import signal
import sys
import threading
from concurrent.futures import ThreadPoolExecutor, wait
-from functools import update_wrapper
+from functools import partial, update_wrapper
from logging import Logger, LoggerAdapter
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Any, Callable, Tuple
from scrapli.exceptions import ScrapliTimeout
if TYPE_CHECKING:
- from scrapli.channel import Channel # pragma: no cover
from scrapli.driver import AsyncGenericDriver, GenericDriver # pragma: no cover
from scrapli.transport.base.base_transport import BaseTransport # pragma: no cover
if TYPE_CHECKING:
- LoggerAdapterT = LoggerAdapter[Logger] # pylint:disable=E1136
+ LoggerAdapterT = LoggerAdapter[Logger] # pragma: no cover # pylint:disable=E1136
else:
LoggerAdapterT = LoggerAdapter
_IS_WINDOWS = sys.platform.startswith("win")
-class TransportTimeout:
- def __init__(self, message: str = "") -> None:
- """
- Transport timeout decorator
+FUNC_TIMEOUT_MESSAGE_MAP = {
+ "channel_authenticate_ssh": "timed out during in channel ssh authentication",
+ "channel_authenticate_telnet": "timed out during in channel telnet authentication",
+ "get_prompt": "timed out getting prompt",
+ "send_input": "timed out sending input to device",
+ "send_input_and_read": "timed out sending input to device",
+ "send_inputs_interact": "timed out sending interactive input to device",
+ "read": "timed out reading from transport",
+}
+
+
+def _get_timeout_message(func_name: str) -> str:
+ """
+ Return appropriate timeout message for the given function name
+
+ Args:
+ func_name: name of function to fetch timeout message for
+
+ Returns:
+ str: timeout message
+
+ Raises:
+ N/A
+
+ """
+ return FUNC_TIMEOUT_MESSAGE_MAP.get(func_name, "unspecified timeout occurred")
+
+
+def _signal_raise_exception(
+ signum: Any, frame: Any, transport: "BaseTransport", logger: LoggerAdapterT, message: str
+) -> None:
+ """
+ Signal method exception handler
+
+ Args:
+ signum: singum from the singal handler, unused here
+ frame: frame from the signal handler, unused here
+ transport: transport to close
+ logger: logger to write closing messages to
+ message: exception message
+
+ Returns:
+ None
+
+ Raises:
+ N/A
+
+ """
+ _, _ = signum, frame
+
+ return _handle_timeout(transport=transport, logger=logger, message=message)
+
+
+def _multiprocessing_timeout(
+ transport: "BaseTransport",
+ logger: LoggerAdapterT,
+ timeout: float,
+ wrapped_func: Callable[..., Any],
+ args: Any,
+ kwargs: Any,
+) -> Any:
+ """
+ Return appropriate timeout message for the given function name
+
+ Args:
+ transport: transport to close (if timeout occurs)
+ logger: logger to write closing message to
+ timeout: timeout in seconds
+ wrapped_func: function being decorated
+ args: function args
+ kwargs: function kwargs
+
+ Returns:
+ Any: result of the wrapped function
+
+ Raises:
+ N/A
+
+ """
+ with ThreadPoolExecutor(max_workers=1) as pool:
+ future = pool.submit(wrapped_func, *args, **kwargs)
+ wait([future], timeout=timeout)
+ if not future.done():
+ return _handle_timeout(
+ transport=transport,
+ logger=logger,
+ message=_get_timeout_message(func_name=wrapped_func.__name__),
+ )
+ return future.result()
- Args:
- message: accepts message from decorated function to add context to any timeout
- (if a timeout happens!)
- Returns:
- None
+def _handle_timeout(transport: "BaseTransport", logger: LoggerAdapterT, message: str) -> None:
+ """
+ Timeout handler method to close connections and raise ScrapliTimeout
- Raises:
- N/A
+ Args:
+ transport: transport to close
+ logger: logger to write closing message to
+ message: message to pass to ScrapliTimeout exception
- """
- self.message = message
- self.transport_instance: "BaseTransport"
- self.transport_timeout_transport = 0.0
+ Returns:
+ None
- def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]:
- """
- Decorate an "operation" to modify the timeout_transport value for duration of that operation
+ Raises:
+ ScrapliTimeout: always, if we hit this method we have already timed out!
- This decorator wraps a transport read operation and is used to allow users to control the
- transport timeout via the `timeout_transport` attribute. This decorator should be applied to
- any transport "read" operations.
+ """
+ logger.critical("operation timed out, closing connection")
+ transport.close()
+ raise ScrapliTimeout(message)
- Args:
- wrapped_func: function being decorated
- Returns:
- decorate: decorated func
+def _get_transport_logger_timeout(
+ cls: Any,
+) -> Tuple["BaseTransport", LoggerAdapterT, float]:
+ """
+ Fetch the transport, logger and timeout from the channel or transport object
- Raises:
- N/A
+ Args:
+ cls: Channel or Transport object (self from wrapped function) to grab transport/logger and
+ timeout values from
- """
+ Returns:
+ Tuple: transport, logger, and timeout value
- if asyncio.iscoroutinefunction(wrapped_func):
+ Raises:
+ N/A
- async def decorate(*args: Any, **kwargs: Any) -> Any:
- self.transport_instance = args[0]
- self.transport_timeout_transport = self._get_timeout_transport()
+ """
+ if hasattr(cls, "transport"):
+ return (
+ cls.transport,
+ cls.logger,
+ cls._base_channel_args.timeout_ops, # pylint: disable=W0212
+ )
- if not self.transport_timeout_transport:
- return await wrapped_func(*args, **kwargs)
+ return (
+ cls,
+ cls.logger,
+ cls._base_transport_args.timeout_transport, # pylint: disable=W0212
+ )
- try:
- return await asyncio.wait_for(
- wrapped_func(*args, **kwargs), timeout=self.transport_timeout_transport
- )
- except asyncio.TimeoutError:
- self._handle_timeout()
-
- else:
- # ignoring type error:
- # "All conditional function variants must have identical signatures"
- # one is sync one is async so never going to be identical here!
- def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore
- self.transport_instance = args[0]
- self.transport_timeout_transport = self._get_timeout_transport()
-
- if not self.transport_timeout_transport:
- return wrapped_func(*args, **kwargs)
-
- transport_instance_class_name = self.transport_instance.__class__.__name__
-
- if (
- transport_instance_class_name in ("SystemTransport", "TelnetTransport")
- or _IS_WINDOWS
- or threading.current_thread() is not threading.main_thread()
- ):
- return self._multiprocessing_timeout(
- wrapped_func=wrapped_func,
- args=args,
- kwargs=kwargs,
- )
-
- old = signal.signal(signal.SIGALRM, self._signal_raise_exception)
- signal.setitimer(signal.ITIMER_REAL, self.transport_timeout_transport)
- try:
- return wrapped_func(*args, **kwargs)
- finally:
- if self.transport_timeout_transport:
- signal.setitimer(signal.ITIMER_REAL, 0)
- signal.signal(signal.SIGALRM, old)
-
- # ensures that the wrapped function is updated w/ the original functions docs/etc. --
- # necessary for introspection for the auto gen docs to work!
- update_wrapper(wrapper=decorate, wrapped=wrapped_func)
- return decorate
-
- def _get_timeout_transport(self) -> float:
- """
- Fetch and return timeout transport from the transport object
-
- Args:
- N/A
-
- Returns:
- float: transport timeout value
-
- Raises:
- N/A
-
- """
- transport_args = self.transport_instance._base_transport_args # pylint: disable=W0212
- return transport_args.timeout_transport
-
- def _handle_timeout(self) -> None:
- """
- Timeout handler method to close connections and raise ScrapliTimeout
-
- Args:
- N/A
-
- Returns:
- None
-
- Raises:
- ScrapliTimeout: always, if we hit this method we have already timed out!
-
- """
- self.transport_instance.logger.critical("transport operation timed out, closing transport")
- self.transport_instance.close()
- raise ScrapliTimeout(self.message)
-
- def _multiprocessing_timeout(
- self, wrapped_func: Callable[..., Any], args: Any, kwargs: Any
- ) -> Any:
- """
- Multiprocessing method for timeouts; works in threads and on windows
-
- Args:
- wrapped_func: function being decorated
- args: function being decorated args
- kwargs: function being decorated kwargs
-
- Returns:
- Any: result of decorated function
-
- Raises:
- N/A
-
- """
- with ThreadPoolExecutor(max_workers=1) as pool:
- future = pool.submit(wrapped_func, *args, **kwargs)
- wait([future], timeout=self.transport_timeout_transport)
- if not future.done():
- self._handle_timeout()
- return future.result()
- def _signal_raise_exception(self, signum: Any, frame: Any) -> None:
- """
- Signal method exception handler
+def timeout_wrapper(wrapped_func: Callable[..., Any]) -> Callable[..., Any]:
+ """
+ Timeout wrapper for transports
- Args:
- signum: singum from the singal handler, unused here
- frame: frame from the signal handler, unused here
+ Args:
+ wrapped_func: function being wrapped -- must be a method of Channel or Transport
- Returns:
- None
+ Returns:
+ Any: result of wrapped function
- Raises:
- N/A
+ Raises:
+ N/A
- """
- _, _ = signum, frame
- self._handle_timeout()
+ """
+ if asyncio.iscoroutinefunction(wrapped_func):
+ async def decorate(*args: Any, **kwargs: Any) -> Any:
+ transport, logger, timeout = _get_transport_logger_timeout(cls=args[0])
-class ChannelTimeout:
- def __init__(self, message: str = "") -> None:
- """
- Channel timeout decorator
+ if not timeout:
+ return await wrapped_func(*args, **kwargs)
- Args:
- message: accepts message from decorated function to add context to any timeout
- (if a timeout happens!)
+ try:
+ return await asyncio.wait_for(wrapped_func(*args, **kwargs), timeout=timeout)
+ except asyncio.TimeoutError:
+ _handle_timeout(
+ transport=transport,
+ logger=logger,
+ message=_get_timeout_message(func_name=wrapped_func.__name__),
+ )
- Returns:
- None
+ else:
+ # ignoring type error:
+ # "All conditional function variants must have identical signatures"
+ # one is sync one is async so never going to be identical here!
+ def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore
+ transport, logger, timeout = _get_transport_logger_timeout(cls=args[0])
+
+ if not timeout:
+ return wrapped_func(*args, **kwargs)
+
+ cls_name = transport.__class__.__name__
+
+ if (
+ cls_name in ("SystemTransport", "TelnetTransport")
+ or _IS_WINDOWS
+ or threading.current_thread() is not threading.main_thread()
+ ):
+ return _multiprocessing_timeout(
+ transport=transport,
+ logger=logger,
+ timeout=timeout,
+ wrapped_func=wrapped_func,
+ args=args,
+ kwargs=kwargs,
+ )
- Raises:
- N/A
+ callback = partial(
+ _signal_raise_exception,
+ transport=transport,
+ logger=logger,
+ message=_get_timeout_message(wrapped_func.__name__),
+ )
- """
- self.message = message
- self.channel_timeout_ops = 0.0
- self.channel_logger: LoggerAdapterT
- self.transport_instance: "BaseTransport"
+ old = signal.signal(signal.SIGALRM, callback)
+ signal.setitimer(signal.ITIMER_REAL, timeout)
+ try:
+ return wrapped_func(*args, **kwargs)
+ finally:
+ if timeout:
+ signal.setitimer(signal.ITIMER_REAL, 0)
+ signal.signal(signal.SIGALRM, old)
- def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]:
- """
- Decorate an "operation" to modify the timeout_ops value for duration of that operation
+ # ensures that the wrapped function is updated w/ the original functions docs/etc. --
+ # necessary for introspection for the auto gen docs to work!
+ update_wrapper(wrapper=decorate, wrapped=wrapped_func)
+ return decorate
- This decorator wraps send command/config ops and is used to allow users to set a
- `timeout_ops` value for the duration of a single method call -- this makes it so users don't
- need to manually set/reset the value
- Args:
- wrapped_func: function being decorated
+def timeout_modifier(wrapped_func: Callable[..., Any]) -> Callable[..., Any]:
+ """
+ Decorate an "operation" to modify the timeout_ops value for duration of that operation
- Returns:
- decorate: decorated func
+ This decorator wraps send command/config ops and is used to allow users to set a
+ `timeout_ops` value for the duration of a single method call -- this makes it so users don't
+ need to manually set/reset the value
- Raises:
- N/A
+ Args:
+ wrapped_func: function being decorated
- """
- if asyncio.iscoroutinefunction(wrapped_func):
+ Returns:
+ decorate: decorated func
- async def decorate(*args: Any, **kwargs: Any) -> Any:
- channel_instance: "Channel" = args[0]
- self.channel_logger = channel_instance.logger
- self.channel_timeout_ops = (
- channel_instance._base_channel_args.timeout_ops # pylint: disable=W0212
- )
+ Raises:
+ N/A
- if not self.channel_timeout_ops:
- return await wrapped_func(*args, **kwargs)
-
- self.transport_instance = channel_instance.transport
-
- try:
- return await asyncio.wait_for(
- wrapped_func(*args, **kwargs), timeout=self.channel_timeout_ops
- )
- except asyncio.TimeoutError:
- self._handle_timeout()
-
- else:
- # ignoring type error:
- # "All conditional function variants must have identical signatures"
- # one is sync one is async so never going to be identical here!
- def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore
- channel_instance: "Channel" = args[0]
- self.channel_logger = channel_instance.logger
- self.channel_timeout_ops = (
- channel_instance._base_channel_args.timeout_ops # pylint: disable=W0212
- )
+ """
+ if asyncio.iscoroutinefunction(wrapped_func):
- if not self.channel_timeout_ops:
- return wrapped_func(*args, **kwargs)
-
- self.transport_instance = channel_instance.transport
- transport_instance_class_name = self.transport_instance.__class__.__name__
-
- if (
- transport_instance_class_name in ("SystemTransport", "TelnetTransport")
- or _IS_WINDOWS
- or threading.current_thread() is not threading.main_thread()
- ):
- return self._multiprocessing_timeout(
- wrapped_func=wrapped_func,
- args=args,
- kwargs=kwargs,
- )
-
- old = signal.signal(signal.SIGALRM, self._signal_raise_exception)
- signal.setitimer(signal.ITIMER_REAL, self.channel_timeout_ops)
- try:
- return wrapped_func(*args, **kwargs)
- finally:
- if self.channel_timeout_ops:
- signal.setitimer(signal.ITIMER_REAL, 0)
- signal.signal(signal.SIGALRM, old)
-
- # ensures that the wrapped function is updated w/ the original functions docs/etc. --
- # necessary for introspection for the auto gen docs to work!
- update_wrapper(wrapper=decorate, wrapped=wrapped_func)
- return decorate
-
- def _handle_timeout(self) -> None:
- """
- Timeout handler method to close connections and raise ScrapliTimeout
-
- Args:
- N/A
-
- Returns:
- None
-
- Raises:
- ScrapliTimeout: always, if we hit this method we have already timed out!
-
- """
- self.channel_logger.critical("channel operation timed out, closing transport")
- self.transport_instance.close()
- raise ScrapliTimeout(self.message)
-
- def _multiprocessing_timeout(
- self, wrapped_func: Callable[..., Any], args: Any, kwargs: Any
- ) -> Any:
- """
- Multiprocessing method for timeouts; works in threads and on windows
-
- Args:
- wrapped_func: function being decorated
- args: function being decorated args
- kwargs: function being decorated kwargs
-
- Returns:
- Any: result of decorated function
-
- Raises:
- N/A
-
- """
- with ThreadPoolExecutor(max_workers=1) as pool:
- future = pool.submit(wrapped_func, *args, **kwargs)
- wait([future], timeout=self.channel_timeout_ops)
- if not future.done():
- self._handle_timeout()
- return future.result()
+ async def decorate(*args: Any, **kwargs: Any) -> Any:
+ driver_instance: "AsyncGenericDriver" = args[0]
+ driver_logger = driver_instance.logger
+
+ timeout_ops_kwarg = kwargs.get("timeout_ops", None)
- def _signal_raise_exception(self, signum: Any, frame: Any) -> None:
- """
- Signal method exception handler
-
- Args:
- signum: singum from the singal handler, unused here
- frame: frame from the signal handler, unused here
-
- Returns:
- None
-
- Raises:
- N/A
-
- """
- _, _ = signum, frame
- self._handle_timeout()
-
-
-class TimeoutOpsModifier:
- def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]:
- """
- Decorate an "operation" to modify the timeout_ops value for duration of that operation
-
- This decorator wraps send command/config ops and is used to allow users to set a
- `timeout_ops` value for the duration of a single method call -- this makes it so users don't
- need to manually set/reset the value
-
- Args:
- wrapped_func: function being decorated
-
- Returns:
- decorate: decorated func
-
- Raises:
- N/A
-
- """
- if asyncio.iscoroutinefunction(wrapped_func):
-
- async def decorate(*args: Any, **kwargs: Any) -> Any:
- driver_instance: "AsyncGenericDriver" = args[0]
- driver_logger = driver_instance.logger
-
- timeout_ops_kwarg = kwargs.get("timeout_ops", None)
-
- if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops:
- result = await wrapped_func(*args, **kwargs)
- else:
- driver_logger.info(
- "modifying driver timeout for current operation, temporary timeout_ops "
- f"value: '{timeout_ops_kwarg}'"
- )
- base_timeout_ops = driver_instance.timeout_ops
- driver_instance.timeout_ops = kwargs["timeout_ops"]
- result = await wrapped_func(*args, **kwargs)
- driver_instance.timeout_ops = base_timeout_ops
- return result
-
- else:
- # ignoring type error:
- # "All conditional function variants must have identical signatures"
- # one is sync one is async so never going to be identical here!
- def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore
- driver_instance: "GenericDriver" = args[0]
- driver_logger = driver_instance.logger
-
- timeout_ops_kwarg = kwargs.get("timeout_ops", None)
-
- if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops:
- result = wrapped_func(*args, **kwargs)
- else:
- driver_logger.info(
- "modifying driver timeout for current operation, temporary timeout_ops "
- f"value: '{timeout_ops_kwarg}'"
- )
- base_timeout_ops = driver_instance.timeout_ops
- driver_instance.timeout_ops = kwargs["timeout_ops"]
- result = wrapped_func(*args, **kwargs)
- driver_instance.timeout_ops = base_timeout_ops
- return result
-
- # ensures that the wrapped function is updated w/ the original functions docs/etc. --
- # necessary for introspection for the auto gen docs to work!
- update_wrapper(wrapper=decorate, wrapped=wrapped_func)
- return decorate
+ if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops:
+ result = await wrapped_func(*args, **kwargs)
+ else:
+ driver_logger.info(
+ "modifying driver timeout for current operation, temporary timeout_ops "
+ f"value: '{timeout_ops_kwarg}'"
+ )
+ base_timeout_ops = driver_instance.timeout_ops
+ driver_instance.timeout_ops = kwargs["timeout_ops"]
+ result = await wrapped_func(*args, **kwargs)
+ driver_instance.timeout_ops = base_timeout_ops
+ return result
+
+ else:
+ # ignoring type error:
+ # "All conditional function variants must have identical signatures"
+ # one is sync one is async so never going to be identical here!
+ def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore
+ driver_instance: "GenericDriver" = args[0]
+ driver_logger = driver_instance.logger
+
+ timeout_ops_kwarg = kwargs.get("timeout_ops", None)
+
+ if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops:
+ result = wrapped_func(*args, **kwargs)
+ else:
+ driver_logger.info(
+ "modifying driver timeout for current operation, temporary timeout_ops "
+ f"value: '{timeout_ops_kwarg}'"
+ )
+ base_timeout_ops = driver_instance.timeout_ops
+ driver_instance.timeout_ops = kwargs["timeout_ops"]
+ result = wrapped_func(*args, **kwargs)
+ driver_instance.timeout_ops = base_timeout_ops
+ return result
+
+ # ensures that the wrapped function is updated w/ the original functions docs/etc. --
+ # necessary for introspection for the auto gen docs to work!
+ update_wrapper(wrapper=decorate, wrapped=wrapped_func)
+ return decorate
+## Functions
-## Classes
-
-### ChannelTimeout
+
+#### timeout_modifier
+`timeout_modifier(wrapped_func: Callable[..., Any]) ‑> Callable[..., Any]`
```text
-Channel timeout decorator
+Decorate an "operation" to modify the timeout_ops value for duration of that operation
+
+This decorator wraps send command/config ops and is used to allow users to set a
+`timeout_ops` value for the duration of a single method call -- this makes it so users don't
+need to manually set/reset the value
Args:
- message: accepts message from decorated function to add context to any timeout
- (if a timeout happens!)
+ wrapped_func: function being decorated
Returns:
- None
+ decorate: decorated func
Raises:
N/A
```
-
-
- Expand source code
-
-
-
-class ChannelTimeout:
- def __init__(self, message: str = "") -> None:
- """
- Channel timeout decorator
-
- Args:
- message: accepts message from decorated function to add context to any timeout
- (if a timeout happens!)
-
- Returns:
- None
- Raises:
- N/A
- """
- self.message = message
- self.channel_timeout_ops = 0.0
- self.channel_logger: LoggerAdapterT
- self.transport_instance: "BaseTransport"
- def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]:
- """
- Decorate an "operation" to modify the timeout_ops value for duration of that operation
-
- This decorator wraps send command/config ops and is used to allow users to set a
- `timeout_ops` value for the duration of a single method call -- this makes it so users don't
- need to manually set/reset the value
-
- Args:
- wrapped_func: function being decorated
-
- Returns:
- decorate: decorated func
-
- Raises:
- N/A
-
- """
- if asyncio.iscoroutinefunction(wrapped_func):
-
- async def decorate(*args: Any, **kwargs: Any) -> Any:
- channel_instance: "Channel" = args[0]
- self.channel_logger = channel_instance.logger
- self.channel_timeout_ops = (
- channel_instance._base_channel_args.timeout_ops # pylint: disable=W0212
- )
-
- if not self.channel_timeout_ops:
- return await wrapped_func(*args, **kwargs)
-
- self.transport_instance = channel_instance.transport
-
- try:
- return await asyncio.wait_for(
- wrapped_func(*args, **kwargs), timeout=self.channel_timeout_ops
- )
- except asyncio.TimeoutError:
- self._handle_timeout()
-
- else:
- # ignoring type error:
- # "All conditional function variants must have identical signatures"
- # one is sync one is async so never going to be identical here!
- def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore
- channel_instance: "Channel" = args[0]
- self.channel_logger = channel_instance.logger
- self.channel_timeout_ops = (
- channel_instance._base_channel_args.timeout_ops # pylint: disable=W0212
- )
-
- if not self.channel_timeout_ops:
- return wrapped_func(*args, **kwargs)
-
- self.transport_instance = channel_instance.transport
- transport_instance_class_name = self.transport_instance.__class__.__name__
-
- if (
- transport_instance_class_name in ("SystemTransport", "TelnetTransport")
- or _IS_WINDOWS
- or threading.current_thread() is not threading.main_thread()
- ):
- return self._multiprocessing_timeout(
- wrapped_func=wrapped_func,
- args=args,
- kwargs=kwargs,
- )
-
- old = signal.signal(signal.SIGALRM, self._signal_raise_exception)
- signal.setitimer(signal.ITIMER_REAL, self.channel_timeout_ops)
- try:
- return wrapped_func(*args, **kwargs)
- finally:
- if self.channel_timeout_ops:
- signal.setitimer(signal.ITIMER_REAL, 0)
- signal.signal(signal.SIGALRM, old)
-
- # ensures that the wrapped function is updated w/ the original functions docs/etc. --
- # necessary for introspection for the auto gen docs to work!
- update_wrapper(wrapper=decorate, wrapped=wrapped_func)
- return decorate
-
- def _handle_timeout(self) -> None:
- """
- Timeout handler method to close connections and raise ScrapliTimeout
-
- Args:
- N/A
-
- Returns:
- None
-
- Raises:
- ScrapliTimeout: always, if we hit this method we have already timed out!
-
- """
- self.channel_logger.critical("channel operation timed out, closing transport")
- self.transport_instance.close()
- raise ScrapliTimeout(self.message)
-
- def _multiprocessing_timeout(
- self, wrapped_func: Callable[..., Any], args: Any, kwargs: Any
- ) -> Any:
- """
- Multiprocessing method for timeouts; works in threads and on windows
-
- Args:
- wrapped_func: function being decorated
- args: function being decorated args
- kwargs: function being decorated kwargs
-
- Returns:
- Any: result of decorated function
-
- Raises:
- N/A
-
- """
- with ThreadPoolExecutor(max_workers=1) as pool:
- future = pool.submit(wrapped_func, *args, **kwargs)
- wait([future], timeout=self.channel_timeout_ops)
- if not future.done():
- self._handle_timeout()
- return future.result()
-
- def _signal_raise_exception(self, signum: Any, frame: Any) -> None:
- """
- Signal method exception handler
-
- Args:
- signum: singum from the singal handler, unused here
- frame: frame from the signal handler, unused here
-
- Returns:
- None
-
- Raises:
- N/A
-
- """
- _, _ = signum, frame
- self._handle_timeout()
-
-
-
-
-
-
-
-
-### TimeoutOpsModifier
-
-
-
-
-
- Expand source code
-
-
-
-class TimeoutOpsModifier:
- def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]:
- """
- Decorate an "operation" to modify the timeout_ops value for duration of that operation
-
- This decorator wraps send command/config ops and is used to allow users to set a
- `timeout_ops` value for the duration of a single method call -- this makes it so users don't
- need to manually set/reset the value
-
- Args:
- wrapped_func: function being decorated
-
- Returns:
- decorate: decorated func
-
- Raises:
- N/A
-
- """
- if asyncio.iscoroutinefunction(wrapped_func):
-
- async def decorate(*args: Any, **kwargs: Any) -> Any:
- driver_instance: "AsyncGenericDriver" = args[0]
- driver_logger = driver_instance.logger
-
- timeout_ops_kwarg = kwargs.get("timeout_ops", None)
-
- if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops:
- result = await wrapped_func(*args, **kwargs)
- else:
- driver_logger.info(
- "modifying driver timeout for current operation, temporary timeout_ops "
- f"value: '{timeout_ops_kwarg}'"
- )
- base_timeout_ops = driver_instance.timeout_ops
- driver_instance.timeout_ops = kwargs["timeout_ops"]
- result = await wrapped_func(*args, **kwargs)
- driver_instance.timeout_ops = base_timeout_ops
- return result
-
- else:
- # ignoring type error:
- # "All conditional function variants must have identical signatures"
- # one is sync one is async so never going to be identical here!
- def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore
- driver_instance: "GenericDriver" = args[0]
- driver_logger = driver_instance.logger
-
- timeout_ops_kwarg = kwargs.get("timeout_ops", None)
-
- if timeout_ops_kwarg is None or timeout_ops_kwarg == driver_instance.timeout_ops:
- result = wrapped_func(*args, **kwargs)
- else:
- driver_logger.info(
- "modifying driver timeout for current operation, temporary timeout_ops "
- f"value: '{timeout_ops_kwarg}'"
- )
- base_timeout_ops = driver_instance.timeout_ops
- driver_instance.timeout_ops = kwargs["timeout_ops"]
- result = wrapped_func(*args, **kwargs)
- driver_instance.timeout_ops = base_timeout_ops
- return result
-
- # ensures that the wrapped function is updated w/ the original functions docs/etc. --
- # necessary for introspection for the auto gen docs to work!
- update_wrapper(wrapper=decorate, wrapped=wrapped_func)
- return decorate
-
-
-
-
-
-
-
-
-### TransportTimeout
+
+#### timeout_wrapper
+`timeout_wrapper(wrapped_func: Callable[..., Any]) ‑> Callable[..., Any]`
```text
-Transport timeout decorator
+Timeout wrapper for transports
Args:
- message: accepts message from decorated function to add context to any timeout
- (if a timeout happens!)
+ wrapped_func: function being wrapped -- must be a method of Channel or Transport
Returns:
- None
+ Any: result of wrapped function
Raises:
N/A
-```
-
-
-
- Expand source code
-
-
-
-class TransportTimeout:
- def __init__(self, message: str = "") -> None:
- """
- Transport timeout decorator
-
- Args:
- message: accepts message from decorated function to add context to any timeout
- (if a timeout happens!)
-
- Returns:
- None
-
- Raises:
- N/A
-
- """
- self.message = message
- self.transport_instance: "BaseTransport"
- self.transport_timeout_transport = 0.0
-
- def __call__(self, wrapped_func: Callable[..., Any]) -> Callable[..., Any]:
- """
- Decorate an "operation" to modify the timeout_transport value for duration of that operation
-
- This decorator wraps a transport read operation and is used to allow users to control the
- transport timeout via the `timeout_transport` attribute. This decorator should be applied to
- any transport "read" operations.
-
- Args:
- wrapped_func: function being decorated
-
- Returns:
- decorate: decorated func
-
- Raises:
- N/A
-
- """
-
- if asyncio.iscoroutinefunction(wrapped_func):
-
- async def decorate(*args: Any, **kwargs: Any) -> Any:
- self.transport_instance = args[0]
- self.transport_timeout_transport = self._get_timeout_transport()
-
- if not self.transport_timeout_transport:
- return await wrapped_func(*args, **kwargs)
-
- try:
- return await asyncio.wait_for(
- wrapped_func(*args, **kwargs), timeout=self.transport_timeout_transport
- )
- except asyncio.TimeoutError:
- self._handle_timeout()
-
- else:
- # ignoring type error:
- # "All conditional function variants must have identical signatures"
- # one is sync one is async so never going to be identical here!
- def decorate(*args: Any, **kwargs: Any) -> Any: # type: ignore
- self.transport_instance = args[0]
- self.transport_timeout_transport = self._get_timeout_transport()
-
- if not self.transport_timeout_transport:
- return wrapped_func(*args, **kwargs)
-
- transport_instance_class_name = self.transport_instance.__class__.__name__
-
- if (
- transport_instance_class_name in ("SystemTransport", "TelnetTransport")
- or _IS_WINDOWS
- or threading.current_thread() is not threading.main_thread()
- ):
- return self._multiprocessing_timeout(
- wrapped_func=wrapped_func,
- args=args,
- kwargs=kwargs,
- )
-
- old = signal.signal(signal.SIGALRM, self._signal_raise_exception)
- signal.setitimer(signal.ITIMER_REAL, self.transport_timeout_transport)
- try:
- return wrapped_func(*args, **kwargs)
- finally:
- if self.transport_timeout_transport:
- signal.setitimer(signal.ITIMER_REAL, 0)
- signal.signal(signal.SIGALRM, old)
-
- # ensures that the wrapped function is updated w/ the original functions docs/etc. --
- # necessary for introspection for the auto gen docs to work!
- update_wrapper(wrapper=decorate, wrapped=wrapped_func)
- return decorate
-
- def _get_timeout_transport(self) -> float:
- """
- Fetch and return timeout transport from the transport object
-
- Args:
- N/A
-
- Returns:
- float: transport timeout value
-
- Raises:
- N/A
-
- """
- transport_args = self.transport_instance._base_transport_args # pylint: disable=W0212
- return transport_args.timeout_transport
-
- def _handle_timeout(self) -> None:
- """
- Timeout handler method to close connections and raise ScrapliTimeout
-
- Args:
- N/A
-
- Returns:
- None
-
- Raises:
- ScrapliTimeout: always, if we hit this method we have already timed out!
-
- """
- self.transport_instance.logger.critical("transport operation timed out, closing transport")
- self.transport_instance.close()
- raise ScrapliTimeout(self.message)
-
- def _multiprocessing_timeout(
- self, wrapped_func: Callable[..., Any], args: Any, kwargs: Any
- ) -> Any:
- """
- Multiprocessing method for timeouts; works in threads and on windows
-
- Args:
- wrapped_func: function being decorated
- args: function being decorated args
- kwargs: function being decorated kwargs
-
- Returns:
- Any: result of decorated function
-
- Raises:
- N/A
-
- """
- with ThreadPoolExecutor(max_workers=1) as pool:
- future = pool.submit(wrapped_func, *args, **kwargs)
- wait([future], timeout=self.transport_timeout_transport)
- if not future.done():
- self._handle_timeout()
- return future.result()
-
- def _signal_raise_exception(self, signum: Any, frame: Any) -> None:
- """
- Signal method exception handler
-
- Args:
- signum: singum from the singal handler, unused here
- frame: frame from the signal handler, unused here
-
- Returns:
- None
-
- Raises:
- N/A
-
- """
- _, _ = signum, frame
- self._handle_timeout()
-
-
-
\ No newline at end of file
+```
\ No newline at end of file
diff --git a/docs/api_docs/driver/base/async_driver.md b/docs/api_docs/driver/base/async_driver.md
index d233b9f1..4eedcb56 100644
--- a/docs/api_docs/driver/base/async_driver.md
+++ b/docs/api_docs/driver/base/async_driver.md
@@ -30,13 +30,15 @@ scrapli.driver.base.async_driver
"""scrapli.driver.base.async_driver"""
from types import TracebackType
-from typing import Any, Optional, Type
+from typing import Any, Optional, Type, TypeVar
from scrapli.channel import AsyncChannel
from scrapli.driver.base.base_driver import BaseDriver
from scrapli.exceptions import ScrapliValueError
from scrapli.transport import ASYNCIO_TRANSPORTS
+_T = TypeVar("_T", bound="AsyncDriver")
+
class AsyncDriver(BaseDriver):
def __init__(self, **kwargs: Any):
@@ -53,7 +55,7 @@ class AsyncDriver(BaseDriver):
base_channel_args=self._base_channel_args,
)
- async def __aenter__(self) -> "AsyncDriver":
+ async def __aenter__(self: _T) -> _T:
"""
Enter method for context manager
@@ -61,7 +63,7 @@ class AsyncDriver(BaseDriver):
N/A
Returns:
- AsyncDriver: opened AsyncDriver object
+ _T: a concrete implementation of the opened AsyncDriver object
Raises:
N/A
@@ -371,7 +373,7 @@ class AsyncDriver(BaseDriver):
base_channel_args=self._base_channel_args,
)
- async def __aenter__(self) -> "AsyncDriver":
+ async def __aenter__(self: _T) -> _T:
"""
Enter method for context manager
@@ -379,7 +381,7 @@ class AsyncDriver(BaseDriver):
N/A
Returns:
- AsyncDriver: opened AsyncDriver object
+ _T: a concrete implementation of the opened AsyncDriver object
Raises:
N/A
diff --git a/docs/api_docs/driver/base/sync_driver.md b/docs/api_docs/driver/base/sync_driver.md
index f2a19063..8436fd20 100644
--- a/docs/api_docs/driver/base/sync_driver.md
+++ b/docs/api_docs/driver/base/sync_driver.md
@@ -30,13 +30,15 @@ scrapli.driver.base.sync_driver
"""scrapli.driver.base.sync_driver"""
from types import TracebackType
-from typing import Any, Optional, Type
+from typing import Any, Optional, Type, TypeVar
from scrapli.channel import Channel
from scrapli.driver.base.base_driver import BaseDriver
from scrapli.exceptions import ScrapliValueError
from scrapli.transport import ASYNCIO_TRANSPORTS
+_T = TypeVar("_T", bound="Driver")
+
class Driver(BaseDriver):
def __init__(self, **kwargs: Any):
@@ -53,7 +55,7 @@ class Driver(BaseDriver):
base_channel_args=self._base_channel_args,
)
- def __enter__(self) -> "Driver":
+ def __enter__(self: _T) -> _T:
"""
Enter method for context manager
@@ -61,7 +63,7 @@ class Driver(BaseDriver):
N/A
Returns:
- Driver: opened Driver object
+ _T: a concrete implementation of the opened Driver object
Raises:
N/A
@@ -322,7 +324,7 @@ class Driver(BaseDriver):
base_channel_args=self._base_channel_args,
)
- def __enter__(self) -> "Driver":
+ def __enter__(self: _T) -> _T:
"""
Enter method for context manager
@@ -330,7 +332,7 @@ class Driver(BaseDriver):
N/A
Returns:
- Driver: opened Driver object
+ _T: a concrete implementation of the opened Driver object
Raises:
N/A
diff --git a/docs/api_docs/driver/core/juniper_junos/base_driver.md b/docs/api_docs/driver/core/juniper_junos/base_driver.md
index 50d71fcb..0e3bd9dc 100644
--- a/docs/api_docs/driver/core/juniper_junos/base_driver.md
+++ b/docs/api_docs/driver/core/juniper_junos/base_driver.md
@@ -90,7 +90,7 @@ PRIVS = {
),
"root_shell": (
PrivilegeLevel(
- pattern=r"^.*root@(?:\S*:\S*\s?)?[%\#]\s?$",
+ pattern=r"^.*root@(?:\S*:?\S*\s?)?[%\#]\s?$",
name="root_shell",
previous_priv="exec",
deescalate="exit",
diff --git a/docs/api_docs/driver/generic/async_driver.md b/docs/api_docs/driver/generic/async_driver.md
index e5ab89f4..2f78fc15 100644
--- a/docs/api_docs/driver/generic/async_driver.md
+++ b/docs/api_docs/driver/generic/async_driver.md
@@ -33,7 +33,7 @@ import asyncio
from io import BytesIO
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
-from scrapli.decorators import TimeoutOpsModifier
+from scrapli.decorators import timeout_modifier
from scrapli.driver import AsyncDriver
from scrapli.driver.generic.base_driver import BaseGenericDriver
from scrapli.exceptions import ScrapliTimeout, ScrapliValueError
@@ -118,7 +118,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver):
prompt: str = await self.channel.get_prompt()
return prompt
- @TimeoutOpsModifier()
+ @timeout_modifier
async def _send_command(
self,
command: str,
@@ -312,7 +312,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver):
timeout_ops=timeout_ops,
)
- @TimeoutOpsModifier()
+ @timeout_modifier
async def send_and_read(
self,
channel_input: str,
@@ -372,7 +372,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver):
raw_response=raw_response, processed_response=processed_response, response=response
)
- @TimeoutOpsModifier()
+ @timeout_modifier
async def send_interactive(
self,
interact_events: Union[List[Tuple[str, str]], List[Tuple[str, str, bool]]],
@@ -784,7 +784,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver):
prompt: str = await self.channel.get_prompt()
return prompt
- @TimeoutOpsModifier()
+ @timeout_modifier
async def _send_command(
self,
command: str,
@@ -978,7 +978,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver):
timeout_ops=timeout_ops,
)
- @TimeoutOpsModifier()
+ @timeout_modifier
async def send_and_read(
self,
channel_input: str,
@@ -1038,7 +1038,7 @@ class AsyncGenericDriver(AsyncDriver, BaseGenericDriver):
raw_response=raw_response, processed_response=processed_response, response=response
)
- @TimeoutOpsModifier()
+ @timeout_modifier
async def send_interactive(
self,
interact_events: Union[List[Tuple[str, str]], List[Tuple[str, str, bool]]],
diff --git a/docs/api_docs/driver/generic/sync_driver.md b/docs/api_docs/driver/generic/sync_driver.md
index 1e661c26..aa3ef0bd 100644
--- a/docs/api_docs/driver/generic/sync_driver.md
+++ b/docs/api_docs/driver/generic/sync_driver.md
@@ -33,7 +33,7 @@ import time
from io import BytesIO
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
-from scrapli.decorators import TimeoutOpsModifier
+from scrapli.decorators import timeout_modifier
from scrapli.driver import Driver
from scrapli.driver.generic.base_driver import BaseGenericDriver
from scrapli.exceptions import ScrapliTimeout, ScrapliValueError
@@ -119,7 +119,7 @@ class GenericDriver(Driver, BaseGenericDriver):
prompt: str = self.channel.get_prompt()
return prompt
- @TimeoutOpsModifier()
+ @timeout_modifier
def _send_command(
self,
command: str,
@@ -313,7 +313,7 @@ class GenericDriver(Driver, BaseGenericDriver):
timeout_ops=timeout_ops,
)
- @TimeoutOpsModifier()
+ @timeout_modifier
def send_and_read(
self,
channel_input: str,
@@ -373,7 +373,7 @@ class GenericDriver(Driver, BaseGenericDriver):
raw_response=raw_response, processed_response=processed_response, response=response
)
- @TimeoutOpsModifier()
+ @timeout_modifier
def send_interactive(
self,
interact_events: Union[List[Tuple[str, str]], List[Tuple[str, str, bool]]],
@@ -783,7 +783,7 @@ class GenericDriver(Driver, BaseGenericDriver):
prompt: str = self.channel.get_prompt()
return prompt
- @TimeoutOpsModifier()
+ @timeout_modifier
def _send_command(
self,
command: str,
@@ -977,7 +977,7 @@ class GenericDriver(Driver, BaseGenericDriver):
timeout_ops=timeout_ops,
)
- @TimeoutOpsModifier()
+ @timeout_modifier
def send_and_read(
self,
channel_input: str,
@@ -1037,7 +1037,7 @@ class GenericDriver(Driver, BaseGenericDriver):
raw_response=raw_response, processed_response=processed_response, response=response
)
- @TimeoutOpsModifier()
+ @timeout_modifier
def send_interactive(
self,
interact_events: Union[List[Tuple[str, str]], List[Tuple[str, str, bool]]],
diff --git a/docs/api_docs/driver/network/async_driver.md b/docs/api_docs/driver/network/async_driver.md
index 4fc72d19..3c5f2a93 100644
--- a/docs/api_docs/driver/network/async_driver.md
+++ b/docs/api_docs/driver/network/async_driver.md
@@ -136,7 +136,10 @@ class AsyncNetworkDriver(AsyncGenericDriver, BaseNetworkDriver):
(escalate_priv.escalate, escalate_priv.escalate_prompt, False),
(self.auth_secondary, escalate_priv.pattern, True),
],
- interaction_complete_patterns=[escalate_priv.pattern],
+ interaction_complete_patterns=[
+ self.privilege_levels[escalate_priv.previous_priv].pattern,
+ escalate_priv.pattern,
+ ],
)
except ScrapliTimeout as exc:
raise ScrapliAuthenticationFailed(
@@ -462,7 +465,7 @@ class AsyncNetworkDriver(AsyncGenericDriver, BaseNetworkDriver):
if failed_when_contains is None:
failed_when_contains = self.failed_when_contains
- # type hint is due to the TimeoutModifier wrapper returning `Any` so that we dont anger the
+ # type hint is due to the timeout_modifier wrapper returning `Any` so that we dont anger the
# asyncio parts (which will get an awaitable not a Response returned)
response: Response = await super().send_interactive(
interact_events=interact_events,
@@ -860,7 +863,10 @@ class AsyncNetworkDriver(AsyncGenericDriver, BaseNetworkDriver):
(escalate_priv.escalate, escalate_priv.escalate_prompt, False),
(self.auth_secondary, escalate_priv.pattern, True),
],
- interaction_complete_patterns=[escalate_priv.pattern],
+ interaction_complete_patterns=[
+ self.privilege_levels[escalate_priv.previous_priv].pattern,
+ escalate_priv.pattern,
+ ],
)
except ScrapliTimeout as exc:
raise ScrapliAuthenticationFailed(
@@ -1186,7 +1192,7 @@ class AsyncNetworkDriver(AsyncGenericDriver, BaseNetworkDriver):
if failed_when_contains is None:
failed_when_contains = self.failed_when_contains
- # type hint is due to the TimeoutModifier wrapper returning `Any` so that we dont anger the
+ # type hint is due to the timeout_modifier wrapper returning `Any` so that we dont anger the
# asyncio parts (which will get an awaitable not a Response returned)
response: Response = await super().send_interactive(
interact_events=interact_events,
diff --git a/docs/api_docs/driver/network/base_driver.md b/docs/api_docs/driver/network/base_driver.md
index f5ba00dd..fdfca7fd 100644
--- a/docs/api_docs/driver/network/base_driver.md
+++ b/docs/api_docs/driver/network/base_driver.md
@@ -42,7 +42,7 @@ from scrapli.helper import user_warning
from scrapli.response import MultiResponse, Response
if TYPE_CHECKING:
- LoggerAdapterT = LoggerAdapter[Logger] # pylint:disable=E1136
+ LoggerAdapterT = LoggerAdapter[Logger] # pragma: no cover # pylint:disable=E1136
else:
LoggerAdapterT = LoggerAdapter
@@ -142,7 +142,7 @@ class BaseNetworkDriver:
rf"({priv_level_data.pattern})" for priv_level_data in self.privilege_levels.values()
)
- @lru_cache()
+ @lru_cache(maxsize=64)
def _determine_current_priv(self, current_prompt: str) -> List[str]:
"""
Determine current privilege level from prompt string
@@ -664,7 +664,7 @@ class BaseNetworkDriver:
rf"({priv_level_data.pattern})" for priv_level_data in self.privilege_levels.values()
)
- @lru_cache()
+ @lru_cache(maxsize=64)
def _determine_current_priv(self, current_prompt: str) -> List[str]:
"""
Determine current privilege level from prompt string
diff --git a/docs/api_docs/driver/network/sync_driver.md b/docs/api_docs/driver/network/sync_driver.md
index 7e4ced0d..cf8fd253 100644
--- a/docs/api_docs/driver/network/sync_driver.md
+++ b/docs/api_docs/driver/network/sync_driver.md
@@ -136,7 +136,10 @@ class NetworkDriver(GenericDriver, BaseNetworkDriver):
(escalate_priv.escalate, escalate_priv.escalate_prompt, False),
(self.auth_secondary, escalate_priv.pattern, True),
],
- interaction_complete_patterns=[escalate_priv.pattern],
+ interaction_complete_patterns=[
+ self.privilege_levels[escalate_priv.previous_priv].pattern,
+ escalate_priv.pattern,
+ ],
)
except ScrapliTimeout as exc:
raise ScrapliAuthenticationFailed(
@@ -462,7 +465,7 @@ class NetworkDriver(GenericDriver, BaseNetworkDriver):
if failed_when_contains is None:
failed_when_contains = self.failed_when_contains
- # type hint is due to the TimeoutModifier wrapper returning `Any` so that we dont anger the
+ # type hint is due to the timeout_modifier wrapper returning `Any` so that we dont anger the
# asyncio parts (which will get an awaitable not a Response returned)
response: Response = super().send_interactive(
interact_events=interact_events,
@@ -860,7 +863,10 @@ class NetworkDriver(GenericDriver, BaseNetworkDriver):
(escalate_priv.escalate, escalate_priv.escalate_prompt, False),
(self.auth_secondary, escalate_priv.pattern, True),
],
- interaction_complete_patterns=[escalate_priv.pattern],
+ interaction_complete_patterns=[
+ self.privilege_levels[escalate_priv.previous_priv].pattern,
+ escalate_priv.pattern,
+ ],
)
except ScrapliTimeout as exc:
raise ScrapliAuthenticationFailed(
@@ -1186,7 +1192,7 @@ class NetworkDriver(GenericDriver, BaseNetworkDriver):
if failed_when_contains is None:
failed_when_contains = self.failed_when_contains
- # type hint is due to the TimeoutModifier wrapper returning `Any` so that we dont anger the
+ # type hint is due to the timeout_modifier wrapper returning `Any` so that we dont anger the
# asyncio parts (which will get an awaitable not a Response returned)
response: Response = super().send_interactive(
interact_events=interact_events,
diff --git a/docs/api_docs/transport/plugins/asyncssh.md b/docs/api_docs/transport/plugins/asyncssh.md
index 861d5c59..ecfba1d4 100644
--- a/docs/api_docs/transport/plugins/asyncssh.md
+++ b/docs/api_docs/transport/plugins/asyncssh.md
@@ -37,7 +37,7 @@ from asyncssh.connection import SSHClientConnection, connect
from asyncssh.misc import ConnectionLost, PermissionDenied
from asyncssh.stream import SSHReader, SSHWriter
-from scrapli.decorators import TransportTimeout
+from scrapli.decorators import timeout_wrapper
from scrapli.exceptions import (
ScrapliAuthenticationFailed,
ScrapliConnectionError,
@@ -275,7 +275,7 @@ class AsyncsshTransport(AsyncTransport):
pass
return False
- @TransportTimeout("timed out reading from transport")
+ @timeout_wrapper
async def read(self) -> bytes:
if not self.stdout:
raise ScrapliConnectionNotOpened
@@ -583,7 +583,7 @@ class AsyncsshTransport(AsyncTransport):
pass
return False
- @TransportTimeout("timed out reading from transport")
+ @timeout_wrapper
async def read(self) -> bytes:
if not self.stdout:
raise ScrapliConnectionNotOpened
diff --git a/docs/api_docs/transport/plugins/asynctelnet.md b/docs/api_docs/transport/plugins/asynctelnet.md
index 4e8d0584..540ac108 100644
--- a/docs/api_docs/transport/plugins/asynctelnet.md
+++ b/docs/api_docs/transport/plugins/asynctelnet.md
@@ -34,22 +34,14 @@ import socket
from dataclasses import dataclass
from typing import Optional
-from scrapli.decorators import TransportTimeout
+from scrapli.decorators import timeout_wrapper
from scrapli.exceptions import (
ScrapliAuthenticationFailed,
ScrapliConnectionError,
ScrapliConnectionNotOpened,
)
from scrapli.transport.base import AsyncTransport, BasePluginTransportArgs, BaseTransportArgs
-
-# telnet control characters we care about
-IAC = bytes([255])
-DONT = bytes([254])
-DO = bytes([253])
-WONT = bytes([252])
-WILL = bytes([251])
-TERM_TYPE = bytes([24])
-SUPPRESS_GO_AHEAD = bytes([3])
+from scrapli.transport.base.telnet_common import DO, DONT, IAC, SUPPRESS_GO_AHEAD, WILL, WONT
@dataclass()
@@ -68,10 +60,9 @@ class AsynctelnetTransport(AsyncTransport):
self.stdin: Optional[asyncio.StreamWriter] = None
self._initial_buf = b""
- self._stdout_binary_transmission = False
def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes:
- """ "
+ """
Handle the actual response to control characters
Broken up to be easier to test as well as to appease mr. mccabe
@@ -128,7 +119,7 @@ class AsynctelnetTransport(AsyncTransport):
return control_buf
async def _handle_control_chars(self) -> None:
- """ "
+ """
Handle control characters -- nearly identical to CPython telnetlib
Basically we want to read and "decline" any and all control options that the server proposes
@@ -227,7 +218,7 @@ class AsynctelnetTransport(AsyncTransport):
return False
return not self.stdout.at_eof()
- @TransportTimeout("timed out reading from transport")
+ @timeout_wrapper
async def read(self) -> bytes:
if not self.stdout:
raise ScrapliConnectionNotOpened
@@ -301,10 +292,9 @@ class AsynctelnetTransport(AsyncTransport):
self.stdin: Optional[asyncio.StreamWriter] = None
self._initial_buf = b""
- self._stdout_binary_transmission = False
def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes:
- """ "
+ """
Handle the actual response to control characters
Broken up to be easier to test as well as to appease mr. mccabe
@@ -361,7 +351,7 @@ class AsynctelnetTransport(AsyncTransport):
return control_buf
async def _handle_control_chars(self) -> None:
- """ "
+ """
Handle control characters -- nearly identical to CPython telnetlib
Basically we want to read and "decline" any and all control options that the server proposes
@@ -460,7 +450,7 @@ class AsynctelnetTransport(AsyncTransport):
return False
return not self.stdout.at_eof()
- @TransportTimeout("timed out reading from transport")
+ @timeout_wrapper
async def read(self) -> bytes:
if not self.stdout:
raise ScrapliConnectionNotOpened
diff --git a/docs/api_docs/transport/plugins/paramiko.md b/docs/api_docs/transport/plugins/paramiko.md
index 6509cd30..3ccf0670 100644
--- a/docs/api_docs/transport/plugins/paramiko.md
+++ b/docs/api_docs/transport/plugins/paramiko.md
@@ -216,7 +216,11 @@ class ParamikoTransport(Transport):
raise ScrapliConnectionNotOpened
if self._base_transport_args.transport_options.get("enable_rsa2", False) is False:
- self.session.disabled_algorithms = {"keys": ["rsa-sha2-256", "rsa-sha2-512"]}
+ # do this for "keys" and "pubkeys": https://github.com/paramiko/paramiko/issues/1984
+ self.session.disabled_algorithms = {
+ "keys": ["rsa-sha2-256", "rsa-sha2-512"],
+ "pubkeys": ["rsa-sha2-256", "rsa-sha2-512"],
+ }
try:
paramiko_key = RSAKey(filename=self.plugin_transport_args.auth_private_key)
@@ -537,7 +541,11 @@ class ParamikoTransport(Transport):
raise ScrapliConnectionNotOpened
if self._base_transport_args.transport_options.get("enable_rsa2", False) is False:
- self.session.disabled_algorithms = {"keys": ["rsa-sha2-256", "rsa-sha2-512"]}
+ # do this for "keys" and "pubkeys": https://github.com/paramiko/paramiko/issues/1984
+ self.session.disabled_algorithms = {
+ "keys": ["rsa-sha2-256", "rsa-sha2-512"],
+ "pubkeys": ["rsa-sha2-256", "rsa-sha2-512"],
+ }
try:
paramiko_key = RSAKey(filename=self.plugin_transport_args.auth_private_key)
diff --git a/docs/api_docs/transport/plugins/system.md b/docs/api_docs/transport/plugins/system.md
index 5d0c65c8..eb21223c 100644
--- a/docs/api_docs/transport/plugins/system.md
+++ b/docs/api_docs/transport/plugins/system.md
@@ -33,7 +33,7 @@ import sys
from dataclasses import dataclass
from typing import List, Optional
-from scrapli.decorators import TransportTimeout
+from scrapli.decorators import timeout_wrapper
from scrapli.exceptions import (
ScrapliConnectionError,
ScrapliConnectionNotOpened,
@@ -180,7 +180,7 @@ class SystemTransport(Transport):
return True
return False
- @TransportTimeout("timed out reading from transport")
+ @timeout_wrapper
def read(self) -> bytes:
if not self.session:
raise ScrapliConnectionNotOpened
@@ -437,7 +437,7 @@ class SystemTransport(Transport):
return True
return False
- @TransportTimeout("timed out reading from transport")
+ @timeout_wrapper
def read(self) -> bytes:
if not self.session:
raise ScrapliConnectionNotOpened
diff --git a/docs/api_docs/transport/plugins/telnet.md b/docs/api_docs/transport/plugins/telnet.md
index 298f3339..14f2c972 100644
--- a/docs/api_docs/transport/plugins/telnet.md
+++ b/docs/api_docs/transport/plugins/telnet.md
@@ -30,12 +30,13 @@ scrapli.transport.plugins.telnet.transport
"""scrapli.transport.plugins.telnet.transport"""
from dataclasses import dataclass
-from telnetlib import Telnet
from typing import Optional
-from scrapli.decorators import TransportTimeout
+from scrapli.decorators import timeout_wrapper
from scrapli.exceptions import ScrapliConnectionError, ScrapliConnectionNotOpened
from scrapli.transport.base import BasePluginTransportArgs, BaseTransportArgs, Transport
+from scrapli.transport.base.base_socket import Socket
+from scrapli.transport.base.telnet_common import DO, DONT, IAC, SUPPRESS_GO_AHEAD, WILL, WONT
@dataclass()
@@ -43,81 +44,196 @@ class PluginTransportArgs(BasePluginTransportArgs):
pass
-class ScrapliTelnet(Telnet):
- def __init__(self, host: str, port: int, timeout: float) -> None:
+class TelnetTransport(Transport):
+ def __init__(
+ self, base_transport_args: BaseTransportArgs, plugin_transport_args: PluginTransportArgs
+ ) -> None:
+ super().__init__(base_transport_args=base_transport_args)
+ self.plugin_transport_args = plugin_transport_args
+
+ self.socket: Optional[Socket] = None
+ self._initial_buf = b""
+
+ def _set_socket_timeout(self, timeout: float) -> None:
"""
- ScrapliTelnet class for typing purposes
+ Set underlying socket timeout
+
+ Mostly this exists just to assert that socket and socket.sock are not None to appease mypy!
Args:
- host: string of host
- port: integer port to connect to
- timeout: timeout value in seconds
+ timeout: float value to set as the timeout
Returns:
- None
+ N/A
Raises:
+ ScrapliConnectionNotOpened: if either socket or socket.sock are None
+ """
+ if self.socket is None:
+ raise ScrapliConnectionNotOpened
+ if self.socket.sock is None:
+ raise ScrapliConnectionNotOpened
+ self.socket.sock.settimeout(timeout)
+
+ def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes:
+ """
+ Handle the actual response to control characters
+
+ Broken up to be easier to test as well as to appease mr. mccabe
+
+ NOTE: see the asynctelnet transport for additional comments inline about what is going on
+ here.
+
+ Args:
+ control_buf: current control_buf to work with
+ c: currently read control char to process
+
+ Returns:
+ bytes: updated control_buf
+
+ Raises:
+ ScrapliConnectionNotOpened: if connection is not opened for some reason
+
+ """
+ if not self.socket:
+ raise ScrapliConnectionNotOpened
+
+ if not control_buf:
+ if c != IAC:
+ self._initial_buf += c
+ else:
+ control_buf += c
+
+ elif len(control_buf) == 1 and c in (DO, DONT, WILL, WONT):
+ control_buf += c
+
+ elif len(control_buf) == 2:
+ cmd = control_buf[1:2]
+ control_buf = b""
+
+ if (cmd == DO) and (c == SUPPRESS_GO_AHEAD):
+ self.write(IAC + WILL + c)
+ elif cmd in (DO, DONT):
+ self.write(IAC + WONT + c)
+ elif cmd == WILL:
+ self.write(IAC + DO + c)
+ elif cmd == WONT:
+ self.write(IAC + DONT + c)
+
+ return control_buf
+
+ def _handle_control_chars(self) -> None:
+ """
+ Handle control characters -- nearly identical to CPython (removed in 3.11) telnetlib
+
+ Basically we want to read and "decline" any and all control options that the server proposes
+ to us -- so if they say "DO" XYZ directive, we say "DONT", if they say "WILL" we say "WONT".
+
+ NOTE: see the asynctelnet transport for additional comments inline about what is going on
+ here.
+
+ Args:
N/A
+ Returns:
+ None
+
+ Raises:
+ ScrapliConnectionNotOpened: if connection is not opened for some reason
+ ScrapliConnectionNotOpened: if we read an empty byte string from the reader -- this
+ indicates the server sent an EOF -- see #142
+
"""
- self.eof: bool
- self.timeout: float
+ if not self.socket:
+ raise ScrapliConnectionNotOpened
- super().__init__(host, port, int(timeout))
+ control_buf = b""
+ original_socket_timeout = self._base_transport_args.timeout_socket
+ self._set_socket_timeout(self._base_transport_args.timeout_socket / 4)
-class TelnetTransport(Transport):
- def __init__(
- self, base_transport_args: BaseTransportArgs, plugin_transport_args: PluginTransportArgs
- ) -> None:
- super().__init__(base_transport_args=base_transport_args)
- self.plugin_transport_args = plugin_transport_args
+ while True:
+ try:
+ c = self._read(1)
+ if not c:
+ raise ScrapliConnectionNotOpened("server returned EOF, connection not opened")
+ except TimeoutError:
+ # shouldn't really matter/need to be reset back to "normal", but don't really want
+ # to leave it modified as that would be confusing!
+ self._base_transport_args.timeout_socket = original_socket_timeout
+ return
- self.session: Optional[ScrapliTelnet] = None
+ self._set_socket_timeout(self._base_transport_args.timeout_socket / 10)
+ control_buf = self._handle_control_chars_response(control_buf=control_buf, c=c)
def open(self) -> None:
self._pre_open_closing_log(closing=False)
- # establish session with "socket" timeout, then reset timeout to "transport" timeout
- try:
- self.session = ScrapliTelnet(
+ if not self.socket:
+ self.socket = Socket(
host=self._base_transport_args.host,
port=self._base_transport_args.port,
timeout=self._base_transport_args.timeout_socket,
)
- self.session.timeout = self._base_transport_args.timeout_transport
- except ConnectionError as exc:
- msg = f"Failed to open telnet session to host {self._base_transport_args.host}"
- if "connection refused" in str(exc).lower():
- msg = (
- f"Failed to open telnet session to host {self._base_transport_args.host}, "
- "connection refused"
- )
- raise ScrapliConnectionError(msg) from exc
+
+ if not self.socket.isalive():
+ self.socket.open()
+
+ self._handle_control_chars()
self._post_open_closing_log(closing=False)
def close(self) -> None:
self._pre_open_closing_log(closing=True)
- if self.session:
- self.session.close()
+ if self.socket:
+ self.socket.close()
- self.session = None
+ self.socket = None
self._post_open_closing_log(closing=True)
def isalive(self) -> bool:
- if not self.session:
+ if not self.socket:
+ return False
+ if not self.socket.isalive():
return False
- return not self.session.eof
+ return True
- @TransportTimeout("timed out reading from transport")
+ def _read(self, n: int = 65535) -> bytes:
+ """
+ Read n bytes from the socket
+
+ Mostly this exists just to assert that socket and socket.sock are not None to appease mypy!
+
+ Args:
+ n: optional amount of bytes to try to recv from the underlying socket
+
+ Returns:
+ N/A
+
+ Raises:
+ ScrapliConnectionNotOpened: if either socket or socket.sock are None
+ """
+ if self.socket is None:
+ raise ScrapliConnectionNotOpened
+ if self.socket.sock is None:
+ raise ScrapliConnectionNotOpened
+ return self.socket.sock.recv(n)
+
+ @timeout_wrapper
def read(self) -> bytes:
- if not self.session:
+ if not self.socket:
raise ScrapliConnectionNotOpened
+
+ if self._initial_buf:
+ buf = self._initial_buf
+ self._initial_buf = b""
+ return buf
+
try:
- buf = self.session.read_eager()
+ buf = self._read()
+ buf = buf.replace(b"\x00", b"")
except Exception as exc:
raise ScrapliConnectionError(
"encountered EOF reading from transport; typically means the device closed the "
@@ -126,9 +242,11 @@ class TelnetTransport(Transport):
return buf
def write(self, channel_input: bytes) -> None:
- if not self.session:
+ if self.socket is None:
+ raise ScrapliConnectionNotOpened
+ if self.socket.sock is None:
raise ScrapliConnectionNotOpened
- self.session.write(channel_input)
+ self.socket.sock.send(channel_input)
@@ -164,67 +282,17 @@ class PluginTransportArgs(BasePluginTransportArgs):
-### ScrapliTelnet
+### TelnetTransport
```text
-Telnet interface class.
-
-An instance of this class represents a connection to a telnet
-server. The instance is initially not connected; the open()
-method must be used to establish a connection. Alternatively, the
-host name and optional port number can be passed to the
-constructor, too.
-
-Don't try to reopen an already connected instance.
-
-This class has many read_*() methods. Note that some of them
-raise EOFError when the end of the connection is read, because
-they can return an empty string for other reasons. See the
-individual doc strings.
-
-read_until(expected, [timeout])
- Read until the expected string has been seen, or a timeout is
- hit (default is no timeout); may block.
-
-read_all()
- Read all data until EOF; may block.
-
-read_some()
- Read at least one byte or EOF; may block.
-
-read_very_eager()
- Read all data available already queued or on the socket,
- without blocking.
-
-read_eager()
- Read either data already queued or some data available on the
- socket, without blocking.
-
-read_lazy()
- Read all data in the raw queue (processing it first), without
- doing any socket I/O.
-
-read_very_lazy()
- Reads all data in the cooked queue, without doing any socket
- I/O.
-
-read_sb_data()
- Reads available data between SB ... SE sequence. Don't block.
-
-set_option_negotiation_callback(callback)
- Each time a telnet option is read on the input flow, this callback
- (if set) is called with the following parameters :
- callback(telnet socket, command, option)
- option will be chr(0) when there is no option.
- No other action is done afterwards by telnetlib.
+Helper class that provides a standard way to create an ABC using
+inheritance.
-ScrapliTelnet class for typing purposes
+Scrapli's transport base class
Args:
- host: string of host
- port: integer port to connect to
- timeout: timeout value in seconds
+ base_transport_args: base transport args dataclass
Returns:
None
@@ -239,114 +307,196 @@ Raises:
-class ScrapliTelnet(Telnet):
- def __init__(self, host: str, port: int, timeout: float) -> None:
+class TelnetTransport(Transport):
+ def __init__(
+ self, base_transport_args: BaseTransportArgs, plugin_transport_args: PluginTransportArgs
+ ) -> None:
+ super().__init__(base_transport_args=base_transport_args)
+ self.plugin_transport_args = plugin_transport_args
+
+ self.socket: Optional[Socket] = None
+ self._initial_buf = b""
+
+ def _set_socket_timeout(self, timeout: float) -> None:
"""
- ScrapliTelnet class for typing purposes
+ Set underlying socket timeout
+
+ Mostly this exists just to assert that socket and socket.sock are not None to appease mypy!
Args:
- host: string of host
- port: integer port to connect to
- timeout: timeout value in seconds
+ timeout: float value to set as the timeout
Returns:
- None
+ N/A
Raises:
- N/A
+ ScrapliConnectionNotOpened: if either socket or socket.sock are None
+ """
+ if self.socket is None:
+ raise ScrapliConnectionNotOpened
+ if self.socket.sock is None:
+ raise ScrapliConnectionNotOpened
+ self.socket.sock.settimeout(timeout)
+ def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes:
"""
- self.eof: bool
- self.timeout: float
+ Handle the actual response to control characters
- super().__init__(host, port, int(timeout))
-
-
-
+ Broken up to be easier to test as well as to appease mr. mccabe
+ NOTE: see the asynctelnet transport for additional comments inline about what is going on
+ here.
-#### Ancestors (in MRO)
-- telnetlib.Telnet
+ Args:
+ control_buf: current control_buf to work with
+ c: currently read control char to process
+ Returns:
+ bytes: updated control_buf
+ Raises:
+ ScrapliConnectionNotOpened: if connection is not opened for some reason
-### TelnetTransport
+ """
+ if not self.socket:
+ raise ScrapliConnectionNotOpened
+ if not control_buf:
+ if c != IAC:
+ self._initial_buf += c
+ else:
+ control_buf += c
-```text
-Helper class that provides a standard way to create an ABC using
-inheritance.
+ elif len(control_buf) == 1 and c in (DO, DONT, WILL, WONT):
+ control_buf += c
-Scrapli's transport base class
+ elif len(control_buf) == 2:
+ cmd = control_buf[1:2]
+ control_buf = b""
-Args:
- base_transport_args: base transport args dataclass
+ if (cmd == DO) and (c == SUPPRESS_GO_AHEAD):
+ self.write(IAC + WILL + c)
+ elif cmd in (DO, DONT):
+ self.write(IAC + WONT + c)
+ elif cmd == WILL:
+ self.write(IAC + DO + c)
+ elif cmd == WONT:
+ self.write(IAC + DONT + c)
-Returns:
- None
+ return control_buf
-Raises:
- N/A
-```
+ def _handle_control_chars(self) -> None:
+ """
+ Handle control characters -- nearly identical to CPython (removed in 3.11) telnetlib
-
-
- Expand source code
-
-
-
-class TelnetTransport(Transport):
- def __init__(
- self, base_transport_args: BaseTransportArgs, plugin_transport_args: PluginTransportArgs
- ) -> None:
- super().__init__(base_transport_args=base_transport_args)
- self.plugin_transport_args = plugin_transport_args
+ Basically we want to read and "decline" any and all control options that the server proposes
+ to us -- so if they say "DO" XYZ directive, we say "DONT", if they say "WILL" we say "WONT".
+
+ NOTE: see the asynctelnet transport for additional comments inline about what is going on
+ here.
+
+ Args:
+ N/A
+
+ Returns:
+ None
+
+ Raises:
+ ScrapliConnectionNotOpened: if connection is not opened for some reason
+ ScrapliConnectionNotOpened: if we read an empty byte string from the reader -- this
+ indicates the server sent an EOF -- see #142
+
+ """
+ if not self.socket:
+ raise ScrapliConnectionNotOpened
+
+ control_buf = b""
+
+ original_socket_timeout = self._base_transport_args.timeout_socket
+ self._set_socket_timeout(self._base_transport_args.timeout_socket / 4)
- self.session: Optional[ScrapliTelnet] = None
+ while True:
+ try:
+ c = self._read(1)
+ if not c:
+ raise ScrapliConnectionNotOpened("server returned EOF, connection not opened")
+ except TimeoutError:
+ # shouldn't really matter/need to be reset back to "normal", but don't really want
+ # to leave it modified as that would be confusing!
+ self._base_transport_args.timeout_socket = original_socket_timeout
+ return
+
+ self._set_socket_timeout(self._base_transport_args.timeout_socket / 10)
+ control_buf = self._handle_control_chars_response(control_buf=control_buf, c=c)
def open(self) -> None:
self._pre_open_closing_log(closing=False)
- # establish session with "socket" timeout, then reset timeout to "transport" timeout
- try:
- self.session = ScrapliTelnet(
+ if not self.socket:
+ self.socket = Socket(
host=self._base_transport_args.host,
port=self._base_transport_args.port,
timeout=self._base_transport_args.timeout_socket,
)
- self.session.timeout = self._base_transport_args.timeout_transport
- except ConnectionError as exc:
- msg = f"Failed to open telnet session to host {self._base_transport_args.host}"
- if "connection refused" in str(exc).lower():
- msg = (
- f"Failed to open telnet session to host {self._base_transport_args.host}, "
- "connection refused"
- )
- raise ScrapliConnectionError(msg) from exc
+
+ if not self.socket.isalive():
+ self.socket.open()
+
+ self._handle_control_chars()
self._post_open_closing_log(closing=False)
def close(self) -> None:
self._pre_open_closing_log(closing=True)
- if self.session:
- self.session.close()
+ if self.socket:
+ self.socket.close()
- self.session = None
+ self.socket = None
self._post_open_closing_log(closing=True)
def isalive(self) -> bool:
- if not self.session:
+ if not self.socket:
+ return False
+ if not self.socket.isalive():
return False
- return not self.session.eof
+ return True
- @TransportTimeout("timed out reading from transport")
+ def _read(self, n: int = 65535) -> bytes:
+ """
+ Read n bytes from the socket
+
+ Mostly this exists just to assert that socket and socket.sock are not None to appease mypy!
+
+ Args:
+ n: optional amount of bytes to try to recv from the underlying socket
+
+ Returns:
+ N/A
+
+ Raises:
+ ScrapliConnectionNotOpened: if either socket or socket.sock are None
+ """
+ if self.socket is None:
+ raise ScrapliConnectionNotOpened
+ if self.socket.sock is None:
+ raise ScrapliConnectionNotOpened
+ return self.socket.sock.recv(n)
+
+ @timeout_wrapper
def read(self) -> bytes:
- if not self.session:
+ if not self.socket:
raise ScrapliConnectionNotOpened
+
+ if self._initial_buf:
+ buf = self._initial_buf
+ self._initial_buf = b""
+ return buf
+
try:
- buf = self.session.read_eager()
+ buf = self._read()
+ buf = buf.replace(b"\x00", b"")
except Exception as exc:
raise ScrapliConnectionError(
"encountered EOF reading from transport; typically means the device closed the "
@@ -355,9 +505,11 @@ class TelnetTransport(Transport):
return buf
def write(self, channel_input: bytes) -> None:
- if not self.session:
+ if self.socket is None:
+ raise ScrapliConnectionNotOpened
+ if self.socket.sock is None:
raise ScrapliConnectionNotOpened
- self.session.write(channel_input)
+ self.socket.sock.send(channel_input)
diff --git a/setup.py b/setup.py
index b81183f0..ae38ea88 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
import setuptools
-__version__ = "2022.01.30.post1"
+__version__ = "2022.07.30"
__author__ = "Carl Montanari"
with open("README.md", "r", encoding="utf-8") as f: