Skip to content

Commit

Permalink
Improve protocol concurrency
Browse files Browse the repository at this point in the history
Improve protocol concurrency, use locks to prevent concurrent UDP requests.
  • Loading branch information
mletenay committed Dec 27, 2021
1 parent 069f994 commit f6520df
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 89 deletions.
77 changes: 50 additions & 27 deletions goodwe/inverter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

import asyncio
import io
import logging
from dataclasses import dataclass
Expand Down Expand Up @@ -60,37 +63,57 @@ class Inverter:
"""

def __init__(self, host: str, comm_addr: int = 0, timeout: int = 1, retries: int = 3):
self.host = host
self.comm_addr = comm_addr
self.timeout = timeout
self.retries = retries
self.host: str = host
self.comm_addr: int = comm_addr
self.timeout: int = timeout
self.retries: int = retries
self._running_loop: asyncio.AbstractEventLoop | None = None
self._lock: asyncio.Lock | None = None
self._consecutive_failures_count: int = 0

self.model_name: str = None
self.serial_number: str = None
self.software_version: str = None
self.modbus_version: int = None
self.rated_power: int = None
self.ac_output_type: int = None
self.dsp1_sw_version: int = None
self.dsp2_sw_version: int = None
self.dsp_svn_version: int = None
self.arm_sw_version: int = None
self.arm_svn_version: int = None
self.arm_version: str = None
self.model_name: str | None = None
self.serial_number: str | None = None
self.software_version: str | None = None
self.modbus_version: int | None = None
self.rated_power: int | None = None
self.ac_output_type: int | None = None
self.dsp1_sw_version: int | None = None
self.dsp2_sw_version: int | None = None
self.dsp_svn_version: int | None = None
self.arm_sw_version: int | None = None
self.arm_svn_version: int | None = None
self.arm_version: str | None = None

def _ensure_lock(self) -> None:
"""Validate (or create) asyncio Lock.
The asyncio.Lock must always be created from within's asyncio loop,
so it cannot be eagerly created in constructor.
Additionally, since asyncio.run() creates and closes its own loop,
the lock's scope (its creating loop) mus be verified to support proper
behavior in subsequent asyncio.run() invocations.
"""
if self._lock and self._running_loop == asyncio.get_event_loop():
pass
else:
logger.debug('Creating lock instance for current event loop.')
self._lock = asyncio.Lock()
self._running_loop = asyncio.get_event_loop()

async def _read_from_socket(self, command: ProtocolCommand) -> bytes:
try:
result = await command.execute(self.host, self.timeout, self.retries)
self._consecutive_failures_count = 0
return result
except MaxRetriesException:
self._consecutive_failures_count += 1
raise RequestFailedException(f'No valid response received even after {self.retries} retries',
self._consecutive_failures_count)
except RequestFailedException as ex:
self._consecutive_failures_count += 1
raise RequestFailedException(ex.message, self._consecutive_failures_count)
self._ensure_lock()
async with self._lock:
try:
result = await command.execute(self.host, self.timeout, self.retries)
self._consecutive_failures_count = 0
return result
except MaxRetriesException:
self._consecutive_failures_count += 1
raise RequestFailedException(f'No valid response received even after {self.retries} retries',
self._consecutive_failures_count)
except RequestFailedException as ex:
self._consecutive_failures_count += 1
raise RequestFailedException(ex.message, self._consecutive_failures_count)

async def read_device_info(self):
"""
Expand Down
70 changes: 35 additions & 35 deletions goodwe/protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import asyncio
import logging
from asyncio.futures import Future
from typing import Tuple, Optional, Callable

from .const import GOODWE_UDP_PORT
Expand All @@ -12,68 +15,63 @@
class UdpInverterProtocol(asyncio.DatagramProtocol):
def __init__(
self,
request: bytes,
validator: Callable[[bytes], bool],
on_response_received: asyncio.futures.Future,
command: ProtocolCommand,
timeout: int,
retries: int
):
super().__init__()
self.request: bytes = request
self.validator: Callable[[bytes], bool] = validator
self.on_response_received: asyncio.futures.Future = on_response_received
self.transport: asyncio.transports.DatagramTransport
self.command: ProtocolCommand = command
self._transport: asyncio.transports.DatagramTransport | None = None
self._retry_timeout: int = timeout
self._max_retries: int = retries
self._retries: int = 0

def connection_made(self, transport: asyncio.DatagramTransport) -> None:
"""On connection made"""
self.transport = transport
self._transport = transport
self._send_request()

def _send_request(self) -> None:
"""Send message via transport"""
logger.debug(f'Sent: {self.request.hex()} to {self.transport.get_extra_info("peername")}')
self.transport.sendto(self.request)
asyncio.get_event_loop().call_later(self._retry_timeout, self.retry_mechanism)

def connection_lost(self, exc: Optional[Exception]) -> None:
"""On connection lost"""
if exc is not None:
logger.debug(f'Socket closed with error: {exc}')
# Cancel Future on connection lost
if not self.on_response_received.done():
self.on_response_received.cancel()
if not self.command.response_future.done():
self.command.response_future.cancel()

def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None:
"""On datagram received"""
logger.debug(f'Received: {data.hex()}')
if self.validator(data):
self.on_response_received.set_result(data)
if self.command.validator(data):
logger.debug(f'Received: {data.hex()}')
self.command.response_future.set_result(data)
else:
logger.debug(f'Invalid response: {data.hex()}')
logger.debug(f'Received invalid response: {data.hex()}')
self._retries += 1
self._send_request()

def error_received(self, exc: Exception) -> None:
"""On error received"""
logger.debug(f'Received error: {exc}')
self.on_response_received.set_exception(exc)
self.command.response_future.set_exception(exc)

def retry_mechanism(self):
def _send_request(self) -> None:
"""Send message via transport"""
logger.debug('Sending: %s%s', self.command,
f' - retry #{self._retries}/{self._max_retries}' if self._retries > 0 else '')
self._transport.sendto(self.command.request)
asyncio.get_event_loop().call_later(self._retry_timeout, self._retry_mechanism)

def _retry_mechanism(self) -> None:
"""Retry mechanism to prevent hanging transport"""
# If future is done we can close the transport
if self.on_response_received.done():
self.transport.close()
if self.command.response_future.done():
self._transport.close()
elif self._retries < self._max_retries:
logger.debug('Failed to receive response to %s in time (%ds).', self.command, self._retry_timeout)
self._retries += 1
logger.debug(f'Retry #{self._retries} of {self._max_retries}')
self._send_request()
else:
logger.debug(f'Max number of retries ({self._max_retries}) reached, closing socket')
self.on_response_received.set_exception(MaxRetriesException)
self.transport.close()
logger.debug('Max number of retries (%d) reached, request %s failed.', self._max_retries, self.command)
self.command.response_future.set_exception(MaxRetriesException)


class ProtocolCommand:
Expand All @@ -82,6 +80,10 @@ class ProtocolCommand:
def __init__(self, request: bytes, validator: Callable[[bytes], bool]):
self.request: bytes = request
self.validator: Callable[[bytes], bool] = validator
self.response_future: Future | None = None

def __repr__(self):
return self.request.hex()

async def execute(self, host: str, timeout: int, retries: int) -> bytes:
"""
Expand All @@ -92,16 +94,14 @@ async def execute(self, host: str, timeout: int, retries: int) -> bytes:
Return raw response data
"""
loop = asyncio.get_running_loop()
on_response_received = loop.create_future()
self.response_future = loop.create_future()
transport, _ = await loop.create_datagram_endpoint(
lambda: UdpInverterProtocol(
self.request, self.validator, on_response_received, timeout, retries
),
lambda: UdpInverterProtocol(self, timeout, retries),
remote_addr=(host, GOODWE_UDP_PORT),
)
try:
await on_response_received
result = on_response_received.result()
await self.response_future
result = self.response_future.result()
if result is not None:
return result
else:
Expand Down
16 changes: 12 additions & 4 deletions tests/inverter_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
f"- Version: {inverter.software_version}"
)

#response = asyncio.run(inverter.read_runtime_data(True))
response = asyncio.run(inverter.read_runtime_data(True))

#for sensor in inverter.sensors():
# if sensor.id_ in response:
Expand All @@ -41,7 +41,15 @@
# print(f"{setting.id_}: \t\t {setting.name} = {value} {setting.unit}")

#asyncio.run(inverter.set_operation_mode(2))
response = asyncio.run(inverter.get_operation_mode())
print(response)
#response = asyncio.run(inverter.get_operation_mode())
#print(response)
#response = asyncio.run(inverter.write_setting('grid_export_limit', 3999))
#print(response)
#print(response)

async def run_in_parallel(inverter):
a, b, c, = await asyncio.gather(inverter.get_grid_export_limit(), inverter.get_ongrid_battery_dod(), inverter.read_runtime_data())
print(a)
print(b)
print(c)

asyncio.run(run_in_parallel(inverter))
47 changes: 24 additions & 23 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
from unittest import TestCase, mock

from goodwe.exceptions import MaxRetriesException
from goodwe.protocol import UdpInverterProtocol, ModbusReadCommand, ModbusWriteCommand
from goodwe.protocol import UdpInverterProtocol, ModbusReadCommand, ModbusWriteCommand, ProtocolCommand


class TestUDPClientProtocol(TestCase):
def setUp(self) -> None:
self.future = mock.Mock()
self.command = ProtocolCommand(bytes.fromhex('636f666665650d0a'), lambda x: True)
self.command.response_future = mock.Mock()
# self.processor = mock.Mock()
self.protocol = UdpInverterProtocol(bytes.fromhex('636f666665650d0a'), lambda x: True, self.future, 1, 3)
self.protocol = UdpInverterProtocol(self.command, 1, 3)

def test_datagram_received(self):
data = b'this is mock data'
self.protocol.datagram_received(data, ('127.0.0.1', 1337))
self.future.set_result.assert_called_once()
self.command.response_future.set_result.assert_called_once()
# self.processor.assert_called_once_with(data)

# def test_datagram_received_process_exception(self):
Expand All @@ -29,7 +30,7 @@ def test_datagram_received(self):
def test_error_received(self):
exc = Exception('something went wrong')
self.protocol.error_received(exc)
self.future.set_exception.assert_called_once_with(exc)
self.command.response_future.set_exception.assert_called_once_with(exc)

@mock.patch('goodwe.protocol.asyncio.get_event_loop')
def test_connection_made(self, mock_get_event_loop):
Expand All @@ -38,30 +39,30 @@ def test_connection_made(self, mock_get_event_loop):
mock_get_event_loop.return_value = mock_loop

mock_retry_mechanism = mock.Mock()
self.protocol.retry_mechanism = mock_retry_mechanism
self.protocol._retry_mechanism = mock_retry_mechanism
self.protocol.connection_made(transport)

transport.sendto.assert_called_with(self.protocol.request)
transport.sendto.assert_called_with(self.protocol.command.request)
mock_get_event_loop.assert_called()
mock_loop.call_later.assert_called_with(1, mock_retry_mechanism)

def test_connection_lost(self):
self.future.done.return_value = True
self.command.response_future.done.return_value = True
self.protocol.connection_lost(None)
self.future.cancel.assert_not_called()
self.command.response_future.cancel.assert_not_called()

def test_connection_lost_not_done(self):
self.future.done.return_value = False
self.command.response_future.done.return_value = False
self.protocol.connection_lost(None)
self.future.cancel.assert_called()
self.command.response_future.cancel.assert_called()

def test_retry_mechanism(self):
self.protocol.transport = mock.Mock()
self.protocol._transport = mock.Mock()
self.protocol._send_message = mock.Mock()
self.future.done.return_value = True
self.protocol.retry_mechanism()
self.command.response_future.done.return_value = True
self.protocol._retry_mechanism()

self.protocol.transport.close.assert_called()
self.protocol._transport.close.assert_called()
self.protocol._send_message.assert_not_called()

@mock.patch('goodwe.protocol.asyncio.get_event_loop')
Expand All @@ -73,11 +74,11 @@ def call_later(_: int, retry_func: Callable):
mock_get_event_loop.return_value = mock_loop
mock_loop.call_later = call_later

self.protocol.transport = mock.Mock()
self.future.done.side_effect = [False, False, True]
self.protocol.retry_mechanism()
self.protocol._transport = mock.Mock()
self.command.response_future.done.side_effect = [False, False, True]
self.protocol._retry_mechanism()

self.protocol.transport.close.assert_called()
self.protocol._transport.close.assert_called()
self.assertEqual(self.protocol._retries, 2)

@mock.patch('goodwe.protocol.asyncio.get_event_loop')
Expand All @@ -89,10 +90,10 @@ def call_later(_: int, retry_func: Callable):
mock_get_event_loop.return_value = mock_loop
mock_loop.call_later = call_later

self.protocol.transport = mock.Mock()
self.future.done.side_effect = [False, False, False, False, False]
self.protocol.retry_mechanism()
self.future.set_exception.assert_called_once_with(MaxRetriesException)
self.protocol._transport = mock.Mock()
self.command.response_future.done.side_effect = [False, False, False, False, False]
self.protocol._retry_mechanism()
self.command.response_future.set_exception.assert_called_once_with(MaxRetriesException)
self.assertEqual(self.protocol._retries, 3)

def test_modbus_read_command(self):
Expand Down

0 comments on commit f6520df

Please sign in to comment.