diff --git a/amqtt/adapters.py b/amqtt/adapters.py index c5a12ad8..d35845b3 100644 --- a/amqtt/adapters.py +++ b/amqtt/adapters.py @@ -71,16 +71,13 @@ async def _feed_buffer(self, n: int = 1) -> None: :param n: if given, feed buffer until it contains at least n bytes. """ buffer = bytearray(self._stream.read()) + message: str | bytes | None = None while len(buffer) < n: - try: + with suppress(ConnectionClosed): message = await self._protocol.recv() - except ConnectionClosed: - message = None if message is None: break - if not isinstance(message, bytes): - msg = "message must be bytes" - raise TypeError(msg) + message = message.encode("utf-8") if isinstance(message, str) else message buffer.extend(message) self._stream = io.BytesIO(buffer) diff --git a/amqtt/broker.py b/amqtt/broker.py index f5054caf..227120d4 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -1,12 +1,13 @@ import asyncio from asyncio import CancelledError, futures from collections import deque +from collections.abc import Generator from enum import Enum from functools import partial import logging import re import ssl -from typing import Any +from typing import Any, ClassVar from transitions import Machine, MachineError import websockets.asyncio.server @@ -20,18 +21,24 @@ WebSocketsWriter, WriterAdapter, ) -from amqtt.errors import AMQTTException, BrokerException, MQTTException, NoDataException +from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler -from amqtt.session import ApplicationMessage, Session +from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session from amqtt.utils import format_client_message, gen_client_id from .plugins.manager import BaseContext, PluginManager -_defaults: dict[str, int | bool | dict[Any, Any]] = { +type _CONFIG_LISTENER = dict[str, int | bool | dict[str, Any]] +type _BROADCAST = dict[str, Session | str | bytes | int | None] + +_defaults: _CONFIG_LISTENER = { "timeout-disconnect-delay": 2, "auth": {"allow-anonymous": True, "password-file": None}, } + +AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80 + EVENT_BROKER_PRE_START = "broker_pre_start" EVENT_BROKER_POST_START = "broker_post_start" EVENT_BROKER_PRE_SHUTDOWN = "broker_pre_shutdown" @@ -59,7 +66,12 @@ def __init__(self, source_session: Session | None, topic: str, data: bytes, qos: class Server: - def __init__(self, listener_name: str, server_instance: Any, max_connections: int = -1) -> None: + def __init__( + self, + listener_name: str, + server_instance: asyncio.Server | websockets.asyncio.server.Server, + max_connections: int = -1, + ) -> None: self.logger = logging.getLogger(__name__) self.instance = server_instance self.conn_count = 0 @@ -100,7 +112,7 @@ class BrokerContext(BaseContext): def __init__(self, broker: "Broker") -> None: super().__init__() - self.config: dict[str, int | bool | dict[Any, Any]] | None = None + self.config: _CONFIG_LISTENER | None = None self._broker_instance = broker async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None: @@ -110,17 +122,17 @@ def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | No self._broker_instance.retain_message(None, topic_name, data, qos) @property - def sessions(self) -> Any: - for session in self._broker_instance._sessions.values(): # noqa: SLF001 + def sessions(self) -> Generator[Session]: + for session in self._broker_instance._sessions.values(): yield session[0] @property - def retained_messages(self) -> dict[Any, Any]: - return self._broker_instance._retained_messages # noqa: SLF001 + def retained_messages(self) -> dict[str, RetainedApplicationMessage]: + return self._broker_instance._retained_messages @property - def subscriptions(self) -> dict[Any, Any]: - return self._broker_instance._subscriptions # noqa: SLF001 + def subscriptions(self) -> dict[str, list[tuple[Session, int]]]: + return self._broker_instance._subscriptions class Broker: @@ -132,7 +144,7 @@ class Broker: """ - states = [ + states: ClassVar[list[str]] = [ "new", "starting", "started", @@ -145,7 +157,7 @@ class Broker: def __init__( self, - config: dict[str, Any] | None = None, + config: _CONFIG_LISTENER | None = None, loop: asyncio.AbstractEventLoop | None = None, plugin_namespace: str | None = None, ) -> None: @@ -161,7 +173,7 @@ def __init__( self._sessions: dict[str, tuple[Session, BrokerProtocolHandler]] = {} self._subscriptions: dict[str, list[tuple[Session, int]]] = {} self._retained_messages: dict[str, RetainedApplicationMessage] = {} - self._broadcast_queue: asyncio.Queue[Any] = asyncio.Queue() + self._broadcast_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() self._broadcast_task: asyncio.Task[Any] | None = None self._broadcast_shutdown_waiter: asyncio.Future[Any] = futures.Future() @@ -172,18 +184,25 @@ def __init__( namespace = plugin_namespace or "amqtt.broker.plugins" self.plugins_manager = PluginManager(namespace, context, self._loop) - def _build_listeners_config(self, broker_config: dict[str, Any]) -> None: + def _build_listeners_config(self, broker_config: _CONFIG_LISTENER) -> None: self.listeners_config = {} try: - listeners_config = broker_config["listeners"] - defaults = listeners_config["default"] + listeners_config = broker_config.get("listeners") + if not isinstance(listeners_config, dict): + msg = "Listener config not found or invalid" + raise BrokerError(msg) + defaults = listeners_config.get("default") + if defaults is None: + msg = "Listener config has not default included or is invalid" + raise BrokerError(msg) + for listener_name, listener_conf in listeners_config.items(): config = defaults.copy() config.update(listener_conf) self.listeners_config[listener_name] = config except KeyError as ke: msg = f"Listener config not found or invalid: {ke}" - raise BrokerException(msg) from ke + raise BrokerError(msg) from ke def _init_states(self) -> None: self.transitions = Machine(states=Broker.states, initial="new") @@ -212,7 +231,7 @@ async def start(self) -> None: # Backwards compat: MachineError is raised by transitions < 0.5.0. self.logger.warning(f"[WARN-0001] Invalid method call at this moment: {exc}") msg = f"Broker instance can't be started: {exc}" - raise BrokerException(msg) from exc + raise BrokerError(msg) from exc await self.plugins_manager.fire_event(EVENT_BROKER_PRE_START) try: @@ -244,10 +263,10 @@ async def start(self) -> None: sc.verify_mode = ssl.CERT_OPTIONAL except KeyError as ke: msg = f"'certfile' or 'keyfile' configuration parameter missing: {ke}" - raise BrokerException(msg) from ke + raise BrokerError(msg) from ke except FileNotFoundError as fnfe: msg = "Can't read cert files '{}' or '{}' : {}".format(listener["certfile"], listener["keyfile"], fnfe) - raise BrokerException(msg) from fnfe + raise BrokerError(msg) from fnfe address, s_port = listener["bind"].split(":") port = 0 @@ -255,7 +274,7 @@ async def start(self) -> None: port = int(s_port) except ValueError as e: msg = "Invalid port value in bind value: {}".format(listener["bind"]) - raise BrokerException(msg) from e + raise BrokerError(msg) from e instance: asyncio.Server | websockets.asyncio.server.Server | None = None if listener["type"] == "tcp": @@ -292,7 +311,7 @@ async def start(self) -> None: self.logger.exception("Broker startup failed") self.transitions.starting_fail() msg = f"Broker instance can't be started: {e}" - raise BrokerException(msg) from e + raise BrokerError(msg) from e async def shutdown(self) -> None: """Stop broker instance. @@ -330,12 +349,12 @@ async def ws_connected(self, websocket: ServerConnection, listener_name: str) -> async def stream_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, listener_name: str) -> None: await self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer)) - async def client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: # noqa: C901, PLR0915, PLR0912 + async def client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: # Wait for connection available on listener server = self._servers.get(listener_name, None) if not server: msg = f"Invalid listener name '{listener_name}'" - raise BrokerException(msg) + raise BrokerError(msg) await server.acquire_connection() remote_info = writer.get_peer_info() @@ -346,7 +365,7 @@ async def client_connected(self, listener_name: str, reader: ReaderAdapter, writ # Wait for first packet and expect a CONNECT try: handler, client_session = await BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager) - except AMQTTException as exc: + except AMQTTError as exc: self.logger.warning( f"[MQTT-3.1.0-1] {format_client_message(address=remote_address, port=remote_port)}:" f"Can't read first packet an CONNECT: {exc}", @@ -355,7 +374,7 @@ async def client_connected(self, listener_name: str, reader: ReaderAdapter, writ self.logger.debug("Connection closed") server.release_connection() return - except MQTTException: + except MQTTError: self.logger.exception( f"Invalid connection from {format_client_message(address=remote_address, port=remote_port)}", ) @@ -363,7 +382,7 @@ async def client_connected(self, listener_name: str, reader: ReaderAdapter, writ server.release_connection() self.logger.debug("Connection closed") return - except NoDataException as ne: + except NoDataError as ne: self.logger.error(f"No data from {format_client_message(address=remote_address, port=remote_port)} : {ne}") # noqa: TRY400 # cannot replace with exception else test fails server.release_connection() return @@ -385,7 +404,7 @@ async def client_connected(self, listener_name: str, reader: ReaderAdapter, writ if client_session.client_id is None: msg = "Client ID was not correct created/set." - raise BrokerException(msg) + raise BrokerError(msg) timeout_disconnect_delay = self.config.get("timeout-disconnect-delay") if client_session.keep_alive > 0 and isinstance(timeout_disconnect_delay, int): @@ -493,9 +512,10 @@ async def client_connected(self, listener_name: str, reader: ReaderAdapter, writ return_codes = [ await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics ] + await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes) for index, subscription in enumerate(subscriptions.topics): - if return_codes[index] != 0x80: + if return_codes[index] != AMQTT_MAGIC_VALUE_RET_SUBSCRIBED: await self.plugins_manager.fire_event( EVENT_BROKER_CLIENT_SUBSCRIBED, client_id=client_session.client_id, @@ -577,7 +597,7 @@ async def _stop_handler(self, handler: BrokerProtocolHandler) -> None: except Exception: self.logger.exception("Failed to stop handler") - async def authenticate(self, session: Session, listener: Any) -> bool: # noqa: ARG002 + async def authenticate(self, session: Session, _: dict[str, Any]) -> bool: """Call the authenticate method on registered plugins to test user authentication. User is considered authenticated if all plugins called returns True. @@ -708,7 +728,7 @@ def _del_all_subscriptions(self, session: Session) -> None: :param session: :return: """ - filter_queue: deque[Any] = deque() + filter_queue: deque[str] = deque() for topic in self._subscriptions: if self._del_subscription(topic, session): filter_queue.append(topic) @@ -725,7 +745,7 @@ def matches(self, topic: str, a_filter: str) -> bool: return bool(match_pattern.fullmatch(topic)) async def _broadcast_loop(self) -> None: - running_tasks: deque[Any] = deque() + running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]] = deque() try: while True: while running_tasks and running_tasks[0].done(): @@ -758,7 +778,7 @@ async def _broadcast_loop(self) -> None: if running_tasks: await asyncio.gather(*running_tasks) - async def _run_broadcast(self, running_tasks: deque[Any]) -> None: + async def _run_broadcast(self, running_tasks: deque[asyncio.Task[OutgoingApplicationMessage]]) -> None: broadcast = await self._broadcast_queue.get() if self.logger.isEnabledFor(logging.DEBUG): @@ -800,7 +820,7 @@ async def _run_broadcast(self, running_tasks: deque[Any]) -> None: ) running_tasks.append(task) - async def _retain_broadcast_message(self, broadcast: dict[Any, Any], qos: int, target_session: Session) -> None: + async def _retain_broadcast_message(self, broadcast: dict[str, Any], qos: int, target_session: Session) -> None: if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug( f"retaining application message from {format_client_message(session=broadcast['session'])}" @@ -818,7 +838,7 @@ async def _shutdown_broadcast_loop(self) -> None: self._broadcast_shutdown_waiter.set_result(True) try: await asyncio.wait_for(self._broadcast_task, timeout=30) - except BaseException as e: + except TimeoutError as e: self.logger.warning(f"Failed to cleanly shutdown broadcast loop: {e}") if not self._broadcast_queue.empty(): @@ -831,7 +851,7 @@ async def _broadcast_message( data: bytes | None, force_qos: int | None = None, ) -> None: - broadcast: dict[str, Session | str | bytes | int | None] = {"session": session, "topic": topic, "data": data} + broadcast: _BROADCAST = {"session": session, "topic": topic, "data": data} if force_qos is not None: broadcast["qos"] = force_qos await self._broadcast_queue.put(broadcast) diff --git a/amqtt/client.py b/amqtt/client.py index 6b8649d0..b14a82e1 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -6,7 +6,7 @@ from functools import wraps import logging import ssl -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, cast from urllib.parse import urlparse, urlunparse import websockets @@ -18,7 +18,7 @@ WebSocketsReader, WebSocketsWriter, ) -from amqtt.errors import ClientException, ConnectException, ProtocolHandlerException +from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError from amqtt.mqtt.connack import CONNECTION_ACCEPTED from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler @@ -53,10 +53,10 @@ def __init__(self) -> None: base_logger = logging.getLogger(__name__) -F = TypeVar("F", bound=Callable[..., Coroutine[Any, Any, Any]]) +type _F = Callable[..., Coroutine[Any, Any, Any]] -def mqtt_connected(func: F) -> F: +def mqtt_connected(func: _F) -> _F: """MQTTClient coroutines decorator which will wait until connection before calling the decorated method. :param func: coroutine to be called once connected @@ -78,10 +78,10 @@ async def wrapper(self: "MQTTClient", *args: Any, **kwargs: Any) -> Any: t.cancel() if self._no_more_connections.is_set(): msg = "Will not reconnect" - raise ClientException(msg) + raise ClientError(msg) return await func(self, *args, **kwargs) - return cast(F, wrapper) + return cast(_F, wrapper) class MQTTClient: @@ -128,12 +128,14 @@ async def connect( At first, a network connection is established with the server using the given protocol (``mqtt``, ``mqtts``, ``ws`` or ``wss``). - Once the socket is connected, a `CONNECT `_ + Once the socket is connected, a + `CONNECT `_ message is sent with the requested information. This method is a *coroutine*. - :param uri: Broker URI connection, conforming to `MQTT URI scheme `_. + :param uri: Broker URI connection, conforming to + `MQTT URI scheme `_. Uses ``uri`` config attribute by default. :param cleansession: MQTT CONNECT clean session flag :param cafile: server certificate authority file (optional, used for secured connection) @@ -151,8 +153,9 @@ async def connect( try: return await self._do_connect() - except asyncio.CancelledError: - raise + except asyncio.CancelledError as e: + msg = "Future or Task was cancelled" + raise ConnectError(msg) from e except Exception as e: self.logger.warning(f"Connection failed: {e}") if not self.config.get("auto_reconnect", False): @@ -222,14 +225,15 @@ async def reconnect(self, cleansession: bool | None = None) -> int: try: self.logger.debug(f"Reconnect attempt {nb_attempt}...") return await self._do_connect() - except asyncio.CancelledError: - raise + except asyncio.CancelledError as e: + msg = "Future or Task was cancelled" + raise ConnectError(msg) from e except Exception as e: self.logger.warning(f"Reconnection attempt failed: {e}") if reconnect_retries >= 0 and nb_attempt > reconnect_retries: self.logger.exception("Maximum connection attempts reached. Reconnection aborted.") msg = "Too many failed attempts" - raise ConnectException(msg) from e + raise ConnectError(msg) from e delay = min(reconnect_max_interval, 2**nb_attempt) self.logger.debug(f"Waiting {delay} seconds before next attempt") await asyncio.sleep(delay) @@ -282,13 +286,13 @@ async def publish( """ if self._handler is None: msg = "Handler is not initialized." - raise ClientException(msg) + raise ClientError(msg) def get_retain_and_qos() -> tuple[int, bool]: if qos is not None: if qos not in (QOS_0, QOS_1, QOS_2): msg = f"QOS '{qos}' is not one of QOS_0, QOS_1, QOS_2." - raise ClientException(msg) + raise ClientError(msg) _qos = qos else: _qos = self.config["default_qos"] @@ -370,7 +374,7 @@ async def deliver_message(self, timeout_duration: float | None = None) -> Applic """ if self._handler is None: msg = "Handler is not initialized." - raise ClientException(msg) + raise ClientError(msg) deliver_task = asyncio.create_task(self._handler.mqtt_deliver_next_message()) self.client_tasks.append(deliver_task) @@ -401,7 +405,7 @@ async def _connect_coro(self) -> int: """Perform the core connection logic.""" if self.session is None: msg = "Session is not initialized." - raise ClientException(msg) + raise ClientError(msg) kwargs: dict[str, Any] = {} @@ -483,7 +487,7 @@ async def _connect_coro(self) -> int: self.session.transitions.disconnect() self.logger.warning("reader or writer not initialized") msg = "reader or writer not initialized" - raise ClientException(msg) + raise ClientError(msg) # Start MQTT protocol self._handler.attach(self.session, reader, writer) @@ -493,7 +497,7 @@ async def _connect_coro(self) -> int: self.session.transitions.disconnect() self.logger.warning(f"Connection rejected with code '{return_code}'") msg = "Connection rejected by broker" - exc = ConnectException(msg) + exc = ConnectError(msg) exc.return_code = return_code raise exc # Handle MQTT protocol @@ -503,19 +507,19 @@ async def _connect_coro(self) -> int: self.logger.debug(f"Connected to {self.session.remote_address}:{self.session.remote_port}") return return_code - except (InvalidURI, InvalidHandshake, ProtocolHandlerException, ConnectionError, OSError) as e: + except (InvalidURI, InvalidHandshake, ProtocolHandlerError, ConnectionError, OSError) as e: self.logger.warning(f"Connection failed : {self.session.broker_uri} : {e}") self.session.transitions.disconnect() - raise ConnectException(e) from e + raise ConnectError(e) from e async def handle_connection_close(self) -> None: """Handle disconnection from the broker.""" if self.session is None: msg = "Session is not initialized." - raise ClientException(msg) + raise ClientError(msg) if self._handler is None: msg = "Handler is not initialized." - raise ClientException(msg) + raise ClientError(msg) def cancel_tasks() -> None: self._no_more_connections.set() @@ -542,7 +546,7 @@ def cancel_tasks() -> None: self.logger.debug("Auto-reconnecting") try: await self.reconnect() - except ConnectException: + except ConnectError: # Cancel client pending tasks cancel_tasks() else: @@ -565,7 +569,7 @@ def _init_session( if not broker_conf.get("uri"): msg = "Missing connection parameter 'uri'" - raise ClientException(msg) + raise ClientError(msg) session = Session() session.broker_uri = broker_conf["uri"] diff --git a/amqtt/codecs.py b/amqtt/codecs_a.py similarity index 86% rename from amqtt/codecs.py rename to amqtt/codecs_a.py index 55252ec3..c367e3b6 100644 --- a/amqtt/codecs.py +++ b/amqtt/codecs_a.py @@ -2,7 +2,7 @@ from struct import pack, unpack from amqtt.adapters import ReaderAdapter -from amqtt.errors import NoDataException +from amqtt.errors import NoDataError def bytes_to_hex_str(data: bytes) -> str: @@ -30,16 +30,21 @@ def int_to_bytes(int_value: int, length: int) -> bytes: """Convert an integer to a sequence of bytes using big endian byte ordering. :param int_value: integer value to convert - :param length: (optional) byte length - :return: byte sequence. + :param length: byte length (must be 1 or 2) + :return: byte sequence + :raises ValueError: if the length is unsupported """ - if length == 1: - fmt = "!B" - elif length == 2: - fmt = "!H" - else: - msg = "Unsupported length for int to bytes conversion." + # Map length to the appropriate format string + fmt_mapping = { + 1: "!B", # 1 byte, unsigned char + 2: "!H", # 2 bytes, unsigned short + } + + fmt = fmt_mapping.get(length) + if not fmt: + msg = "Unsupported length for int to bytes conversion. Only lengths 1 or 2 are allowed." raise ValueError(msg) + return pack(fmt, int_value) @@ -56,7 +61,7 @@ async def read_or_raise(reader: ReaderAdapter | asyncio.StreamReader, n: int = - data = None if not data: msg = "No more data" - raise NoDataException(msg) + raise NoDataError(msg) return data diff --git a/amqtt/errors.py b/amqtt/errors.py index fdbfd522..71d65c74 100644 --- a/amqtt/errors.py +++ b/amqtt/errors.py @@ -1,32 +1,32 @@ -class AMQTTException(Exception): # noqa: N818 +class AMQTTError(Exception): """aMQTT base exception.""" -class MQTTException(Exception): # noqa: N818 +class MQTTError(Exception): """Base class for all errors referring to MQTT specifications.""" -class CodecException(Exception): # noqa: N818 +class CodecError(Exception): """Exceptions thrown by packet encode/decode functions.""" -class NoDataException(Exception): # noqa: N818 +class NoDataError(Exception): """Exceptions thrown by packet encode/decode functions.""" -class BrokerException(Exception): # noqa: N818 +class BrokerError(Exception): """Exceptions thrown by broker.""" -class ClientException(Exception): # noqa: N818 +class ClientError(Exception): """Exceptions thrown by client.""" -class ConnectException(ClientException): +class ConnectError(ClientError): """Exceptions thrown by client connect.""" return_code: int | None = None -class ProtocolHandlerException(Exception): # noqa: N818 +class ProtocolHandlerError(Exception): """Exceptions thrown by protocol handle.""" diff --git a/amqtt/mqtt/__init__.py b/amqtt/mqtt/__init__.py index ba8d74bc..f538698b 100644 --- a/amqtt/mqtt/__init__.py +++ b/amqtt/mqtt/__init__.py @@ -1,6 +1,8 @@ """INIT.""" -from amqtt.errors import AMQTTException +from typing import Any + +from amqtt.errors import AMQTTError from amqtt.mqtt.connack import ConnackPacket from amqtt.mqtt.connect import ConnectPacket from amqtt.mqtt.disconnect import DisconnectPacket @@ -34,7 +36,9 @@ from amqtt.mqtt.unsuback import UnsubackPacket from amqtt.mqtt.unsubscribe import UnsubscribePacket -packet_dict = { +type _P = MQTTPacket[Any, Any, Any] + +packet_dict: dict[int, type[_P]] = { CONNECT: ConnectPacket, CONNACK: ConnackPacket, PUBLISH: PublishPacket, @@ -52,9 +56,11 @@ } -def packet_class(fixed_header: MQTTFixedHeader) -> type[MQTTPacket]: +def packet_class( + fixed_header: MQTTFixedHeader, +) -> type[_P]: try: return packet_dict[fixed_header.packet_type] except KeyError as e: msg = f"Unexpected packet Type '{fixed_header.packet_type}'" - raise AMQTTException(msg) from e + raise AMQTTError(msg) from e diff --git a/amqtt/mqtt/connack.py b/amqtt/mqtt/connack.py index 456f0447..42a79bca 100644 --- a/amqtt/mqtt/connack.py +++ b/amqtt/mqtt/connack.py @@ -1,8 +1,8 @@ from typing import Self from amqtt.adapters import ReaderAdapter -from amqtt.codecs import bytes_to_int, read_or_raise -from amqtt.errors import AMQTTException +from amqtt.codecs_a import bytes_to_int, read_or_raise +from amqtt.errors import AMQTTError from amqtt.mqtt.packet import CONNACK, MQTTFixedHeader, MQTTPacket, MQTTPayload, MQTTVariableHeader CONNECTION_ACCEPTED = 0x00 @@ -22,7 +22,7 @@ def __init__(self, session_parent: int | None = None, return_code: int | None = self.return_code = return_code @classmethod - async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader | None) -> Self: # noqa: ARG003 + async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader | None) -> Self: data = await read_or_raise(reader, 2) session_parent = data[0] & 0x01 return_code = bytes_to_int(data[1]) @@ -40,10 +40,31 @@ def __repr__(self) -> str: return f"{type(self).__name__}(session_parent={hex(self.session_parent or 0)}, return_code={hex(self.return_code or 0)})" -class ConnackPacket(MQTTPacket[ConnackVariableHeader, MQTTPayload[MQTTVariableHeader]]): +class ConnackPacket(MQTTPacket[ConnackVariableHeader, MQTTPayload[MQTTVariableHeader], MQTTFixedHeader]): VARIABLE_HEADER = ConnackVariableHeader PAYLOAD = MQTTPayload[MQTTVariableHeader] + def __init__( + self, + fixed: MQTTFixedHeader | None = None, + variable_header: ConnackVariableHeader | None = None, + payload: MQTTPayload[MQTTVariableHeader] | None = None, + ) -> None: + if fixed is None: + header = MQTTFixedHeader(CONNACK, 0x00) + elif fixed.packet_type != CONNACK: + msg = f"Invalid fixed packet type {fixed.packet_type} for ConnackPacket init" + raise AMQTTError(msg) from None + else: + header = fixed + + super().__init__(header, variable_header, payload) + + @classmethod + def build(cls, session_parent: int | None = None, return_code: int | None = None) -> Self: + v_header = ConnackVariableHeader(session_parent, return_code) + return cls(variable_header=v_header) + @property def return_code(self) -> int | None: if self.variable_header is None: @@ -71,24 +92,3 @@ def session_parent(self, session_parent: int | None) -> None: msg = "Variable header is not set" raise ValueError(msg) self.variable_header.session_parent = session_parent - - def __init__( - self, - fixed: MQTTFixedHeader | None = None, - variable_header: ConnackVariableHeader | None = None, - payload: MQTTPayload | None = None, - ) -> None: - if fixed is None: - header = MQTTFixedHeader(CONNACK, 0x00) - elif fixed.packet_type != CONNACK: - msg = f"Invalid fixed packet type {fixed.packet_type} for ConnackPacket init" - raise AMQTTException(msg) from None - else: - header = fixed - - super().__init__(header, variable_header, payload) - - @classmethod - def build(cls, session_parent: int | None = None, return_code: int | None = None) -> Self: - v_header = ConnackVariableHeader(session_parent, return_code) - return cls(variable_header=v_header) diff --git a/amqtt/mqtt/connect.py b/amqtt/mqtt/connect.py index 4207eaaf..259d2774 100644 --- a/amqtt/mqtt/connect.py +++ b/amqtt/mqtt/connect.py @@ -2,7 +2,7 @@ from typing import Self from amqtt.adapters import ReaderAdapter -from amqtt.codecs import ( +from amqtt.codecs_a import ( bytes_to_int, decode_data_with_length, decode_string, @@ -11,7 +11,7 @@ int_to_bytes, read_or_raise, ) -from amqtt.errors import AMQTTException, NoDataException +from amqtt.errors import AMQTTError, NoDataError from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader, MQTTPacket, MQTTPayload, MQTTVariableHeader from amqtt.utils import gen_client_id @@ -50,7 +50,7 @@ def _get_flag(self, mask: int) -> bool: return bool(self.flags & mask) @classmethod - async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader) -> Self: # noqa: ARG003 + async def from_stream(cls, reader: ReaderAdapter, fixed_header: MQTTFixedHeader) -> Self: # protocol name protocol_name = await decode_string(reader) @@ -176,14 +176,14 @@ def __repr__(self) -> str: async def from_stream( cls, reader: StreamReader | ReaderAdapter, - fixed_header: MQTTFixedHeader | None, # noqa: ARG003 + fixed_header: MQTTFixedHeader | None, variable_header: ConnectVariableHeader | None, ) -> Self: payload = cls() # Client identifier try: payload.client_id = await decode_string(reader) - except NoDataException: + except NoDataError: payload.client_id = None if payload.client_id is None or payload.client_id == "": @@ -198,27 +198,27 @@ async def from_stream( try: payload.will_topic = await decode_string(reader) payload.will_message = await decode_data_with_length(reader) - except NoDataException: + except NoDataError: payload.will_topic = None payload.will_message = None if variable_header is not None and variable_header.username_flag: try: payload.username = await decode_string(reader) - except NoDataException: + except NoDataError: payload.username = None if variable_header is not None and variable_header.password_flag: try: payload.password = await decode_string(reader) - except NoDataException: + except NoDataError: payload.password = None return payload def to_bytes( self, - fixed_header: MQTTFixedHeader | None = None, # noqa: ARG002 + fixed_header: MQTTFixedHeader | None = None, variable_header: ConnectVariableHeader | None = None, ) -> bytes: out = bytearray() @@ -241,7 +241,7 @@ def to_bytes( return out -class ConnectPacket(MQTTPacket[ConnectVariableHeader, ConnectPayload]): +class ConnectPacket(MQTTPacket[ConnectVariableHeader, ConnectPayload, MQTTFixedHeader]): # type: ignore [type-var] VARIABLE_HEADER = ConnectVariableHeader PAYLOAD = ConnectPayload @@ -256,7 +256,7 @@ def __init__( else: if fixed.packet_type is not CONNECT: msg = f"Invalid fixed packet type {fixed.packet_type} for ConnectPacket init" - raise AMQTTException(msg) + raise AMQTTError(msg) header = fixed super().__init__(header) self.variable_header = variable_header diff --git a/amqtt/mqtt/disconnect.py b/amqtt/mqtt/disconnect.py index aeaab3aa..9fb28262 100644 --- a/amqtt/mqtt/disconnect.py +++ b/amqtt/mqtt/disconnect.py @@ -1,8 +1,8 @@ -from amqtt.errors import AMQTTException +from amqtt.errors import AMQTTError from amqtt.mqtt.packet import DISCONNECT, MQTTFixedHeader, MQTTPacket -class DisconnectPacket(MQTTPacket[None, None]): +class DisconnectPacket(MQTTPacket[None, None, MQTTFixedHeader]): VARIABLE_HEADER = None PAYLOAD = None @@ -12,7 +12,7 @@ def __init__(self, fixed: MQTTFixedHeader | None = None) -> None: else: if fixed.packet_type is not DISCONNECT: msg = f"Invalid fixed packet type {fixed.packet_type} for DisconnectPacket init" - raise AMQTTException(msg) + raise AMQTTError(msg) header = fixed super().__init__(header) self.variable_header = None diff --git a/amqtt/mqtt/packet.py b/amqtt/mqtt/packet.py index 31b68ffb..bc551be9 100644 --- a/amqtt/mqtt/packet.py +++ b/amqtt/mqtt/packet.py @@ -2,16 +2,11 @@ import asyncio from datetime import UTC, datetime from struct import unpack -from typing import Any, Generic, Self, TypeVar +from typing import Self from amqtt.adapters import ReaderAdapter, WriterAdapter -from amqtt.codecs import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise -from amqtt.errors import CodecException, MQTTException, NoDataException - -TVariableHeader = TypeVar("TVariableHeader", bound="MQTTVariableHeader | None", default="MQTTVariableHeader") -TPayload = TypeVar("TPayload", bound="MQTTPayload[Any] | None", default="MQTTPayload[MQTTVariableHeader]") -TFixedHeader = TypeVar("TFixedHeader", bound="MQTTFixedHeader", default="MQTTFixedHeader") - +from amqtt.codecs_a import bytes_to_hex_str, decode_packet_id, int_to_bytes, read_or_raise +from amqtt.errors import CodecError, MQTTError, NoDataError RESERVED_0 = 0x00 CONNECT = 0x01 @@ -63,7 +58,7 @@ def encode_remaining_length(length: int) -> bytes: return bytes([packet_type_flags]) + encoded_length except OverflowError as exc: msg = f"Fixed header encoding failed: {exc}" - raise CodecException(msg) from exc + raise CodecError(msg) from exc async def to_stream(self, writer: WriterAdapter) -> None: """Write the fixed header to the stream.""" @@ -91,7 +86,7 @@ async def decode_remaining_length() -> int: multiplier *= 128 if multiplier > 128**3: msg = f"Invalid remaining length bytes:{bytes_to_hex_str(buffer)}, packet_type={packet_type}" - raise MQTTException(msg) + raise MQTTError(msg) return value try: @@ -101,7 +96,7 @@ async def decode_remaining_length() -> int: flags = int1 & 0x0F remaining_length = await decode_remaining_length() return cls(packet_type, flags, remaining_length) - except NoDataException: + except NoDataError: return None def __repr__(self) -> str: @@ -145,7 +140,7 @@ def to_bytes(self) -> bytes: async def from_stream( cls: type[Self], reader: ReaderAdapter, - fixed_header: MQTTFixedHeader | None = None, # noqa: ARG003 + fixed_header: MQTTFixedHeader | None = None, ) -> Self: packet_id = await decode_packet_id(reader) return cls(packet_id) @@ -154,7 +149,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(packet_id={self.packet_id})" -class MQTTPayload(Generic[TVariableHeader]): +class MQTTPayload[_VH: MQTTVariableHeader](ABC): """Abstract base class for MQTT payloads.""" async def to_stream(self, writer: asyncio.StreamWriter) -> None: @@ -162,7 +157,7 @@ async def to_stream(self, writer: asyncio.StreamWriter) -> None: await writer.drain() @abstractmethod - def to_bytes(self, fixed_header: MQTTFixedHeader | None = None, variable_header: TVariableHeader | None = None) -> bytes: + def to_bytes(self, fixed_header: MQTTFixedHeader | None = None, variable_header: _VH | None = None) -> bytes: pass @classmethod @@ -171,26 +166,21 @@ async def from_stream( cls: type[Self], reader: asyncio.StreamReader | ReaderAdapter, fixed_header: MQTTFixedHeader | None, - variable_header: TVariableHeader | None, + variable_header: _VH | None, ) -> Self: pass -class MQTTPacket(Generic[TVariableHeader, TPayload, TFixedHeader]): +class MQTTPacket[_VH: MQTTVariableHeader | None, _P: MQTTPayload[MQTTVariableHeader] | None, _FH: MQTTFixedHeader]: """Represents an MQTT packet.""" __slots__ = ("fixed_header", "payload", "protocol_ts", "variable_header") - FIXED_HEADER: type[TFixedHeader] = MQTTFixedHeader # type: ignore [assignment] - VARIABLE_HEADER: type[TVariableHeader] | None = None - PAYLOAD: type[TPayload] | None = None + VARIABLE_HEADER: type[_VH] | None = None + PAYLOAD: type[_P] | None = None + FIXED_HEADER: type[_FH] = MQTTFixedHeader # type: ignore [assignment] - def __init__( - self, - fixed: TFixedHeader | None, - variable_header: TVariableHeader | None = None, - payload: TPayload | None = None, - ) -> None: + def __init__(self, fixed: _FH, variable_header: _VH | None = None, payload: _P | None = None) -> None: self.fixed_header = fixed self.variable_header = variable_header self.payload = payload @@ -204,10 +194,8 @@ async def to_stream(self, writer: WriterAdapter) -> None: def to_bytes(self) -> bytes: """Serialize the packet into bytes.""" - variable_header_bytes = self.variable_header.to_bytes() if isinstance(self.variable_header, MQTTVariableHeader) else b"" - payload_bytes = ( - self.payload.to_bytes(self.fixed_header, self.variable_header) if isinstance(self.payload, MQTTPayload) else b"" - ) + variable_header_bytes = self.variable_header.to_bytes() if self.variable_header is not None else b"" + payload_bytes = self.payload.to_bytes(self.fixed_header, self.variable_header) if self.payload is not None else b"" fixed_header_bytes = b"" if self.fixed_header: @@ -220,8 +208,8 @@ def to_bytes(self) -> bytes: async def from_stream( cls: type[Self], reader: ReaderAdapter, - fixed_header: TFixedHeader | None = None, - variable_header: TVariableHeader | None = None, + fixed_header: _FH | None = None, + variable_header: _VH | None = None, ) -> Self: """Decode an MQTT packet from the stream.""" if fixed_header is None: diff --git a/amqtt/mqtt/pingreq.py b/amqtt/mqtt/pingreq.py index 278d972d..445160a4 100644 --- a/amqtt/mqtt/pingreq.py +++ b/amqtt/mqtt/pingreq.py @@ -1,8 +1,8 @@ -from amqtt.errors import AMQTTException +from amqtt.errors import AMQTTError from amqtt.mqtt.packet import PINGREQ, MQTTFixedHeader, MQTTPacket -class PingReqPacket(MQTTPacket[None, None]): +class PingReqPacket(MQTTPacket[None, None, MQTTFixedHeader]): VARIABLE_HEADER = None PAYLOAD = None @@ -12,7 +12,7 @@ def __init__(self, fixed: MQTTFixedHeader | None = None) -> None: else: if fixed.packet_type is not PINGREQ: msg = f"Invalid fixed packet type {fixed.packet_type} for PingReqPacket init" - raise AMQTTException(msg) + raise AMQTTError(msg) header = fixed super().__init__(header) self.variable_header = None diff --git a/amqtt/mqtt/pingresp.py b/amqtt/mqtt/pingresp.py index a72f351b..f361aeac 100644 --- a/amqtt/mqtt/pingresp.py +++ b/amqtt/mqtt/pingresp.py @@ -1,10 +1,10 @@ from typing import Self -from amqtt.errors import AMQTTException +from amqtt.errors import AMQTTError from amqtt.mqtt.packet import PINGRESP, MQTTFixedHeader, MQTTPacket -class PingRespPacket(MQTTPacket[None, None]): +class PingRespPacket(MQTTPacket[None, None, MQTTFixedHeader]): VARIABLE_HEADER = None PAYLOAD = None @@ -14,7 +14,7 @@ def __init__(self, fixed: MQTTFixedHeader | None = None) -> None: else: if fixed.packet_type is not PINGRESP: msg = f"Invalid fixed packet type {fixed.packet_type} for PingRespPacket init" - raise AMQTTException(msg) + raise AMQTTError(msg) header = fixed super().__init__(header) self.variable_header = None diff --git a/amqtt/mqtt/protocol/broker_handler.py b/amqtt/mqtt/protocol/broker_handler.py index 32ef5e58..59e89fa1 100644 --- a/amqtt/mqtt/protocol/broker_handler.py +++ b/amqtt/mqtt/protocol/broker_handler.py @@ -2,7 +2,7 @@ from asyncio import AbstractEventLoop, Queue from amqtt.adapters import ReaderAdapter, WriterAdapter -from amqtt.errors import MQTTException +from amqtt.errors import MQTTError from amqtt.mqtt.connack import ( BAD_USERNAME_PASSWORD, CONNECTION_ACCEPTED, @@ -11,16 +11,15 @@ UNACCEPTABLE_PROTOCOL_VERSION, ConnackPacket, ) -from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader +from amqtt.mqtt.connect import ConnectPacket from amqtt.mqtt.disconnect import DisconnectPacket -from amqtt.mqtt.packet import PacketIdVariableHeader from amqtt.mqtt.pingreq import PingReqPacket from amqtt.mqtt.pingresp import PingRespPacket from amqtt.mqtt.protocol.handler import ProtocolHandler from amqtt.mqtt.suback import SubackPacket -from amqtt.mqtt.subscribe import SubscribePacket, SubscribePayload +from amqtt.mqtt.subscribe import SubscribePacket from amqtt.mqtt.unsuback import UnsubackPacket -from amqtt.mqtt.unsubscribe import UnsubscribePacket, UnubscribePayload +from amqtt.mqtt.unsubscribe import UnsubscribePacket from amqtt.plugins.manager import PluginManager from amqtt.session import Session from amqtt.utils import format_client_message @@ -83,7 +82,7 @@ async def handle_disconnect(self, disconnect: DisconnectPacket | None) -> None: async def handle_connection_closed(self) -> None: await self.handle_disconnect(None) - async def handle_connect(self, connect: ConnectPacket) -> None: # noqa: ARG002 + async def handle_connect(self, connect: ConnectPacket) -> None: # Broker handler shouldn't receive CONNECT message during messages handling # as CONNECT messages are managed by the broker on client connection self.logger.error( @@ -93,30 +92,27 @@ async def handle_connect(self, connect: ConnectPacket) -> None: # noqa: ARG002 if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): self._disconnect_waiter.set_result(None) - async def handle_pingreq(self, pingreq: PingReqPacket) -> None: # noqa: ARG002 + async def handle_pingreq(self, pingreq: PingReqPacket) -> None: await self._send_packet(PingRespPacket.build()) async def handle_subscribe(self, subscribe: SubscribePacket) -> None: - if subscribe.variable_header is None or not isinstance(subscribe.variable_header, PacketIdVariableHeader): - msg = f"Invalid variable header in SUBSCRIBE packet: {subscribe.payload}. Expected a PacketIdVariableHeader." - raise MQTTException(msg) - if subscribe.payload is None or not isinstance(subscribe.payload, SubscribePayload): - msg = f"Invalid payload in SUBSCRIBE packet: {subscribe.payload}. Expected a SubscribePayload." - raise MQTTException(msg) + if subscribe.variable_header is None: + msg = "SUBSCRIBE packet: variable header not initialized." + raise MQTTError(msg) + if subscribe.payload is None: + msg = "SUBSCRIBE packet: payload not initialized." + raise MQTTError(msg) subscription: Subscription = Subscription(subscribe.variable_header.packet_id, subscribe.payload.topics) await self._pending_subscriptions.put(subscription) async def handle_unsubscribe(self, unsubscribe: UnsubscribePacket) -> None: - if unsubscribe.variable_header is None or not isinstance(unsubscribe.variable_header, PacketIdVariableHeader): - msg = ( - f"Invalid variable header in UNSUBSCRIBE packet: {unsubscribe.variable_header}." - "Expected a PacketIdVariableHeader." - ) - raise MQTTException(msg) - if unsubscribe.payload is None or not isinstance(unsubscribe.payload, UnubscribePayload): - msg = f"Invalid payload in UNSUBSCRIBE packet: {unsubscribe.payload}. Expected a UnubscribePayload." - raise MQTTException(msg) + if unsubscribe.variable_header is None: + msg = "UNSUBSCRIBE packet: variable header not initialized." + raise MQTTError(msg) + if unsubscribe.payload is None: + msg = "UNSUBSCRIBE packet: payload not initialized." + raise MQTTError(msg) unsubscription: UnSubscription = UnSubscription(unsubscribe.variable_header.packet_id, unsubscribe.payload.topics) await self._pending_unsubscriptions.put(unsubscription) @@ -137,7 +133,7 @@ async def mqtt_acknowledge_unsubscription(self, packet_id: int) -> None: async def mqtt_connack_authorize(self, authorize: bool) -> None: if self.session is None: msg = "Session is not initialized!" - raise MQTTException(msg) + raise MQTTError(msg) connack = ConnackPacket.build(self.session.parent, CONNECTION_ACCEPTED if authorize else NOT_AUTHORIZED) await self._send_packet(connack) @@ -154,30 +150,30 @@ async def init_from_connect( connect = await ConnectPacket.from_stream(reader) await plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect) - if connect.payload is None or not isinstance(connect.payload, ConnectPayload): - msg = f"Invalid payload in CONNECT packet: {connect.variable_header}. Expected a ConnectPayload." - raise MQTTException(msg) - if connect.variable_header is None or not isinstance(connect.variable_header, ConnectVariableHeader): - msg = f"Invalid variable header in CONNECT packet: {connect.variable_header}. Expected a ConnectVariableHeader." - raise MQTTException(msg) + if connect.variable_header is None: + msg = "CONNECT packet: variable header not initialized." + raise MQTTError(msg) + if connect.payload is None: + msg = "CONNECT packet: payload not initialized." + raise MQTTError(msg) # this shouldn't be required anymore since broker generates for each client a random client_id if not provided # [MQTT-3.1.3-6] if connect.payload.client_id is None: msg = "[[MQTT-3.1.3-3]] : Client identifier must be present" - raise MQTTException(msg) + raise MQTTError(msg) if connect.variable_header.will_flag and (connect.payload.will_topic is None or connect.payload.will_message is None): msg = "Will flag set, but will topic/message not present in payload" - raise MQTTException(msg) + raise MQTTError(msg) if connect.variable_header.reserved_flag: msg = "[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0" - raise MQTTException(msg) + raise MQTTError(msg) if connect.proto_name != "MQTT": msg = f'[MQTT-3.1.2-1] Incorrect protocol name: "{connect.proto_name}"' - raise MQTTException(msg) + raise MQTTError(msg) remote_info = writer.get_peer_info() if remote_info is not None: @@ -210,7 +206,7 @@ async def init_from_connect( await plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack) await connack.to_stream(writer) await writer.close() - raise MQTTException(error_msg) from None + raise MQTTError(error_msg) from None incoming_session = Session() incoming_session.client_id = connect.client_id diff --git a/amqtt/mqtt/protocol/client_handler.py b/amqtt/mqtt/protocol/client_handler.py index 9b87a6d3..491390f6 100644 --- a/amqtt/mqtt/protocol/client_handler.py +++ b/amqtt/mqtt/protocol/client_handler.py @@ -1,15 +1,14 @@ import asyncio from typing import Any -from amqtt.errors import AMQTTException +from amqtt.errors import AMQTTError from amqtt.mqtt.connack import ConnackPacket from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader from amqtt.mqtt.disconnect import DisconnectPacket -from amqtt.mqtt.packet import PacketIdVariableHeader from amqtt.mqtt.pingreq import PingReqPacket from amqtt.mqtt.pingresp import PingRespPacket from amqtt.mqtt.protocol.handler import EVENT_MQTT_PACKET_RECEIVED, ProtocolHandler -from amqtt.mqtt.suback import SubackPacket, SubackPayload +from amqtt.mqtt.suback import SubackPacket from amqtt.mqtt.subscribe import SubscribePacket from amqtt.mqtt.unsuback import UnsubackPacket from amqtt.mqtt.unsubscribe import UnsubscribePacket @@ -51,7 +50,7 @@ def _build_connect_packet(self) -> ConnectPacket: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) vh.keep_alive = self.session.keep_alive vh.clean_session_flag = self.session.clean_session if self.session.clean_session is not None else False @@ -87,7 +86,7 @@ async def mqtt_connect(self) -> int | None: if self.reader is None: msg = "Reader is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) connack = await ConnackPacket.from_stream(self.reader) await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connack, session=self.session) @@ -98,8 +97,12 @@ def handle_write_timeout(self) -> None: if not self._ping_task: self.logger.debug("Scheduling Ping") self._ping_task = asyncio.create_task(self.mqtt_ping()) - except Exception as e: - self.logger.debug(f"Exception ignored in ping task: {e!r}") + except asyncio.InvalidStateError as e: + self.logger.warning(f"Invalid state while scheduling ping task: {e!r}") + except asyncio.CancelledError as e: + self.logger.info(f"Ping task was cancelled: {e!r}") + # except Exception as e: + # self.logger.debug(f"Exception ignored in ping task: {e!r}") def handle_read_timeout(self) -> None: pass @@ -115,7 +118,7 @@ async def mqtt_subscribe(self, topics: list[tuple[str, int]], packet_id: int) -> if subscribe.variable_header is None: msg = f"Invalid variable header in SUBSCRIBE packet: {subscribe.variable_header}" - raise AMQTTException(msg) + raise AMQTTError(msg) waiter: asyncio.Future[list[int]] = asyncio.Future() self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter @@ -126,16 +129,15 @@ async def mqtt_subscribe(self, topics: list[tuple[str, int]], packet_id: int) -> return return_codes async def handle_suback(self, suback: SubackPacket) -> None: - if suback.variable_header is None or not isinstance(suback.variable_header, PacketIdVariableHeader): - msg = f"Invalid variable header in SUBACK packet: {suback.variable_header}" - raise AMQTTException(msg) + if suback.variable_header is None: + msg = "SUBACK packet: variable header not initialized." + raise AMQTTError(msg) + if suback.payload is None: + msg = "SUBACK packet: payload not initialized." + raise AMQTTError(msg) packet_id = suback.variable_header.packet_id - if suback.payload is None or not isinstance(suback.payload, SubackPayload): - msg = f"Invalid payload in SUBACK packet: {suback.payload}. Expected a SubackPayload." - raise AMQTTException(msg) - waiter = self._subscriptions_waiter.get(packet_id) if waiter is not None: waiter.set_result(suback.payload.return_codes) @@ -149,9 +151,9 @@ async def mqtt_unsubscribe(self, topics: list[str], packet_id: int) -> None: """ unsubscribe = UnsubscribePacket.build(topics, packet_id) - if unsubscribe.variable_header is None or not isinstance(unsubscribe.variable_header, PacketIdVariableHeader): - msg = f"Invalid variable header in UNSUBSCRIBE packet: {unsubscribe.variable_header}" - raise AMQTTException(msg) + if unsubscribe.variable_header is None: + msg = "UNSUBSCRIBE packet: variable header not initialized." + raise AMQTTError(msg) await self._send_packet(unsubscribe) waiter: asyncio.Future[Any] = asyncio.Future() @@ -162,9 +164,9 @@ async def mqtt_unsubscribe(self, topics: list[str], packet_id: int) -> None: del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] async def handle_unsuback(self, unsuback: UnsubackPacket) -> None: - if unsuback.variable_header is None or not isinstance(unsuback.variable_header, PacketIdVariableHeader): - msg = f"Invalid variable header in UNSUBACK packet: {unsuback.variable_header}" - raise AMQTTException(msg) + if unsuback.variable_header is None: + msg = "UNSUBACK packet: variable header not initialized." + raise AMQTTError(msg) packet_id = unsuback.variable_header.packet_id waiter = self._unsubscriptions_waiter.get(packet_id) diff --git a/amqtt/mqtt/protocol/handler.py b/amqtt/mqtt/protocol/handler.py index 563de922..1ba73e0e 100644 --- a/amqtt/mqtt/protocol/handler.py +++ b/amqtt/mqtt/protocol/handler.py @@ -1,12 +1,12 @@ import asyncio -from asyncio import InvalidStateError +from asyncio import InvalidStateError, QueueFull, QueueShutDown import collections import itertools import logging -from typing import Any, cast +from typing import cast from amqtt.adapters import ReaderAdapter, WriterAdapter -from amqtt.errors import AMQTTException, MQTTException, NoDataException, ProtocolHandlerException +from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError from amqtt.mqtt import packet_class from amqtt.mqtt.connack import ConnackPacket from amqtt.mqtt.connect import ConnectPacket @@ -83,7 +83,7 @@ def __init__( def _init_session(self, session: Session) -> None: if not session: msg = "Session cannot be None" - raise AMQTTException(msg) + raise AMQTTError(msg) log = logging.getLogger(__name__) self.session = session self.logger = logging.LoggerAdapter(log, {"client_id": self.session.client_id}) @@ -94,7 +94,7 @@ def _init_session(self, session: Session) -> None: def attach(self, session: Session, reader: ReaderAdapter, writer: WriterAdapter) -> None: if self.session: msg = "Handler is already attached to a session" - raise ProtocolHandlerException(msg) + raise ProtocolHandlerError(msg) self._init_session(session) self.reader = reader self.writer = writer @@ -110,7 +110,7 @@ def _is_attached(self) -> bool: async def start(self) -> None: if not self._is_attached(): msg = "Handler is not attached to a stream" - raise ProtocolHandlerException(msg) + raise ProtocolHandlerError(msg) self._reader_ready = asyncio.Event() self._reader_task = asyncio.create_task(self._reader_loop()) await self._reader_ready.wait() @@ -133,8 +133,12 @@ async def stop(self) -> None: try: if self.writer is not None: await self.writer.close() - except Exception as e: - self.logger.debug(f"Handler writer close failed: {e}") + except asyncio.CancelledError: + self.logger.debug("Writer close was cancelled.") + except OSError as e: + self.logger.debug(f"Writer close failed due to I/O error: {e}") + except TimeoutError: + self.logger.debug("Writer close operation timed out.") def _stop_waiters(self) -> None: self.logger.debug(f"Stopping {len(self._puback_waiters)} puback waiters") @@ -149,7 +153,7 @@ def _stop_waiters(self) -> None: ): if not isinstance(waiter, asyncio.Future): msg = "Waiter is not a asyncio.Future" - raise AMQTTException(msg) + raise AMQTTError(msg) waiter.cancel() async def _retry_deliveries(self) -> None: @@ -158,7 +162,7 @@ async def _retry_deliveries(self) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) tasks = [ asyncio.create_task( @@ -196,13 +200,13 @@ async def mqtt_publish( """ if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) if qos in (QOS_1, QOS_2): packet_id = self.session.next_packet_id if packet_id in self.session.inflight_out: msg = f"A message with the same packet ID '{packet_id}' is already in flight" - raise AMQTTException(msg) + raise AMQTTError(msg) else: packet_id = None message: OutgoingApplicationMessage = OutgoingApplicationMessage(packet_id, topic, qos, data, retain) @@ -227,7 +231,7 @@ async def _handle_message_flow(self, app_message: IncomingApplicationMessage | O await self._handle_qos2_message_flow(app_message) else: msg = f"Unexpected QOS value '{app_message.qos}'" - raise AMQTTException(msg) + raise AMQTTError(msg) async def _handle_qos0_message_flow(self, app_message: IncomingApplicationMessage | OutgoingApplicationMessage) -> None: """Handle QOS_0 application message acknowledgment. @@ -241,7 +245,7 @@ async def _handle_qos0_message_flow(self, app_message: IncomingApplicationMessag raise ValueError(msg) if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) if app_message.direction == OUTGOING: packet = app_message.build_publish_packet() @@ -257,8 +261,10 @@ async def _handle_qos0_message_flow(self, app_message: IncomingApplicationMessag else: try: self.session.delivered_message_queue.put_nowait(app_message) - except Exception as e: - self.logger.warning(f"Delivered messages queue full. QOS_0 message discarded: {e}") + except QueueShutDown as e: + self.logger.warning(f"Delivered messages queue is shut down. QOS_0 message discarded: {e}") + except QueueFull as e: + self.logger.warning(f"Delivered messages queue is full. QOS_0 message discarded: {e}") async def _handle_qos1_message_flow(self, app_message: OutgoingApplicationMessage | IncomingApplicationMessage) -> None: """Handle QOS_1 application message acknowledgment. @@ -275,10 +281,10 @@ async def _handle_qos1_message_flow(self, app_message: OutgoingApplicationMessag raise ValueError(msg) if app_message.puback_packet: msg = f"Message '{app_message.packet_id}' has already been acknowledged" - raise AMQTTException(msg) + raise AMQTTError(msg) if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) if app_message.direction == OUTGOING: if app_message.packet_id not in self.session.inflight_out and isinstance(app_message, OutgoingApplicationMessage): @@ -311,7 +317,7 @@ async def _handle_qos1_message_flow(self, app_message: OutgoingApplicationMessag await self._send_packet(puback) app_message.puback_packet = puback - async def _handle_qos2_message_flow(self, app_message: OutgoingApplicationMessage | IncomingApplicationMessage) -> None: # noqa: PLR0915 + async def _handle_qos2_message_flow(self, app_message: OutgoingApplicationMessage | IncomingApplicationMessage) -> None: """Handle QOS_2 application message acknowledgment. For incoming messages, this method stores the message, sends PUBREC, waits for PUBREL, initiate delivery @@ -327,19 +333,19 @@ async def _handle_qos2_message_flow(self, app_message: OutgoingApplicationMessag raise ValueError(msg) if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) if app_message.direction == OUTGOING: if app_message.pubrel_packet and app_message.pubcomp_packet: msg = f"Message '{app_message.packet_id}' has already been acknowledged" - raise AMQTTException(msg) + raise AMQTTError(msg) if not app_message.pubrel_packet: # Store message if app_message.publish_packet is not None: # This is a retry flow, no need to store just check the message exists in session if app_message.packet_id not in self.session.inflight_out: msg = f"Unknown inflight message '{app_message.packet_id}' in session" - raise AMQTTException(msg) + raise AMQTTError(msg) publish_packet = app_message.build_publish_packet(dup=True) elif isinstance(app_message, OutgoingApplicationMessage): # Store message in session @@ -355,7 +361,7 @@ async def _handle_qos2_message_flow(self, app_message: OutgoingApplicationMessag # PUBREC waiter already exists for this packet ID message = f"Can't add PUBREC waiter, a waiter already exists for message Id '{app_message.packet_id}'" self.logger.warning(message) - raise AMQTTException(message) + raise AMQTTError(message) waiter_pub_rec: asyncio.Future[PubrecPacket] = asyncio.Future() self._pubrec_waiters[app_message.packet_id] = waiter_pub_rec try: @@ -405,13 +411,13 @@ async def _handle_qos2_message_flow(self, app_message: OutgoingApplicationMessag else: self.logger.debug("Unknown direction!") - async def _reader_loop(self) -> None: # noqa: C901, PLR0912, PLR0915 + async def _reader_loop(self) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) if not self._reader_ready: msg = "Reader ready is not initialized." - raise ProtocolHandlerException(msg) + raise ProtocolHandlerError(msg) self.logger.debug(f"{self.session.client_id} Starting reader coro") running_tasks: collections.deque[asyncio.Task[None]] = collections.deque() @@ -428,77 +434,77 @@ async def _reader_loop(self) -> None: # noqa: C901, PLR0912, PLR0915 fixed_header = ( await asyncio.wait_for(MQTTFixedHeader.from_stream(self.reader), keepalive_timeout) if self.reader else None ) - if fixed_header: - if fixed_header.packet_type in (RESERVED_0, RESERVED_15): - self.logger.warning( - f"{self.session.client_id} Received reserved packet, which is forbidden: closing connection", - ) - await self.handle_connection_closed() - else: - cls = packet_class(fixed_header) - if self.reader is None: - self.logger.warning("Reader is not initialized!") - continue - # NOTE: type => MQTTPacket - packet: Any = await cls.from_stream(self.reader, fixed_header=fixed_header) - - await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=packet, session=self.session) - if packet.fixed_header is None or packet.fixed_header.packet_type not in ( - CONNACK, - SUBSCRIBE, - UNSUBSCRIBE, - SUBACK, - UNSUBACK, - PUBACK, - PUBREC, - PUBREL, - PUBCOMP, - PINGREQ, - PINGRESP, - PUBLISH, - DISCONNECT, - CONNECT, - ): - self.logger.warning( - f"{self.session.client_id} Unhandled packet type: {packet.fixed_header.packet_type}", - ) - continue - - task: asyncio.Task[None] | None = None - if packet.fixed_header.packet_type == CONNACK: - task = asyncio.create_task(self.handle_connack(packet)) - elif packet.fixed_header.packet_type == SUBSCRIBE: - task = asyncio.create_task(self.handle_subscribe(packet)) - elif packet.fixed_header.packet_type == UNSUBSCRIBE: - task = asyncio.create_task(self.handle_unsubscribe(packet)) - elif packet.fixed_header.packet_type == SUBACK: - task = asyncio.create_task(self.handle_suback(packet)) - elif packet.fixed_header.packet_type == UNSUBACK: - task = asyncio.create_task(self.handle_unsuback(packet)) - elif packet.fixed_header.packet_type == PUBACK: - task = asyncio.create_task(self.handle_puback(packet)) - elif packet.fixed_header.packet_type == PUBREC: - task = asyncio.create_task(self.handle_pubrec(packet)) - elif packet.fixed_header.packet_type == PUBREL: - task = asyncio.create_task(self.handle_pubrel(packet)) - elif packet.fixed_header.packet_type == PUBCOMP: - task = asyncio.create_task(self.handle_pubcomp(packet)) - elif packet.fixed_header.packet_type == PINGREQ: - task = asyncio.create_task(self.handle_pingreq(packet)) - elif packet.fixed_header.packet_type == PINGRESP: - task = asyncio.create_task(self.handle_pingresp(packet)) - elif packet.fixed_header.packet_type == PUBLISH: - task = asyncio.create_task(self.handle_publish(packet)) - elif packet.fixed_header.packet_type == DISCONNECT: - task = asyncio.create_task(self.handle_disconnect(packet)) - elif packet.fixed_header.packet_type == CONNECT: - await self.handle_connect(packet) - if task: - running_tasks.append(task) - else: + if not fixed_header: self.logger.debug(f"{self.session.client_id} No more data (EOF received), stopping reader coro") break - except MQTTException: + + if fixed_header.packet_type in (RESERVED_0, RESERVED_15): + self.logger.warning( + f"{self.session.client_id} Received reserved packet, which is forbidden: closing connection", + ) + await self.handle_connection_closed() + continue + + cls = packet_class(fixed_header) + if self.reader is None: + self.logger.warning("Reader is not initialized!") + continue + packet = await cls.from_stream(self.reader, fixed_header=fixed_header) + + await self.plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=packet, session=self.session) + if packet.fixed_header is None or packet.fixed_header.packet_type not in ( + CONNACK, + SUBSCRIBE, + UNSUBSCRIBE, + SUBACK, + UNSUBACK, + PUBACK, + PUBREC, + PUBREL, + PUBCOMP, + PINGREQ, + PINGRESP, + PUBLISH, + DISCONNECT, + CONNECT, + ): + self.logger.warning(f"{self.session.client_id} Unhandled packet type: {packet.fixed_header.packet_type}") + continue + + task: asyncio.Task[None] | None = None + if packet.fixed_header.packet_type == CONNACK and isinstance(packet, ConnackPacket): + task = asyncio.create_task(self.handle_connack(packet)) + elif packet.fixed_header.packet_type == SUBSCRIBE and isinstance(packet, SubscribePacket): + task = asyncio.create_task(self.handle_subscribe(packet)) + elif packet.fixed_header.packet_type == UNSUBSCRIBE and isinstance(packet, UnsubscribePacket): + task = asyncio.create_task(self.handle_unsubscribe(packet)) + elif packet.fixed_header.packet_type == SUBACK and isinstance(packet, SubackPacket): + task = asyncio.create_task(self.handle_suback(packet)) + elif packet.fixed_header.packet_type == UNSUBACK and isinstance(packet, UnsubackPacket): + task = asyncio.create_task(self.handle_unsuback(packet)) + elif packet.fixed_header.packet_type == PUBACK and isinstance(packet, PubackPacket): + task = asyncio.create_task(self.handle_puback(packet)) + elif packet.fixed_header.packet_type == PUBREC and isinstance(packet, PubrecPacket): + task = asyncio.create_task(self.handle_pubrec(packet)) + elif packet.fixed_header.packet_type == PUBREL and isinstance(packet, PubrelPacket): + task = asyncio.create_task(self.handle_pubrel(packet)) + elif packet.fixed_header.packet_type == PUBCOMP and isinstance(packet, PubcompPacket): + task = asyncio.create_task(self.handle_pubcomp(packet)) + elif packet.fixed_header.packet_type == PINGREQ and isinstance(packet, PingReqPacket): + task = asyncio.create_task(self.handle_pingreq(packet)) + elif packet.fixed_header.packet_type == PINGRESP and isinstance(packet, PingRespPacket): + task = asyncio.create_task(self.handle_pingresp(packet)) + elif packet.fixed_header.packet_type == PUBLISH and isinstance(packet, PublishPacket): + task = asyncio.create_task(self.handle_publish(packet)) + elif packet.fixed_header.packet_type == DISCONNECT and isinstance(packet, DisconnectPacket): + task = asyncio.create_task(self.handle_disconnect(packet)) + elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket): + # TODO: why is this not like all other inside create_task? + await self.handle_connect(packet) + # task = asyncio.create_task(self.handle_connect(packet)) + if task: + running_tasks.append(task) + except MQTTError: self.logger.debug("Message discarded") except asyncio.CancelledError: self.logger.debug("Task cancelled, reader loop ending") @@ -506,11 +512,11 @@ async def _reader_loop(self) -> None: # noqa: C901, PLR0912, PLR0915 except TimeoutError: self.logger.debug(f"{self.session.client_id} Input stream read timeout") self.handle_read_timeout() - except NoDataException: + except NoDataError: self.logger.debug(f"{self.session.client_id} No data available") - except Exception as e: - self.logger.warning(f"{type(self).__name__} Unhandled exception in reader coro: {e!r}") - break + # except Exception as e: + # self.logger.warning(f"{type(self).__name__} Unhandled exception in reader coro: {e!r}") + # break while running_tasks: running_tasks.popleft().cancel() await self.handle_connection_closed() @@ -555,7 +561,7 @@ async def _send_packet( async def mqtt_deliver_next_message(self) -> ApplicationMessage | None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) if not self._is_attached(): return None @@ -573,73 +579,73 @@ async def mqtt_deliver_next_message(self) -> ApplicationMessage | None: def handle_write_timeout(self) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} write timeout unhandled") def handle_read_timeout(self) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} read timeout unhandled") - async def handle_connack(self, connack: ConnackPacket) -> None: # noqa: ARG002 + async def handle_connack(self, connack: ConnackPacket) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} CONNACK unhandled") - async def handle_connect(self, connect: ConnectPacket) -> None: # noqa: ARG002 + async def handle_connect(self, connect: ConnectPacket) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} CONNECT unhandled") - async def handle_subscribe(self, subscribe: SubscribePacket) -> None: # noqa: ARG002 + async def handle_subscribe(self, subscribe: SubscribePacket) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} SUBSCRIBE unhandled") - async def handle_unsubscribe(self, subscribe: UnsubscribePacket) -> None: # noqa: ARG002 + async def handle_unsubscribe(self, subscribe: UnsubscribePacket) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} UNSUBSCRIBE unhandled") - async def handle_suback(self, suback: SubackPacket) -> None: # noqa: ARG002 + async def handle_suback(self, suback: SubackPacket) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} SUBACK unhandled") - async def handle_unsuback(self, unsuback: UnsubackPacket) -> None: # noqa: ARG002 + async def handle_unsuback(self, unsuback: UnsubackPacket) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} UNSUBACK unhandled") - async def handle_pingresp(self, pingresp: PingRespPacket) -> None: # noqa: ARG002 + async def handle_pingresp(self, pingresp: PingRespPacket) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} PINGRESP unhandled") - async def handle_pingreq(self, pingreq: PingReqPacket) -> None: # noqa: ARG002 + async def handle_pingreq(self, pingreq: PingReqPacket) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} PINGREQ unhandled") - async def handle_disconnect(self, disconnect: DisconnectPacket) -> None: # noqa: ARG002 + async def handle_disconnect(self, disconnect: DisconnectPacket) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} DISCONNECT unhandled") async def handle_connection_closed(self) -> None: if self.session is None: msg = "Session is not initialized." - raise AMQTTException(msg) + raise AMQTTError(msg) self.logger.debug(f"{self.session.client_id} Connection closed unhandled") async def handle_puback(self, puback: PubackPacket) -> None: diff --git a/amqtt/mqtt/puback.py b/amqtt/mqtt/puback.py index e7452259..70ac1bf9 100644 --- a/amqtt/mqtt/puback.py +++ b/amqtt/mqtt/puback.py @@ -1,10 +1,10 @@ from typing import Self -from amqtt.errors import AMQTTException +from amqtt.errors import AMQTTError from amqtt.mqtt.packet import PUBACK, MQTTFixedHeader, MQTTPacket, PacketIdVariableHeader -class PubackPacket(MQTTPacket[PacketIdVariableHeader, None]): +class PubackPacket(MQTTPacket[PacketIdVariableHeader, None, MQTTFixedHeader]): VARIABLE_HEADER = PacketIdVariableHeader PAYLOAD = None @@ -32,7 +32,7 @@ def __init__( else: if fixed.packet_type is not PUBACK: msg = f"Invalid fixed packet type {fixed.packet_type} for PubackPacket init" - raise AMQTTException(msg) + raise AMQTTError(msg) header = fixed super().__init__(header, variable_header, None) diff --git a/amqtt/mqtt/pubcomp.py b/amqtt/mqtt/pubcomp.py index 1d2e3c92..647720bb 100644 --- a/amqtt/mqtt/pubcomp.py +++ b/amqtt/mqtt/pubcomp.py @@ -1,10 +1,10 @@ from typing import Self -from amqtt.errors import AMQTTException +from amqtt.errors import AMQTTError from amqtt.mqtt.packet import PUBCOMP, MQTTFixedHeader, MQTTPacket, PacketIdVariableHeader -class PubcompPacket(MQTTPacket[PacketIdVariableHeader]): +class PubcompPacket(MQTTPacket[PacketIdVariableHeader, None, MQTTFixedHeader]): VARIABLE_HEADER = PacketIdVariableHeader PAYLOAD = None @@ -18,7 +18,7 @@ def __init__( else: if fixed.packet_type is not PUBCOMP: msg = f"Invalid fixed packet type {fixed.packet_type} for PubcompPacket init" - raise AMQTTException( + raise AMQTTError( msg, ) header = fixed diff --git a/amqtt/mqtt/publish.py b/amqtt/mqtt/publish.py index 2ca4cbe0..98e2c845 100644 --- a/amqtt/mqtt/publish.py +++ b/amqtt/mqtt/publish.py @@ -2,8 +2,8 @@ from typing import Self from amqtt.adapters import ReaderAdapter -from amqtt.codecs import decode_packet_id, decode_string, encode_string, int_to_bytes -from amqtt.errors import AMQTTException, MQTTException +from amqtt.codecs_a import decode_packet_id, decode_string, encode_string, int_to_bytes +from amqtt.errors import AMQTTError, MQTTError from amqtt.mqtt.packet import PUBLISH, MQTTFixedHeader, MQTTPacket, MQTTPayload, MQTTVariableHeader @@ -14,7 +14,7 @@ def __init__(self, topic_name: str, packet_id: int | None = None) -> None: super().__init__() if "*" in topic_name: msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters." - raise MQTTException(msg) + raise MQTTError(msg) self.topic_name = topic_name self.packet_id = packet_id @@ -36,7 +36,7 @@ async def from_stream(cls, reader: ReaderAdapter | asyncio.StreamReader, fixed_h return cls(topic_name, packet_id) -class PublishPayload(MQTTPayload): +class PublishPayload(MQTTPayload[MQTTVariableHeader]): __slots__ = ("data",) def __init__(self, data: bytes | None = None) -> None: @@ -45,8 +45,8 @@ def __init__(self, data: bytes | None = None) -> None: def to_bytes( self, - fixed_header: MQTTFixedHeader | None = None, # noqa: ARG002 - variable_header: MQTTVariableHeader | None = None, # noqa: ARG002 + fixed_header: MQTTFixedHeader | None = None, + variable_header: MQTTVariableHeader | None = None, ) -> bytes: return self.data if self.data is not None else b"" @@ -74,7 +74,7 @@ def __repr__(self) -> str: return f"{type(self).__name__}(data={repr(self.data)!r})" -class PublishPacket(MQTTPacket[PublishVariableHeader, PublishPayload]): +class PublishPacket(MQTTPacket[PublishVariableHeader, PublishPayload, MQTTFixedHeader]): VARIABLE_HEADER = PublishVariableHeader PAYLOAD = PublishPayload @@ -92,7 +92,7 @@ def __init__( header = MQTTFixedHeader(PUBLISH, 0x00) elif fixed.packet_type != PUBLISH: msg = f"Invalid fixed packet type {fixed.packet_type} for PublishPacket init" - raise AMQTTException(msg) from None + raise AMQTTError(msg) from None else: header = fixed @@ -100,24 +100,28 @@ def __init__( self.variable_header = variable_header self.payload = payload + @classmethod + def build(cls, topic_name: str, message: bytes, packet_id: int | None, dup_flag: bool, qos: int | None, retain: bool) -> Self: + v_header = PublishVariableHeader(topic_name, packet_id) + payload = PublishPayload(message) + packet = cls(variable_header=v_header, payload=payload) + packet.dup_flag = dup_flag + packet.retain_flag = retain + packet.qos = qos + return packet + def set_flags(self, dup_flag: bool = False, qos: int = 0, retain_flag: bool = False) -> None: self.dup_flag = dup_flag self.retain_flag = retain_flag self.qos = qos def _set_header_flag(self, val: bool, mask: int) -> None: - if self.fixed_header is None: - msg = "Fixed header is not set" - raise ValueError(msg) if val: self.fixed_header.flags |= mask else: self.fixed_header.flags &= ~mask def _get_header_flag(self, mask: int) -> bool: - if self.fixed_header is None: - msg = "Fixed header is not set" - raise ValueError(msg) return bool(self.fixed_header.flags & mask) @property @@ -138,16 +142,10 @@ def retain_flag(self, val: bool) -> None: @property def qos(self) -> int | None: - if self.fixed_header is None: - msg = "Fixed header is not set" - raise ValueError(msg) return (self.fixed_header.flags & self.QOS_FLAG) >> 1 @qos.setter def qos(self, val: int) -> None: - if self.fixed_header is None: - msg = "Fixed header is not set" - raise ValueError(msg) self.fixed_header.flags &= 0xF9 self.fixed_header.flags |= val << 1 @@ -192,13 +190,3 @@ def topic_name(self, name: str) -> None: msg = "Variable header is not set" raise ValueError(msg) self.variable_header.topic_name = name - - @classmethod - def build(cls, topic_name: str, message: bytes, packet_id: int | None, dup_flag: bool, qos: int | None, retain: bool) -> Self: - v_header = PublishVariableHeader(topic_name, packet_id) - payload = PublishPayload(message) - packet = cls(variable_header=v_header, payload=payload) - packet.dup_flag = dup_flag - packet.retain_flag = retain - packet.qos = qos - return packet diff --git a/amqtt/mqtt/pubrec.py b/amqtt/mqtt/pubrec.py index b7d10e76..4ca75a23 100644 --- a/amqtt/mqtt/pubrec.py +++ b/amqtt/mqtt/pubrec.py @@ -1,10 +1,10 @@ from typing import Self -from amqtt.errors import AMQTTException +from amqtt.errors import AMQTTError from amqtt.mqtt.packet import PUBREC, MQTTFixedHeader, MQTTPacket, PacketIdVariableHeader -class PubrecPacket(MQTTPacket[PacketIdVariableHeader]): +class PubrecPacket(MQTTPacket[PacketIdVariableHeader, None, MQTTFixedHeader]): VARIABLE_HEADER = PacketIdVariableHeader PAYLOAD = None @@ -18,7 +18,7 @@ def __init__( else: if fixed.packet_type is not PUBREC: msg = f"Invalid fixed packet type {fixed.packet_type} for PubrecPacket init" - raise AMQTTException(msg) + raise AMQTTError(msg) header = fixed super().__init__(header) self.variable_header = variable_header diff --git a/amqtt/mqtt/pubrel.py b/amqtt/mqtt/pubrel.py index 102c0528..1bc3c429 100644 --- a/amqtt/mqtt/pubrel.py +++ b/amqtt/mqtt/pubrel.py @@ -1,10 +1,10 @@ from typing import Self -from amqtt.errors import AMQTTException +from amqtt.errors import AMQTTError from amqtt.mqtt.packet import PUBREL, MQTTFixedHeader, MQTTPacket, PacketIdVariableHeader -class PubrelPacket(MQTTPacket[PacketIdVariableHeader]): +class PubrelPacket(MQTTPacket[PacketIdVariableHeader, None, MQTTFixedHeader]): VARIABLE_HEADER = PacketIdVariableHeader PAYLOAD = None @@ -18,7 +18,7 @@ def __init__( else: if fixed.packet_type is not PUBREL: msg = f"Invalid fixed packet type {fixed.packet_type} for PubrelPacket init" - raise AMQTTException(msg) + raise AMQTTError(msg) header = fixed super().__init__(header) self.variable_header = variable_header diff --git a/amqtt/mqtt/suback.py b/amqtt/mqtt/suback.py index 7c32060a..a78d2e7a 100644 --- a/amqtt/mqtt/suback.py +++ b/amqtt/mqtt/suback.py @@ -2,12 +2,12 @@ from typing import Self from amqtt.adapters import ReaderAdapter -from amqtt.codecs import bytes_to_int, int_to_bytes, read_or_raise -from amqtt.errors import AMQTTException, NoDataException +from amqtt.codecs_a import bytes_to_int, int_to_bytes, read_or_raise +from amqtt.errors import AMQTTError, NoDataError from amqtt.mqtt.packet import SUBACK, MQTTFixedHeader, MQTTPacket, MQTTPayload, MQTTVariableHeader, PacketIdVariableHeader -class SubackPayload(MQTTPayload): +class SubackPayload(MQTTPayload[MQTTVariableHeader]): __slots__ = ("return_codes",) RETURN_CODE_00 = 0x00 @@ -24,8 +24,8 @@ def __repr__(self) -> str: def to_bytes( self, - fixed_header: MQTTFixedHeader | None = None, # noqa: ARG002 - variable_header: MQTTVariableHeader | None = None, # noqa: ARG002 + fixed_header: MQTTFixedHeader | None = None, + variable_header: MQTTVariableHeader | None = None, ) -> bytes: out = b"" for return_code in self.return_codes: @@ -42,7 +42,7 @@ async def from_stream( return_codes = [] if fixed_header is None or variable_header is None: msg = "Fixed header or variable header cannot be None" - raise AMQTTException(msg) + raise AMQTTError(msg) bytes_to_read = fixed_header.remaining_length - variable_header.bytes_length for _ in range(bytes_to_read): @@ -50,12 +50,12 @@ async def from_stream( return_code_byte = await read_or_raise(reader, 1) return_code = bytes_to_int(return_code_byte) return_codes.append(return_code) - except NoDataException: + except NoDataError: break return cls(return_codes) -class SubackPacket(MQTTPacket[PacketIdVariableHeader, SubackPayload]): +class SubackPacket(MQTTPacket[PacketIdVariableHeader, SubackPayload, MQTTFixedHeader]): VARIABLE_HEADER = PacketIdVariableHeader PAYLOAD = SubackPayload @@ -70,7 +70,7 @@ def __init__( else: if fixed.packet_type is not SUBACK: msg = f"Invalid fixed packet type {fixed.packet_type} for SubackPacket init" - raise AMQTTException(msg) + raise AMQTTError(msg) header = fixed super().__init__(header) diff --git a/amqtt/mqtt/subscribe.py b/amqtt/mqtt/subscribe.py index 9c1f56d6..d3e144d3 100644 --- a/amqtt/mqtt/subscribe.py +++ b/amqtt/mqtt/subscribe.py @@ -2,12 +2,12 @@ from typing import Self from amqtt.adapters import ReaderAdapter -from amqtt.codecs import bytes_to_int, decode_string, encode_string, int_to_bytes, read_or_raise -from amqtt.errors import AMQTTException, NoDataException +from amqtt.codecs_a import bytes_to_int, decode_string, encode_string, int_to_bytes, read_or_raise +from amqtt.errors import AMQTTError, NoDataError from amqtt.mqtt.packet import SUBSCRIBE, MQTTFixedHeader, MQTTPacket, MQTTPayload, MQTTVariableHeader, PacketIdVariableHeader -class SubscribePayload(MQTTPayload): +class SubscribePayload(MQTTPayload[MQTTVariableHeader]): __slots__ = ("topics",) def __init__(self, topics: list[tuple[str, int]] | None = None) -> None: @@ -16,8 +16,8 @@ def __init__(self, topics: list[tuple[str, int]] | None = None) -> None: def to_bytes( self, - fixed_header: MQTTFixedHeader | None = None, # noqa: ARG002 - variable_header: MQTTVariableHeader | None = None, # noqa: ARG002 + fixed_header: MQTTFixedHeader | None = None, + variable_header: MQTTVariableHeader | None = None, ) -> bytes: out = b"" for topic in self.topics: @@ -46,7 +46,7 @@ async def from_stream( qos = bytes_to_int(qos_byte) topics.append((topic, qos)) read_bytes += 2 + len(topic.encode("utf-8")) + 1 - except NoDataException: + except NoDataError: break return cls(topics) @@ -54,7 +54,7 @@ def __repr__(self) -> str: return type(self).__name__ + f"(topics={self.topics!r})" -class SubscribePacket(MQTTPacket[PacketIdVariableHeader, SubscribePayload]): +class SubscribePacket(MQTTPacket[PacketIdVariableHeader, SubscribePayload, MQTTFixedHeader]): VARIABLE_HEADER = PacketIdVariableHeader PAYLOAD = SubscribePayload @@ -69,7 +69,7 @@ def __init__( else: if fixed.packet_type is not SUBSCRIBE: msg = f"Invalid fixed packet type {fixed.packet_type} for SubscribePacket init" - raise AMQTTException(msg) + raise AMQTTError(msg) header = fixed super().__init__(header) diff --git a/amqtt/mqtt/unsuback.py b/amqtt/mqtt/unsuback.py index a8385a64..fde941c8 100644 --- a/amqtt/mqtt/unsuback.py +++ b/amqtt/mqtt/unsuback.py @@ -1,10 +1,10 @@ from typing import Self -from amqtt.errors import AMQTTException +from amqtt.errors import AMQTTError from amqtt.mqtt.packet import UNSUBACK, MQTTFixedHeader, MQTTPacket, PacketIdVariableHeader -class UnsubackPacket(MQTTPacket[PacketIdVariableHeader]): +class UnsubackPacket(MQTTPacket[PacketIdVariableHeader, None, MQTTFixedHeader]): VARIABLE_HEADER = PacketIdVariableHeader PAYLOAD = None @@ -19,7 +19,7 @@ def __init__( else: if fixed.packet_type is not UNSUBACK: msg = f"Invalid fixed packet type {fixed.packet_type} for UnsubackPacket init" - raise AMQTTException( + raise AMQTTError( msg, ) header = fixed diff --git a/amqtt/mqtt/unsubscribe.py b/amqtt/mqtt/unsubscribe.py index b6f73b0c..d27cce6c 100644 --- a/amqtt/mqtt/unsubscribe.py +++ b/amqtt/mqtt/unsubscribe.py @@ -2,8 +2,8 @@ from typing import Self from amqtt.adapters import ReaderAdapter -from amqtt.codecs import decode_string, encode_string -from amqtt.errors import AMQTTException, NoDataException +from amqtt.codecs_a import decode_string, encode_string +from amqtt.errors import AMQTTError, NoDataError from amqtt.mqtt.packet import UNSUBSCRIBE, MQTTFixedHeader, MQTTPacket, MQTTPayload, MQTTVariableHeader, PacketIdVariableHeader @@ -14,7 +14,7 @@ def __init__(self, topics: list[str] | None = None) -> None: super().__init__() self.topics = topics or [] - def to_bytes(self, fixed_header: MQTTFixedHeader | None = None, variable_header: MQTTVariableHeader | None = None) -> bytes: # noqa: ARG002 + def to_bytes(self, fixed_header: MQTTFixedHeader | None = None, variable_header: MQTTVariableHeader | None = None) -> bytes: out = b"" for topic in self.topics: out += encode_string(topic) @@ -39,7 +39,7 @@ async def from_stream( topic = await decode_string(reader) topics.append(topic) read_bytes += 2 + len(topic.encode("utf-8")) - except NoDataException: + except NoDataError: break return cls(topics) @@ -59,7 +59,7 @@ def __init__( else: if fixed.packet_type is not UNSUBSCRIBE: msg = f"Invalid fixed packet type {fixed.packet_type} for UnsubscribePacket init" - raise AMQTTException(msg) + raise AMQTTError(msg) header = fixed super().__init__(header) diff --git a/amqtt/plugins/authentication.py b/amqtt/plugins/authentication.py index c1a528e0..0738b19e 100644 --- a/amqtt/plugins/authentication.py +++ b/amqtt/plugins/authentication.py @@ -1,9 +1,9 @@ from pathlib import Path -from typing import Any from passlib.apps import custom_app_context as pwd_context from amqtt.plugins.manager import BaseContext +from amqtt.session import Session class BaseAuthPlugin: @@ -15,7 +15,7 @@ def __init__(self, context: BaseContext) -> None: if not self.auth_config: self.context.logger.warning("'auth' section not found in context configuration") - async def authenticate(self, *args: Any, **kwargs: Any) -> bool | None: # noqa: ARG002 + async def authenticate(self, *args: None, **kwargs: Session) -> bool | None: """Logic for base Authentication. Returns True if auth config exists.""" if not self.auth_config: # auth config section not found @@ -27,7 +27,7 @@ async def authenticate(self, *args: Any, **kwargs: Any) -> bool | None: # noqa: class AnonymousAuthPlugin(BaseAuthPlugin): """Authentication plugin allowing anonymous access.""" - async def authenticate(self, *args: Any, **kwargs: Any) -> bool: + async def authenticate(self, *args: None, **kwargs: Session) -> bool: authenticated = await super().authenticate(*args, **kwargs) if authenticated: # Default to allowing anonymous @@ -36,7 +36,7 @@ async def authenticate(self, *args: Any, **kwargs: Any) -> bool: self.context.logger.debug("Authentication success: config allows anonymous") return True - session = kwargs.get("session") + session: Session | None = kwargs.get("session") if session and session.username: self.context.logger.debug(f"Authentication success: session has username '{session.username}'") return True @@ -80,7 +80,7 @@ def _read_password_file(self) -> None: except Exception: self.context.logger.exception(f"Unexpected error reading password file '{password_file}'") - async def authenticate(self, *args: Any, **kwargs: Any) -> bool | None: + async def authenticate(self, *args: None, **kwargs: Session) -> bool | None: """Authenticate users based on the file-stored user database.""" authenticated = await super().authenticate(*args, **kwargs) if authenticated: diff --git a/amqtt/plugins/logging.py b/amqtt/plugins/logging_a.py similarity index 96% rename from amqtt/plugins/logging.py rename to amqtt/plugins/logging_a.py index f1fe3e1c..42124f8e 100644 --- a/amqtt/plugins/logging.py +++ b/amqtt/plugins/logging_a.py @@ -15,7 +15,7 @@ class EventLoggerPlugin: def __init__(self, context: BaseContext) -> None: self.context = context - async def log_event(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + async def log_event(self, *args: Any, **kwargs: Any) -> None: """Log the occurrence of an event.""" event_name = kwargs["event_name"].replace("old", "") self.context.logger.info(f"### '{event_name}' EVENT FIRED ###") @@ -34,7 +34,7 @@ class PacketLoggerPlugin: def __init__(self, context: BaseContext) -> None: self.context = context - async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None: """Log an MQTT packet when it is received.""" packet = kwargs.get("packet") session: Session | None = kwargs.get("session") @@ -44,7 +44,7 @@ async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None: # n else: self.context.logger.debug(f"<-in-- {packet!r}") - async def on_mqtt_packet_sent(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + async def on_mqtt_packet_sent(self, *args: Any, **kwargs: Any) -> None: """Log an MQTT packet when it is sent.""" packet = kwargs.get("packet") session: Session | None = kwargs.get("session") diff --git a/amqtt/plugins/manager.py b/amqtt/plugins/manager.py index b12e9076..a2b1fe49 100644 --- a/amqtt/plugins/manager.py +++ b/amqtt/plugins/manager.py @@ -1,6 +1,7 @@ __all__ = ["BaseContext", "PluginManager", "get_plugin_manager"] import asyncio +from collections.abc import Awaitable, Callable import contextlib import copy from importlib.metadata import EntryPoint, EntryPoints, entry_points @@ -118,7 +119,7 @@ def plugins(self) -> list[Plugin]: """ return self._plugins - def _schedule_coro(self, coro: Any) -> Any: + def _schedule_coro(self, coro: Awaitable[str | bool | None]) -> asyncio.Future[str | bool | None]: return asyncio.ensure_future(coro) async def fire_event(self, event_name: str, wait: bool = False, *args: Any, **kwargs: Any) -> None: @@ -156,7 +157,12 @@ def clean_fired_events(future: asyncio.Future[Any]) -> None: await asyncio.wait(tasks) self.logger.debug(f"Plugins len(_fired_events)={len(self._fired_events)}") - async def map(self, coro: Any, *args: Any, **kwargs: Any) -> dict[Plugin, Any]: + async def map( + self, + coro: Callable[[Plugin, Any], Awaitable[str | bool | None]], + *args: Any, + **kwargs: Any, + ) -> dict[Plugin, str | bool | None]: """Schedule a given coroutine call for each plugin. The coro called gets the Plugin instance as the first argument of its method call. @@ -190,15 +196,15 @@ async def map(self, coro: Any, *args: Any, **kwargs: Any) -> dict[Plugin, Any]: return ret_dict @staticmethod - async def _call_coro(plugin: Plugin, coro_name: str, *args: Any, **kwargs: Any) -> Any: + async def _call_coro(plugin: Plugin, coro_name: str, *args: Any, **kwargs: Any) -> str | bool | None: if not hasattr(plugin.object, coro_name): - # Plugin doesn't implement coro_name + logging.warning("Plugin doesn't implement coro_name") return None - coro = getattr(plugin.object, coro_name)(*args, **kwargs) + coro: Awaitable[str | bool | None] = getattr(plugin.object, coro_name)(*args, **kwargs) return await coro - async def map_plugin_coro(self, coro_name: str, *args: Any, **kwargs: Any) -> dict[Plugin, Any]: + async def map_plugin_coro(self, coro_name: str, *args: Any, **kwargs: Any) -> dict[Plugin, str | bool | None]: """Call a plugin declared by plugin by its name. :param coro_name: diff --git a/amqtt/plugins/persistence.py b/amqtt/plugins/persistence.py index 673f22a7..8baf2d9b 100644 --- a/amqtt/plugins/persistence.py +++ b/amqtt/plugins/persistence.py @@ -12,17 +12,18 @@ def __init__(self, context: BaseContext) -> None: self.conn: sqlite3.Connection | None = None self.cursor: sqlite3.Cursor | None = None self.db_file: str | None = None - try: - if self.context.config is None: - msg = "Context configuration is missing or malformed." - raise ValueError(msg) - self.persistence_config: dict[str, Any] = self.context.config.get("persistence", {}) + self.persistence_config: dict[str, Any] + + if ( + persistence_config := self.context.config.get("persistence") if self.context.config is not None else None + ) is not None: + self.persistence_config = persistence_config self.init_db() - except (KeyError, ValueError) as e: - self.context.logger.warning(f"'persistence' section not found in context configuration: {e}") + else: + self.context.logger.warning("'persistence' section not found in context configuration") def init_db(self) -> None: - self.db_file = self.persistence_config.get("file", None) + self.db_file = self.persistence_config.get("file") if not self.db_file: self.context.logger.warning("'file' persistence parameter not found") else: @@ -36,9 +37,7 @@ def init_db(self) -> None: self.cursor.execute( "CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)", ) - self.cursor.execute( - "PRAGMA table_info(session)", - ) + self.cursor.execute("PRAGMA table_info(session)") columns = {col[1] for col in self.cursor.fetchall()} required_columns = {"client_id", "data"} if not required_columns.issubset(columns): @@ -56,14 +55,13 @@ async def save_session(self, session: Session) -> None: except Exception: self.context.logger.exception(f"Failed saving session '{session}'") - async def find_session(self, client_id: str) -> Any | None: + async def find_session(self, client_id: str) -> Session | None: if self.cursor: row = self.cursor.execute( "SELECT data FROM session where client_id=?", (client_id,), ).fetchone() - if row: - return json.loads(row[0]) + return json.loads(row[0]) if row else None return None async def del_session(self, client_id: str) -> None: diff --git a/amqtt/plugins/sys/broker.py b/amqtt/plugins/sys/broker.py index 67385796..8b4b22bf 100644 --- a/amqtt/plugins/sys/broker.py +++ b/amqtt/plugins/sys/broker.py @@ -2,12 +2,13 @@ from collections import deque from collections.abc import Buffer from datetime import UTC, datetime -from typing import Any, SupportsIndex, SupportsInt +from typing import SupportsIndex, SupportsInt import amqtt from amqtt.broker import BrokerContext -from amqtt.codecs import int_to_bytes_str -from amqtt.mqtt.packet import PUBLISH +from amqtt.codecs_a import int_to_bytes_str +from amqtt.mqtt.packet import PUBLISH, MQTTFixedHeader, MQTTPacket, PacketIdVariableHeader +from amqtt.mqtt.subscribe import SubscribePayload DOLLAR_SYS_ROOT = "$SYS/broker/" STAT_BYTES_SENT = "bytes_sent" @@ -48,18 +49,18 @@ async def _broadcast_sys_topic(self, topic_basename: str, data: bytes) -> None: """Broadcast a system topic.""" await self.context.broadcast_message(topic_basename, data) - def schedule_broadcast_sys_topic(self, topic_basename: str, data: bytes) -> asyncio.Task[Any]: + def schedule_broadcast_sys_topic(self, topic_basename: str, data: bytes) -> asyncio.Task[None]: """Schedule broadcasting of system topics.""" return asyncio.ensure_future( self._broadcast_sys_topic(DOLLAR_SYS_ROOT + topic_basename, data), loop=self.context.loop, ) - async def on_broker_pre_start(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + async def on_broker_pre_start(self, *args: None, **kwargs: None) -> None: """Clear statistics before broker start.""" self._clear_stats() - async def on_broker_post_start(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + async def on_broker_post_start(self, *args: None, **kwargs: None) -> None: """Initialize statistics and start $SYS broadcasting.""" self._stats[STAT_START_TIME] = int(datetime.now(tz=UTC).timestamp()) version = f"HBMQTT version {amqtt.__version__}" @@ -84,7 +85,7 @@ async def on_broker_post_start(self, *args: Any, **kwargs: Any) -> None: # noqa pass # 'sys_interval' config parameter not found - async def on_broker_pre_stop(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + async def on_broker_pre_stop(self, *args: None, **kwargs: None) -> None: """Stop $SYS topics broadcasting.""" if self._sys_handle: self._sys_handle.cancel() @@ -106,7 +107,7 @@ def broadcast_dollar_sys_topics(self) -> None: subscriptions_count = sum(len(sub) for sub in self.context.subscriptions.values()) # Broadcast updates - tasks: deque[asyncio.Task[Any]] = deque() + tasks: deque[asyncio.Task[None]] = deque() stats: dict[str, int | str] = { "load/bytes/received": self._stats[STAT_BYTES_RECEIVED], "load/bytes/sent": self._stats[STAT_BYTES_SENT], @@ -150,7 +151,11 @@ def broadcast_dollar_sys_topics(self) -> None: else None ) - async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + async def on_mqtt_packet_received( + self, + *args: None, + **kwargs: MQTTPacket[PacketIdVariableHeader, SubscribePayload, MQTTFixedHeader], + ) -> None: """Handle incoming MQTT packets.""" packet = kwargs.get("packet") if packet: @@ -160,7 +165,11 @@ async def on_mqtt_packet_received(self, *args: Any, **kwargs: Any) -> None: # n if packet.fixed_header.packet_type == PUBLISH: self._stats[STAT_PUBLISH_RECEIVED] += 1 - async def on_mqtt_packet_sent(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + async def on_mqtt_packet_sent( + self, + *args: None, + **kwargs: MQTTPacket[PacketIdVariableHeader, SubscribePayload, MQTTFixedHeader], + ) -> None: """Handle sent MQTT packets.""" packet = kwargs.get("packet") if packet: @@ -170,7 +179,7 @@ async def on_mqtt_packet_sent(self, *args: Any, **kwargs: Any) -> None: # noqa: if packet.fixed_header.packet_type == PUBLISH: self._stats[STAT_PUBLISH_SENT] += 1 - async def on_broker_client_connected(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + async def on_broker_client_connected(self, *args: None, **kwargs: None) -> None: """Handle broker client connection.""" self._stats[STAT_CLIENTS_CONNECTED] += 1 self._stats[STAT_CLIENTS_MAXIMUM] = max( @@ -178,7 +187,7 @@ async def on_broker_client_connected(self, *args: Any, **kwargs: Any) -> None: self._stats[STAT_CLIENTS_CONNECTED], ) - async def on_broker_client_disconnected(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + async def on_broker_client_disconnected(self, *args: None, **kwargs: None) -> None: """Handle broker client disconnection.""" self._stats[STAT_CLIENTS_CONNECTED] -= 1 self._stats[STAT_CLIENTS_DISCONNECTED] += 1 diff --git a/amqtt/plugins/topic_checking.py b/amqtt/plugins/topic_checking.py index 8c23a547..0445e585 100644 --- a/amqtt/plugins/topic_checking.py +++ b/amqtt/plugins/topic_checking.py @@ -11,7 +11,7 @@ def __init__(self, context: BaseContext) -> None: if self.topic_config is None: self.context.logger.warning("'topic-check' section not found in context configuration") - async def topic_filtering(self, *args: Any, **kwargs: Any) -> bool: # noqa: ARG002 + async def topic_filtering(self, *args: Any, **kwargs: Any) -> bool: if not self.topic_config: # auth config section not found self.context.logger.warning("'auth' section not found in context configuration") diff --git a/amqtt/scripts/pub_script.py b/amqtt/scripts/pub_script.py index a2ff20bd..9f0c7450 100644 --- a/amqtt/scripts/pub_script.py +++ b/amqtt/scripts/pub_script.py @@ -42,7 +42,7 @@ import amqtt from amqtt.client import MQTTClient -from amqtt.errors import ConnectException +from amqtt.errors import ConnectError from amqtt.utils import read_yaml_config logger = logging.getLogger(__name__) @@ -64,9 +64,10 @@ def _get_qos(arguments: dict[str, Any]) -> int | None: return None -def _get_extra_headers(arguments: dict[str, Any]) -> Any: +def _get_extra_headers(arguments: dict[str, Any]) -> dict[str, Any]: try: - return json.loads(arguments["--extra-headers"]) + extra_headers: dict[str, Any] = json.loads(arguments["--extra-headers"]) + return extra_headers except (json.JSONDecodeError, TypeError): return {} @@ -129,7 +130,7 @@ async def do_pub(client: MQTTClient, arguments: dict[str, Any]) -> None: except KeyboardInterrupt: await client.disconnect() logger.info(f"{client.client_id} Disconnected from broker") - except ConnectException as ce: + except ConnectError as ce: logger.fatal(f"Connection to '{arguments['--url']}' failed: {ce!r}") except asyncio.CancelledError: logger.fatal("Publish canceled due to previous error") diff --git a/amqtt/scripts/sub_script.py b/amqtt/scripts/sub_script.py index 7e17d62b..8e251f3b 100644 --- a/amqtt/scripts/sub_script.py +++ b/amqtt/scripts/sub_script.py @@ -39,7 +39,7 @@ import amqtt from amqtt.client import MQTTClient -from amqtt.errors import ConnectException, MQTTException +from amqtt.errors import ConnectError, MQTTError from amqtt.mqtt.constants import QOS_0 from amqtt.utils import read_yaml_config @@ -62,9 +62,10 @@ def _get_qos(arguments: dict[str, Any]) -> int: return QOS_0 -def _get_extra_headers(arguments: dict[str, Any]) -> Any: +def _get_extra_headers(arguments: dict[str, Any]) -> dict[str, Any]: try: - return json.loads(arguments["--extra-headers"]) + extra_headers: dict[str, Any] = json.loads(arguments["--extra-headers"]) + return extra_headers except (json.JSONDecodeError, TypeError): return {} @@ -97,7 +98,7 @@ async def do_sub(client: MQTTClient, arguments: dict[str, Any]) -> None: count += 1 sys.stdout.buffer.write(message.publish_packet.data) sys.stdout.write("\n") - except MQTTException: + except MQTTError: logger.debug("Error reading packet") await client.disconnect() @@ -105,7 +106,7 @@ async def do_sub(client: MQTTClient, arguments: dict[str, Any]) -> None: except KeyboardInterrupt: await client.disconnect() logger.info(f"{client.client_id} Disconnected from broker") - except ConnectException as ce: + except ConnectError as ce: logger.fatal(f"Connection to '{arguments['--url']}' failed: {ce!r}") except asyncio.CancelledError: logger.fatal("Publish canceled due to previous error") diff --git a/amqtt/session.py b/amqtt/session.py index 75336409..1b4a63e0 100644 --- a/amqtt/session.py +++ b/amqtt/session.py @@ -1,10 +1,10 @@ from asyncio import Queue from collections import OrderedDict -from typing import Any +from typing import Any, ClassVar from transitions import Machine -from amqtt.errors import AMQTTException +from amqtt.errors import AMQTTError from amqtt.mqtt.publish import PublishPacket OUTGOING = 0 @@ -107,7 +107,7 @@ def __init__(self, packet_id: int | None, topic: str, qos: int | None, data: byt class Session: - states = ["new", "connected", "disconnected"] + states: ClassVar[list[str]] = ["new", "connected", "disconnected"] def __init__(self) -> None: self._init_states() @@ -179,7 +179,7 @@ def next_packet_id(self) -> int: self._packet_id = (self._packet_id % 65535) + 1 if self._packet_id == limit: msg = "More than 65535 messages pending. No free packet ID" - raise AMQTTException(msg) + raise AMQTTError(msg) return self._packet_id diff --git a/amqtt/utils.py b/amqtt/utils.py index 66caba02..63b85d96 100644 --- a/amqtt/utils.py +++ b/amqtt/utils.py @@ -5,6 +5,7 @@ import secrets import string import typing +from typing import Any import yaml @@ -37,11 +38,12 @@ def gen_client_id() -> str: return gen_id -def read_yaml_config(config_file: str | Path) -> typing.Any | dict[str, typing.Any] | None: +def read_yaml_config(config_file: str | Path) -> dict[str, Any] | None: """Read a YAML configuration file.""" try: - with Path(str(config_file)).open() as stream: - return yaml.full_load(stream) + with Path(str(config_file)).open(encoding="utf-8") as stream: + yaml_result: dict[str, Any] = yaml.full_load(stream) + return yaml_result except yaml.YAMLError: logger.exception(f"Invalid config_file {config_file}") return None diff --git a/pyproject.toml b/pyproject.toml index d1740035..6d08e4bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,56 +113,42 @@ extend-select = [ ] ignore = [ - "ANN401", # Checks that function arguments are annotated with a more specific type than Any. - "BLE001", # Checks for except clauses that catch all exceptions. - "D107", # Missing docstring in `__init__` "ERA001", # Checks for commented-out Python code. "FBT001", # Checks for the use of boolean positional arguments in function definitions. "FBT002", # Checks for the use of boolean positional arguments in function definitions. - "FBT003", # Checks for boolean positional arguments in function calls. - "FIX002", # Checks for "TODO" comments. "G004", # Logging statement uses f-string - "PLR2004", # Magic value used in comparison, consider replacing 5 with a constant variable - "RUF001", # Checks for ambiguous Unicode characters in strings. - "RUF012", # Checks for mutable default values in class attributes. - "TD002", # Checks that a TODO comment includes an author. - "TD003", # Checks that a TODO comment is associated with a link to a relevant issue or ticket. - "TRY002", # Checks for code that raises Exception or BaseException directly. "TRY300", # Checks for return statements in try blocks. - "TRY301", # Checks for raise statements within try blocks. - "D103", # Missing docstring in public function "D100", # Missing docstring in public module "D101", # Missing docstring in public class "D102", # Missing docstring in public method + "D103", # Missing docstring in public function "D105", # Missing docstring in magic method + "D107", # Missing docstring in `__init__` + "FIX002", # Checks for "TODO" comments. + "TD002", # Checks that a TODO comment includes an author. + "TD003", # Checks that a TODO comment is associated with a link to a relevant issue or ticket. + "SLF001", # Private member accessed + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed + "ARG002", # Unused method argument + "ARG003", # Unused class method argument + "PLR2004", # Magic value used in comparison, consider replacing with a constant variable ] [tool.ruff.lint.per-file-ignores] -"__init__.py" = [ - "F403", # Checks for the use of wildcard imports. - "F405", # Checks for names that might be undefined -] "tests/**" = [ - "D100", # Missing docstring in public module - "D101", # - "D102", # Missing docstring in public method - "D103", # Missing docstring in public function - "D104", # Missing docstring in public package - "N802", # Function name {name} should be lowercase - "N806", # Variable `userId` in function should be lowercase - "N816", # Variable {name} in global scope should not be mixedCase - "S101", # Use of assert detected - "S106", # Possible hardcoded password assigned to argument: "password_file" - "SLF001", # Private member accessed: {access} - "ANN001", - "ANN201", - "ARG001", - "ASYNC110", - "INP001", - "PGH003", - "PTH107", - "PTH110", - "PTH118", + "ALL", + "D100", # Missing docstring in public module + "D101", # + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "N802", # Function name {name} should be lowercase + "N806", # Variable `userId` in function should be lowercase + "N816", # Variable {name} in global scope should not be mixedCase + "S101", # Use of assert detected + "S106", # Possible hardcoded password assigned to argument: "password_file" + "SLF001", # Private member accessed + ] [tool.ruff.lint.flake8-pytest-style] @@ -178,12 +164,12 @@ case-sensitive = true extra-standard-library = ["typing_extensions"] [tool.ruff.lint.mccabe] -max-complexity = 20 +max-complexity = 42 [tool.ruff.lint.pylint] max-args = 12 -max-branches = 25 -max-statements = 70 +max-branches = 42 +max-statements = 142 max-returns = 10 # ----------------------------------- PYTEST ----------------------------------- diff --git a/samples/client_publish.py b/samples/client_publish.py index b3e12144..57bcf58d 100644 --- a/samples/client_publish.py +++ b/samples/client_publish.py @@ -1,7 +1,7 @@ import asyncio import logging -from amqtt.client import ConnectException, MQTTClient +from amqtt.client import ConnectError, MQTTClient from amqtt.mqtt.constants import QOS_1, QOS_2 # @@ -43,7 +43,7 @@ async def test_coro2() -> None: await C.publish("a/b", b"TEST MESSAGE WITH QOS_2", qos=0x02) logger.info("messages published") await C.disconnect() - except ConnectException as ce: + except ConnectError as ce: logger.exception(f"Connection failed: {ce}") asyncio.get_event_loop().stop() diff --git a/samples/client_publish_acl.py b/samples/client_publish_acl.py index 196b6ebd..0a1cbb69 100644 --- a/samples/client_publish_acl.py +++ b/samples/client_publish_acl.py @@ -1,7 +1,7 @@ import asyncio import logging -from amqtt.client import ConnectException, MQTTClient +from amqtt.client import ConnectError, MQTTClient # # This sample shows how to publish messages to broker using different QOS @@ -26,7 +26,7 @@ async def test_coro() -> None: await C.publish("calendar/amqtt/releases", b"NEW RELEASE", qos=0x01) logger.info("messages published") await C.disconnect() - except ConnectException as ce: + except ConnectError as ce: logger.exception(f"Connection failed: {ce}") asyncio.get_event_loop().stop() diff --git a/samples/client_subscribe.py b/samples/client_subscribe.py index 99c2f32b..0a29efb7 100644 --- a/samples/client_subscribe.py +++ b/samples/client_subscribe.py @@ -1,7 +1,7 @@ import asyncio import logging -from amqtt.client import ClientException, MQTTClient +from amqtt.client import ClientError, MQTTClient from amqtt.mqtt.constants import QOS_1, QOS_2 # @@ -30,7 +30,7 @@ async def uptime_coro() -> None: await C.unsubscribe(["$SYS/broker/uptime", "$SYS/broker/load/#"]) logger.info("UnSubscribed") await C.disconnect() - except ClientException as ce: + except ClientError as ce: logger.exception(f"Client exception: {ce}") diff --git a/samples/client_subscribe_acl.py b/samples/client_subscribe_acl.py index d0bc405b..aa694350 100644 --- a/samples/client_subscribe_acl.py +++ b/samples/client_subscribe_acl.py @@ -1,7 +1,7 @@ import asyncio import logging -from amqtt.client import ClientException, MQTTClient +from amqtt.client import ClientError, MQTTClient from amqtt.mqtt.constants import QOS_1 # @@ -34,7 +34,7 @@ async def uptime_coro() -> None: await C.unsubscribe(["$SYS/broker/uptime", "$SYS/broker/load/#"]) logger.info("UnSubscribed") await C.disconnect() - except ClientException as ce: + except ClientError as ce: logger.exception(f"Client exception: {ce}") diff --git a/tests/conftest.py b/tests/conftest.py index 96a1da4a..65128f34 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -106,7 +106,7 @@ def ca_file_fixture(): temp_dir = Path(tempfile.mkdtemp(prefix="amqtt-test-")) url = "http://test.mosquitto.org/ssl/mosquitto.org.crt" ca_file = temp_dir / "mosquitto.org.crt" - urllib.request.urlretrieve(url, str(ca_file)) # noqa: S310 + urllib.request.urlretrieve(url, str(ca_file)) log.info(f"Stored mosquitto cert at {ca_file}") # Yield the CA file path for tests diff --git a/tests/mqtt/test_connect.py b/tests/mqtt/test_connect.py index 8b0a990c..f552d46c 100644 --- a/tests/mqtt/test_connect.py +++ b/tests/mqtt/test_connect.py @@ -1,6 +1,3 @@ -# Copyright (c) 2015 Nicolas JOUANIN -# -# See the file license.txt for copying permission. import asyncio import unittest @@ -36,7 +33,7 @@ def test_decode_ok(self): assert message.payload.will_topic == "WillTopic" assert message.payload.will_message == b"WillMessage" assert message.payload.username == "user" - assert message.payload.password == "password" # noqa: S105 + assert message.payload.password == "password" def test_decode_ok_will_flag(self): data = b"\x10\x26\x00\x04MQTT\x04\xca\x00\x00\x00\x0a0123456789\x00\x04user\x00\x08password" @@ -58,7 +55,7 @@ def test_decode_ok_will_flag(self): assert message.payload.will_topic is None assert message.payload.will_message is None assert message.payload.username == "user" - assert message.payload.password == "password" # noqa: S105 + assert message.payload.password == "password" def test_decode_fail_reserved_flag(self): data = ( @@ -151,5 +148,5 @@ def test_getattr_ok(self): assert message.will_message == b"WillMessage" assert message.payload.username == "user" assert message.username == "user" - assert message.payload.password == "password" # noqa: S105 - assert message.password == "password" # noqa: S105 + assert message.payload.password == "password" + assert message.password == "password" diff --git a/tests/mqtt/test_packet.py b/tests/mqtt/test_packet.py index 9b503c01..1a692402 100644 --- a/tests/mqtt/test_packet.py +++ b/tests/mqtt/test_packet.py @@ -1,13 +1,10 @@ -# Copyright (c) 2015 Nicolas JOUANIN -# -# See the file license.txt for copying permission. import asyncio import unittest import pytest from amqtt.adapters import BufferReader -from amqtt.errors import MQTTException +from amqtt.errors import MQTTError from amqtt.mqtt.packet import CONNECT, MQTTFixedHeader @@ -38,7 +35,7 @@ def test_from_bytes_with_length(self): def test_from_bytes_ko_with_length(self): data = b"\x10\xff\xff\xff\xff\x7f" stream = BufferReader(data) - with pytest.raises(MQTTException): + with pytest.raises(MQTTError): self.loop.run_until_complete(MQTTFixedHeader.from_stream(stream)) def test_to_bytes(self): diff --git a/tests/mqtt/test_puback.py b/tests/mqtt/test_puback.py index 84b2fd0b..d2f3a179 100644 --- a/tests/mqtt/test_puback.py +++ b/tests/mqtt/test_puback.py @@ -1,6 +1,3 @@ -# Copyright (c) 2015 Nicolas JOUANIN -# -# See the file license.txt for copying permission. import asyncio import unittest diff --git a/tests/mqtt/test_pubcomp.py b/tests/mqtt/test_pubcomp.py index b1bf7a70..86ea7ca8 100644 --- a/tests/mqtt/test_pubcomp.py +++ b/tests/mqtt/test_pubcomp.py @@ -1,6 +1,3 @@ -# Copyright (c) 2015 Nicolas JOUANIN -# -# See the file license.txt for copying permission. import asyncio import unittest diff --git a/tests/mqtt/test_publish.py b/tests/mqtt/test_publish.py index c9477d91..71a115db 100644 --- a/tests/mqtt/test_publish.py +++ b/tests/mqtt/test_publish.py @@ -1,6 +1,3 @@ -# Copyright (c) 2015 Nicolas JOUANIN -# -# See the file license.txt for copying permission. import asyncio import unittest diff --git a/tests/mqtt/test_pubrec.py b/tests/mqtt/test_pubrec.py index 8bc15d29..08f80883 100644 --- a/tests/mqtt/test_pubrec.py +++ b/tests/mqtt/test_pubrec.py @@ -1,6 +1,3 @@ -# Copyright (c) 2015 Nicolas JOUANIN -# -# See the file license.txt for copying permission. import asyncio import unittest diff --git a/tests/mqtt/test_pubrel.py b/tests/mqtt/test_pubrel.py index b96ae5ba..2c89d49c 100644 --- a/tests/mqtt/test_pubrel.py +++ b/tests/mqtt/test_pubrel.py @@ -1,6 +1,3 @@ -# Copyright (c) 2015 Nicolas JOUANIN -# -# See the file license.txt for copying permission. import asyncio import unittest diff --git a/tests/mqtt/test_suback.py b/tests/mqtt/test_suback.py index 579b6700..f7a0b686 100644 --- a/tests/mqtt/test_suback.py +++ b/tests/mqtt/test_suback.py @@ -1,6 +1,3 @@ -# Copyright (c) 2015 Nicolas JOUANIN -# -# See the file license.txt for copying permission. import asyncio import unittest diff --git a/tests/mqtt/test_subscribe.py b/tests/mqtt/test_subscribe.py index e46c72b0..c2a141b8 100644 --- a/tests/mqtt/test_subscribe.py +++ b/tests/mqtt/test_subscribe.py @@ -1,6 +1,3 @@ -# Copyright (c) 2015 Nicolas JOUANIN -# -# See the file license.txt for copying permission. import asyncio import unittest diff --git a/tests/mqtt/test_unsubscribe.py b/tests/mqtt/test_unsubscribe.py index e73d04ea..b4caab36 100644 --- a/tests/mqtt/test_unsubscribe.py +++ b/tests/mqtt/test_unsubscribe.py @@ -1,6 +1,3 @@ -# Copyright (c) 2015 Nicolas JOUANIN -# -# See the file license.txt for copying permission. import asyncio import unittest diff --git a/tests/plugins/test_authentication.py b/tests/plugins/test_authentication.py index 4a27c160..92d79c9f 100644 --- a/tests/plugins/test_authentication.py +++ b/tests/plugins/test_authentication.py @@ -60,7 +60,7 @@ def test_allow(self) -> None: } s = Session() s.username = "user" - s.password = "test" # noqa: S105 + s.password = "test" auth_plugin = FileAuthPlugin(context) ret = self.loop.run_until_complete(auth_plugin.authenticate(session=s)) assert ret @@ -75,7 +75,7 @@ def test_wrong_password(self) -> None: } s = Session() s.username = "user" - s.password = "wrong password" # noqa: S105 + s.password = "wrong password" auth_plugin = FileAuthPlugin(context) ret = self.loop.run_until_complete(auth_plugin.authenticate(session=s)) assert not ret @@ -90,7 +90,7 @@ def test_unknown_password(self) -> None: } s = Session() s.username = "some user" - s.password = "some password" # noqa: S105 + s.password = "some password" auth_plugin = FileAuthPlugin(context) ret = self.loop.run_until_complete(auth_plugin.authenticate(session=s)) assert not ret diff --git a/tests/test_broker.py b/tests/test_broker.py index 87ee98ee..1931863e 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -19,7 +19,7 @@ EVENT_BROKER_PRE_START, ) from amqtt.client import MQTTClient -from amqtt.errors import ConnectException +from amqtt.errors import ConnectError from amqtt.mqtt.connack import ConnackPacket from amqtt.mqtt.connect import ConnectPacket, ConnectPayload, ConnectVariableHeader from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 @@ -176,7 +176,7 @@ async def test_client_connect_clean_session_false(broker): return_code = None try: await client.connect("mqtt://127.0.0.1/", cleansession=False) - except ConnectException as ce: + except ConnectError as ce: return_code = ce.return_code assert return_code == 0x02 assert client.session is not None @@ -361,7 +361,7 @@ async def test_client_publish_acl_forbidden(acl_broker): await sub_client.deliver_message(timeout_duration=1) msg = "Should not have worked" raise AssertionError(msg) - except Exception: # noqa: S110 + except Exception: pass await pub_client.disconnect() @@ -396,7 +396,7 @@ async def test_client_publish_acl_permitted_sub_forbidden(acl_broker): await sub_client2.deliver_message(timeout_duration=1) msg = "Should not have worked" raise AssertionError(msg) - except Exception: # noqa: S110 + except Exception: pass await pub_client.disconnect() @@ -575,7 +575,7 @@ async def test_client_subscribe_publish_dollar_topic_1(broker): message = None try: message = await sub_client.deliver_message(timeout_duration=2) - except Exception: # noqa: S110 + except Exception: pass except RuntimeError as e: # The loop is closed with pending tasks. Needs fine tuning. @@ -601,7 +601,7 @@ async def test_client_subscribe_publish_dollar_topic_2(broker): message = None try: message = await sub_client.deliver_message(timeout_duration=2) - except Exception: # noqa: S110 + except Exception: pass except RuntimeError as e: # The loop is closed with pending tasks. Needs fine tuning. diff --git a/tests/test_cli.py b/tests/test_cli.py index f097cd64..a85a07c5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,16 +3,16 @@ def test_smometest(): amqtt_path = "amqtt" - output = subprocess.check_output([amqtt_path, "--help"]) # noqa: S603 + output = subprocess.check_output([amqtt_path, "--help"]) assert b"Usage" in output assert b"aMQTT" in output amqtt_sub_path = "amqtt_sub" - output = subprocess.check_output([amqtt_sub_path, "--help"]) # noqa: S603 + output = subprocess.check_output([amqtt_sub_path, "--help"]) assert b"Usage" in output assert b"amqtt_sub" in output amqtt_pub_path = "amqtt_pub" - output = subprocess.check_output([amqtt_pub_path, "--help"]) # noqa: S603 + output = subprocess.check_output([amqtt_pub_path, "--help"]) assert b"Usage" in output assert b"amqtt_pub" in output diff --git a/tests/test_client.py b/tests/test_client.py index e7f38c99..10de381a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,7 +4,7 @@ import pytest from amqtt.client import MQTTClient -from amqtt.errors import ConnectException +from amqtt.errors import ConnectError from amqtt.mqtt.constants import QOS_0, QOS_1, QOS_2 formatter = "[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s" @@ -32,7 +32,7 @@ async def test_connect_tcp_failure(): config = {"auto_reconnect": False} client = MQTTClient(config=config) - with pytest.raises(ConnectException): + with pytest.raises(ConnectError): await client.connect("mqtt://127.0.0.1/") diff --git a/tests/test_codecs.py b/tests/test_codecs.py index 514c8398..5640c188 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -2,7 +2,7 @@ import unittest from amqtt.adapters import StreamReaderAdapter -from amqtt.codecs import ( +from amqtt.codecs_a import ( bytes_to_hex_str, bytes_to_int, decode_string,