From 40d1214c790cc737ee6fca9c19dc6a59c0e32ffb Mon Sep 17 00:00:00 2001 From: MVladislav Date: Sun, 12 Jan 2025 21:28:42 +0100 Subject: [PATCH] refactor: checking open pull requests: - #119 :: not working in test cases, only comment - #153 :: updated - #61 :: removed try block as described - #72 :: function included --- amqtt/broker.py | 92 +++++++++++++++++++-------- amqtt/client.py | 3 +- amqtt/mqtt/protocol/client_handler.py | 1 + tests/test_broker.py | 18 ++++++ 4 files changed, 88 insertions(+), 26 deletions(-) diff --git a/amqtt/broker.py b/amqtt/broker.py index 227120d4..9543de07 100644 --- a/amqtt/broker.py +++ b/amqtt/broker.py @@ -36,6 +36,8 @@ "auth": {"allow-anonymous": True, "password-file": None}, } +# Default port numbers +DEFAULT_PORTS = {"tcp": 1883, "ws": 8883} AMQTT_MAGIC_VALUE_RET_SUBSCRIBED = 0x80 @@ -268,10 +270,8 @@ async def start(self) -> None: msg = "Can't read cert files '{}' or '{}' : {}".format(listener["certfile"], listener["keyfile"], fnfe) raise BrokerError(msg) from fnfe - address, s_port = listener["bind"].split(":") - port = 0 try: - port = int(s_port) + address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) except ValueError as e: msg = "Invalid port value in bind value: {}".format(listener["bind"]) raise BrokerError(msg) from e @@ -674,30 +674,28 @@ def retain_message( self.logger.debug(f"Clearing retained messages for topic '{topic_name}'") del self._retained_messages[topic_name] + # NOTE: issue #61 remove try block async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int: - try: - topic_filter, qos = subscription - if "#" in topic_filter and not topic_filter.endswith("#"): - # [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter - return 0x80 - if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter): - # [MQTT-4.7.1-3] + wildcard character must occupy entire level - return 0x80 - # Check if the client is authorised to connect to the topic - if not await self.topic_filtering(session, topic_filter, Action.SUBSCRIBE): - return 0x80 - qos_conf = self.config.get("max-qos", qos) - if isinstance(qos_conf, int): - qos = min(qos, qos_conf) - if topic_filter not in self._subscriptions: - self._subscriptions[topic_filter] = [] - if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]): - self._subscriptions[topic_filter].append((session, qos)) - else: - self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}") - return qos - except KeyError: + topic_filter, qos = subscription + if "#" in topic_filter and not topic_filter.endswith("#"): + # [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter + return 0x80 + if topic_filter != "+" and "+" in topic_filter and ("/+" not in topic_filter and "+/" not in topic_filter): + # [MQTT-4.7.1-3] + wildcard character must occupy entire level return 0x80 + # Check if the client is authorised to connect to the topic + if not await self.topic_filtering(session, topic_filter, Action.SUBSCRIBE): + return 0x80 + qos_conf = self.config.get("max-qos", qos) + if isinstance(qos_conf, int): + qos = min(qos, qos_conf) + if topic_filter not in self._subscriptions: + self._subscriptions[topic_filter] = [] + if all(s.client_id != session.client_id for s, _ in self._subscriptions[topic_filter]): + self._subscriptions[topic_filter].append((session, qos)) + else: + self.logger.debug(f"Client {format_client_message(session=session)} has already subscribed to {topic_filter}") + return qos def _del_subscription(self, a_filter: str, session: Session) -> int: """Delete a session subscription on a given topic. @@ -922,3 +920,47 @@ def _get_handler(self, session: Session) -> BrokerProtocolHandler | None: if client_id: return self._sessions.get(client_id, (None, None))[1] return None + + @classmethod + def _split_bindaddr_port(cls, port_str: str, default_port: int) -> tuple[str | None, int]: + """Split an address:port pair into separate IP address and port. with IPv6 special-case handling. + + NOTE: issue #72 + + - Address can be specified using one of the following methods: + - 1883 - Port number only (listen all interfaces) + - :1883 - Port number only (listen all interfaces) + - 0.0.0.0:1883 - IPv4 address + - [::]:1883 - IPv6 address + - empty string - all interfaces default port + """ + + def _parse_port(port_str: str) -> int: + port_str = port_str.removeprefix(":") + + if not port_str: + return default_port + + return int(port_str) + + if port_str.startswith("["): # IPv6 literal + try: + addr_end = port_str.index("]") + except ValueError as e: + msg = "Expecting '[' to be followed by ']'" + raise ValueError(msg) from e + + return (port_str[0 : addr_end + 1], _parse_port(port_str[addr_end + 1 :])) + + if ":" in port_str: + # Address : port + address, port_str = port_str.rsplit(":", 1) + return (address or None, _parse_port(port_str)) + + # Address or port + try: + # Port number? + return (None, _parse_port(port_str)) + except ValueError: + # Address, default port + return (port_str, default_port) diff --git a/amqtt/client.py b/amqtt/client.py index b14a82e1..afc21a0a 100644 --- a/amqtt/client.py +++ b/amqtt/client.py @@ -526,7 +526,8 @@ def cancel_tasks() -> None: while self.client_tasks: task = self.client_tasks.popleft() if not task.done(): - task.cancel() + # task.set_exception(ClientError("Connection lost")) + task.cancel() # NOTE: issue #153 self.logger.debug("Monitoring broker disconnection") # Wait for disconnection from broker (like connection lost) diff --git a/amqtt/mqtt/protocol/client_handler.py b/amqtt/mqtt/protocol/client_handler.py index 491390f6..b0b58510 100644 --- a/amqtt/mqtt/protocol/client_handler.py +++ b/amqtt/mqtt/protocol/client_handler.py @@ -195,6 +195,7 @@ async def handle_connection_closed(self) -> None: self.logger.debug("Broker closed connection") if self._disconnect_waiter is not None and not self._disconnect_waiter.done(): self._disconnect_waiter.set_result(None) + # await self.stop() # NOTE: issue #119 async def wait_disconnect(self) -> None: if self._disconnect_waiter is not None: diff --git a/tests/test_broker.py b/tests/test_broker.py index 1931863e..0944f9b7 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -17,6 +17,7 @@ EVENT_BROKER_POST_START, EVENT_BROKER_PRE_SHUTDOWN, EVENT_BROKER_PRE_START, + Broker, ) from amqtt.client import MQTTClient from amqtt.errors import ConnectError @@ -44,6 +45,23 @@ async def async_magic(): MagicMock.__await__ = lambda _: async_magic().__await__() +@pytest.mark.parametrize( + "input_str, output_addr, output_port", + [ + ("1234", None, 1234), + (":1234", None, 1234), + ("0.0.0.0:1234", "0.0.0.0", 1234), + ("[::]:1234", "[::]", 1234), + ("0.0.0.0", "0.0.0.0", 5678), + ("[::]", "[::]", 5678), + ("localhost", "localhost", 5678), + ("localhost:1234", "localhost", 1234), + ], +) +def test_split_bindaddr_port(input_str, output_addr, output_port): + assert Broker._split_bindaddr_port(input_str, 5678) == (output_addr, output_port) + + @pytest.mark.asyncio async def test_start_stop(broker, mock_plugin_manager): mock_plugin_manager.assert_has_calls(