From f6520df9e19b8bd35c2cf8c7e434034ea8f62a80 Mon Sep 17 00:00:00 2001 From: mle Date: Sun, 26 Dec 2021 23:57:32 +0100 Subject: [PATCH] Improve protocol concurrency Improve protocol concurrency, use locks to prevent concurrent UDP requests. --- goodwe/inverter.py | 77 ++++++++++++++++++++++++++--------------- goodwe/protocol.py | 70 ++++++++++++++++++------------------- tests/inverter_check.py | 16 ++++++--- tests/test_protocol.py | 47 +++++++++++++------------ 4 files changed, 121 insertions(+), 89 deletions(-) diff --git a/goodwe/inverter.py b/goodwe/inverter.py index fee71fc..a0e6ef7 100644 --- a/goodwe/inverter.py +++ b/goodwe/inverter.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import asyncio import io import logging from dataclasses import dataclass @@ -60,37 +63,57 @@ class Inverter: """ def __init__(self, host: str, comm_addr: int = 0, timeout: int = 1, retries: int = 3): - self.host = host - self.comm_addr = comm_addr - self.timeout = timeout - self.retries = retries + self.host: str = host + self.comm_addr: int = comm_addr + self.timeout: int = timeout + self.retries: int = retries + self._running_loop: asyncio.AbstractEventLoop | None = None + self._lock: asyncio.Lock | None = None self._consecutive_failures_count: int = 0 - self.model_name: str = None - self.serial_number: str = None - self.software_version: str = None - self.modbus_version: int = None - self.rated_power: int = None - self.ac_output_type: int = None - self.dsp1_sw_version: int = None - self.dsp2_sw_version: int = None - self.dsp_svn_version: int = None - self.arm_sw_version: int = None - self.arm_svn_version: int = None - self.arm_version: str = None + self.model_name: str | None = None + self.serial_number: str | None = None + self.software_version: str | None = None + self.modbus_version: int | None = None + self.rated_power: int | None = None + self.ac_output_type: int | None = None + self.dsp1_sw_version: int | None = None + self.dsp2_sw_version: int | None = None + self.dsp_svn_version: int | None = None + self.arm_sw_version: int | None = None + self.arm_svn_version: int | None = None + self.arm_version: str | None = None + + def _ensure_lock(self) -> None: + """Validate (or create) asyncio Lock. + + The asyncio.Lock must always be created from within's asyncio loop, + so it cannot be eagerly created in constructor. + Additionally, since asyncio.run() creates and closes its own loop, + the lock's scope (its creating loop) mus be verified to support proper + behavior in subsequent asyncio.run() invocations. + """ + if self._lock and self._running_loop == asyncio.get_event_loop(): + pass + else: + logger.debug('Creating lock instance for current event loop.') + self._lock = asyncio.Lock() + self._running_loop = asyncio.get_event_loop() async def _read_from_socket(self, command: ProtocolCommand) -> bytes: - try: - result = await command.execute(self.host, self.timeout, self.retries) - self._consecutive_failures_count = 0 - return result - except MaxRetriesException: - self._consecutive_failures_count += 1 - raise RequestFailedException(f'No valid response received even after {self.retries} retries', - self._consecutive_failures_count) - except RequestFailedException as ex: - self._consecutive_failures_count += 1 - raise RequestFailedException(ex.message, self._consecutive_failures_count) + self._ensure_lock() + async with self._lock: + try: + result = await command.execute(self.host, self.timeout, self.retries) + self._consecutive_failures_count = 0 + return result + except MaxRetriesException: + self._consecutive_failures_count += 1 + raise RequestFailedException(f'No valid response received even after {self.retries} retries', + self._consecutive_failures_count) + except RequestFailedException as ex: + self._consecutive_failures_count += 1 + raise RequestFailedException(ex.message, self._consecutive_failures_count) async def read_device_info(self): """ diff --git a/goodwe/protocol.py b/goodwe/protocol.py index d7d11b1..44428ed 100644 --- a/goodwe/protocol.py +++ b/goodwe/protocol.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import asyncio import logging +from asyncio.futures import Future from typing import Tuple, Optional, Callable from .const import GOODWE_UDP_PORT @@ -12,68 +15,63 @@ class UdpInverterProtocol(asyncio.DatagramProtocol): def __init__( self, - request: bytes, - validator: Callable[[bytes], bool], - on_response_received: asyncio.futures.Future, + command: ProtocolCommand, timeout: int, retries: int ): super().__init__() - self.request: bytes = request - self.validator: Callable[[bytes], bool] = validator - self.on_response_received: asyncio.futures.Future = on_response_received - self.transport: asyncio.transports.DatagramTransport + self.command: ProtocolCommand = command + self._transport: asyncio.transports.DatagramTransport | None = None self._retry_timeout: int = timeout self._max_retries: int = retries self._retries: int = 0 def connection_made(self, transport: asyncio.DatagramTransport) -> None: """On connection made""" - self.transport = transport + self._transport = transport self._send_request() - def _send_request(self) -> None: - """Send message via transport""" - logger.debug(f'Sent: {self.request.hex()} to {self.transport.get_extra_info("peername")}') - self.transport.sendto(self.request) - asyncio.get_event_loop().call_later(self._retry_timeout, self.retry_mechanism) - def connection_lost(self, exc: Optional[Exception]) -> None: """On connection lost""" if exc is not None: logger.debug(f'Socket closed with error: {exc}') # Cancel Future on connection lost - if not self.on_response_received.done(): - self.on_response_received.cancel() + if not self.command.response_future.done(): + self.command.response_future.cancel() def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: """On datagram received""" - logger.debug(f'Received: {data.hex()}') - if self.validator(data): - self.on_response_received.set_result(data) + if self.command.validator(data): + logger.debug(f'Received: {data.hex()}') + self.command.response_future.set_result(data) else: - logger.debug(f'Invalid response: {data.hex()}') + logger.debug(f'Received invalid response: {data.hex()}') self._retries += 1 self._send_request() def error_received(self, exc: Exception) -> None: """On error received""" logger.debug(f'Received error: {exc}') - self.on_response_received.set_exception(exc) + self.command.response_future.set_exception(exc) - def retry_mechanism(self): + def _send_request(self) -> None: + """Send message via transport""" + logger.debug('Sending: %s%s', self.command, + f' - retry #{self._retries}/{self._max_retries}' if self._retries > 0 else '') + self._transport.sendto(self.command.request) + asyncio.get_event_loop().call_later(self._retry_timeout, self._retry_mechanism) + + def _retry_mechanism(self) -> None: """Retry mechanism to prevent hanging transport""" - # If future is done we can close the transport - if self.on_response_received.done(): - self.transport.close() + if self.command.response_future.done(): + self._transport.close() elif self._retries < self._max_retries: + logger.debug('Failed to receive response to %s in time (%ds).', self.command, self._retry_timeout) self._retries += 1 - logger.debug(f'Retry #{self._retries} of {self._max_retries}') self._send_request() else: - logger.debug(f'Max number of retries ({self._max_retries}) reached, closing socket') - self.on_response_received.set_exception(MaxRetriesException) - self.transport.close() + logger.debug('Max number of retries (%d) reached, request %s failed.', self._max_retries, self.command) + self.command.response_future.set_exception(MaxRetriesException) class ProtocolCommand: @@ -82,6 +80,10 @@ class ProtocolCommand: def __init__(self, request: bytes, validator: Callable[[bytes], bool]): self.request: bytes = request self.validator: Callable[[bytes], bool] = validator + self.response_future: Future | None = None + + def __repr__(self): + return self.request.hex() async def execute(self, host: str, timeout: int, retries: int) -> bytes: """ @@ -92,16 +94,14 @@ async def execute(self, host: str, timeout: int, retries: int) -> bytes: Return raw response data """ loop = asyncio.get_running_loop() - on_response_received = loop.create_future() + self.response_future = loop.create_future() transport, _ = await loop.create_datagram_endpoint( - lambda: UdpInverterProtocol( - self.request, self.validator, on_response_received, timeout, retries - ), + lambda: UdpInverterProtocol(self, timeout, retries), remote_addr=(host, GOODWE_UDP_PORT), ) try: - await on_response_received - result = on_response_received.result() + await self.response_future + result = self.response_future.result() if result is not None: return result else: diff --git a/tests/inverter_check.py b/tests/inverter_check.py index a4e32d4..f9afd85 100644 --- a/tests/inverter_check.py +++ b/tests/inverter_check.py @@ -25,7 +25,7 @@ f"- Version: {inverter.software_version}" ) -#response = asyncio.run(inverter.read_runtime_data(True)) +response = asyncio.run(inverter.read_runtime_data(True)) #for sensor in inverter.sensors(): # if sensor.id_ in response: @@ -41,7 +41,15 @@ # print(f"{setting.id_}: \t\t {setting.name} = {value} {setting.unit}") #asyncio.run(inverter.set_operation_mode(2)) -response = asyncio.run(inverter.get_operation_mode()) -print(response) +#response = asyncio.run(inverter.get_operation_mode()) +#print(response) #response = asyncio.run(inverter.write_setting('grid_export_limit', 3999)) -#print(response) \ No newline at end of file +#print(response) + +async def run_in_parallel(inverter): + a, b, c, = await asyncio.gather(inverter.get_grid_export_limit(), inverter.get_ongrid_battery_dod(), inverter.read_runtime_data()) + print(a) + print(b) + print(c) + +asyncio.run(run_in_parallel(inverter)) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index f854b54..63ee244 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -2,19 +2,20 @@ from unittest import TestCase, mock from goodwe.exceptions import MaxRetriesException -from goodwe.protocol import UdpInverterProtocol, ModbusReadCommand, ModbusWriteCommand +from goodwe.protocol import UdpInverterProtocol, ModbusReadCommand, ModbusWriteCommand, ProtocolCommand class TestUDPClientProtocol(TestCase): def setUp(self) -> None: - self.future = mock.Mock() + self.command = ProtocolCommand(bytes.fromhex('636f666665650d0a'), lambda x: True) + self.command.response_future = mock.Mock() # self.processor = mock.Mock() - self.protocol = UdpInverterProtocol(bytes.fromhex('636f666665650d0a'), lambda x: True, self.future, 1, 3) + self.protocol = UdpInverterProtocol(self.command, 1, 3) def test_datagram_received(self): data = b'this is mock data' self.protocol.datagram_received(data, ('127.0.0.1', 1337)) - self.future.set_result.assert_called_once() + self.command.response_future.set_result.assert_called_once() # self.processor.assert_called_once_with(data) # def test_datagram_received_process_exception(self): @@ -29,7 +30,7 @@ def test_datagram_received(self): def test_error_received(self): exc = Exception('something went wrong') self.protocol.error_received(exc) - self.future.set_exception.assert_called_once_with(exc) + self.command.response_future.set_exception.assert_called_once_with(exc) @mock.patch('goodwe.protocol.asyncio.get_event_loop') def test_connection_made(self, mock_get_event_loop): @@ -38,30 +39,30 @@ def test_connection_made(self, mock_get_event_loop): mock_get_event_loop.return_value = mock_loop mock_retry_mechanism = mock.Mock() - self.protocol.retry_mechanism = mock_retry_mechanism + self.protocol._retry_mechanism = mock_retry_mechanism self.protocol.connection_made(transport) - transport.sendto.assert_called_with(self.protocol.request) + transport.sendto.assert_called_with(self.protocol.command.request) mock_get_event_loop.assert_called() mock_loop.call_later.assert_called_with(1, mock_retry_mechanism) def test_connection_lost(self): - self.future.done.return_value = True + self.command.response_future.done.return_value = True self.protocol.connection_lost(None) - self.future.cancel.assert_not_called() + self.command.response_future.cancel.assert_not_called() def test_connection_lost_not_done(self): - self.future.done.return_value = False + self.command.response_future.done.return_value = False self.protocol.connection_lost(None) - self.future.cancel.assert_called() + self.command.response_future.cancel.assert_called() def test_retry_mechanism(self): - self.protocol.transport = mock.Mock() + self.protocol._transport = mock.Mock() self.protocol._send_message = mock.Mock() - self.future.done.return_value = True - self.protocol.retry_mechanism() + self.command.response_future.done.return_value = True + self.protocol._retry_mechanism() - self.protocol.transport.close.assert_called() + self.protocol._transport.close.assert_called() self.protocol._send_message.assert_not_called() @mock.patch('goodwe.protocol.asyncio.get_event_loop') @@ -73,11 +74,11 @@ def call_later(_: int, retry_func: Callable): mock_get_event_loop.return_value = mock_loop mock_loop.call_later = call_later - self.protocol.transport = mock.Mock() - self.future.done.side_effect = [False, False, True] - self.protocol.retry_mechanism() + self.protocol._transport = mock.Mock() + self.command.response_future.done.side_effect = [False, False, True] + self.protocol._retry_mechanism() - self.protocol.transport.close.assert_called() + self.protocol._transport.close.assert_called() self.assertEqual(self.protocol._retries, 2) @mock.patch('goodwe.protocol.asyncio.get_event_loop') @@ -89,10 +90,10 @@ def call_later(_: int, retry_func: Callable): mock_get_event_loop.return_value = mock_loop mock_loop.call_later = call_later - self.protocol.transport = mock.Mock() - self.future.done.side_effect = [False, False, False, False, False] - self.protocol.retry_mechanism() - self.future.set_exception.assert_called_once_with(MaxRetriesException) + self.protocol._transport = mock.Mock() + self.command.response_future.done.side_effect = [False, False, False, False, False] + self.protocol._retry_mechanism() + self.command.response_future.set_exception.assert_called_once_with(MaxRetriesException) self.assertEqual(self.protocol._retries, 3) def test_modbus_read_command(self):