From ae73d044df33b4cc79e8c9e69f090f2f2f619a18 Mon Sep 17 00:00:00 2001 From: Tobias McNulty Date: Wed, 4 Dec 2024 08:54:17 -0500 Subject: [PATCH] Increase timeout and disable Postgres notify on send (#27) Co-authored-by: Colin Copeland Co-authored-by: Simon Kagwi --- src/smpp_gateway/client.py | 10 ++-- .../management/commands/smpp_client.py | 13 +++++ src/smpp_gateway/outgoing.py | 8 ++-- src/smpp_gateway/smpp.py | 7 +++ tests/test_client.py | 10 ++++ tests/test_router.py | 48 ++++++++++++++++++- 6 files changed, 89 insertions(+), 7 deletions(-) diff --git a/src/smpp_gateway/client.py b/src/smpp_gateway/client.py index 09467e6..44462fd 100644 --- a/src/smpp_gateway/client.py +++ b/src/smpp_gateway/client.py @@ -71,6 +71,7 @@ def __init__( submit_sm_params: dict, set_priority_flag: bool, mt_messages_per_second: int, + event_loop_timeout: int, *args, **kwargs, ): @@ -81,6 +82,7 @@ def __init__( self.submit_sm_params = submit_sm_params self.set_priority_flag = set_priority_flag self.mt_messages_per_second = mt_messages_per_second + self.event_loop_timeout = event_loop_timeout super().__init__(*args, **kwargs) self._pg_conn = pg_listen(self.backend.name) @@ -176,11 +178,13 @@ def receive_pg_notify(self): self.send_mt_messages() def send_mt_messages(self): - limit = self.mt_messages_per_second * self.timeout + limit = self.mt_messages_per_second * self.event_loop_timeout smses = get_mt_messages_to_send(limit=limit, backend=self.backend) if len(smses) == 0: return - logger.info(f"Found {len(smses)} messages to send in {self.timeout} seconds") + logger.info( + f"Found {len(smses)} messages to send in {self.event_loop_timeout} seconds" + ) submit_sm_resps = [] for sms in smses: params = {**self.submit_sm_params, **sms["params"]} @@ -237,7 +241,7 @@ def listen(self, ignore_error_codes=None, auto_send_enquire_link=True): while True: # When either main socket has data or _pg_conn has data, select.select will return rlist, _, _ = select.select( - [self._socket, self._pg_conn], [], [], self.timeout + [self._socket, self._pg_conn], [], [], self.event_loop_timeout ) if not rlist and auto_send_enquire_link: self.logger.debug("Socket timeout, listening again") diff --git a/src/smpp_gateway/management/commands/smpp_client.py b/src/smpp_gateway/management/commands/smpp_client.py index be422bd..f348e14 100644 --- a/src/smpp_gateway/management/commands/smpp_client.py +++ b/src/smpp_gateway/management/commands/smpp_client.py @@ -57,6 +57,19 @@ def add_arguments(self, parser): type=int, default=os.environ.get("SMPPLIB_MT_MESSAGES_PER_SECOND", 20), ) + parser.add_argument( + "--socket-timeout", + type=int, + default=os.environ.get("SMPPLIB_SOCKET_TIMEOUT", 30), + ) + parser.add_argument( + "--event-loop-timeout", + type=int, + default=os.environ.get("SMPPLIB_EVENT_LOOP_TIMEOUT", 5), + help="Timeout for listening for Postgres notifications and new " + "incoming messages. This is also the time between enquire_link " + "PDUs sent to the SMPP server when there is no other traffic.", + ) parser.add_argument( "--database-url", default=os.environ.get("DATABASE_URL"), diff --git a/src/smpp_gateway/outgoing.py b/src/smpp_gateway/outgoing.py index 12878f9..89d183b 100644 --- a/src/smpp_gateway/outgoing.py +++ b/src/smpp_gateway/outgoing.py @@ -1,10 +1,10 @@ import logging -from django.db import connection from django.utils import timezone from rapidsms.backends.base import BackendBase from smpp_gateway.models import MTMessage +from smpp_gateway.queries import pg_notify from smpp_gateway.utils import grouper logger = logging.getLogger(__name__) @@ -16,6 +16,8 @@ class SMPPGatewayBackend(BackendBase): # Optional additional params from: # https://github.com/python-smpplib/python-smpplib/blob/d9d91beb2d7f37915b13a064bb93f907379342ec/smpplib/command.py#L652-L700 OPTIONAL_PARAMS = ("source_addr",) + # The minimum priority_flag value for which to send a Postgres notification + minimum_notify_priority_flag = MTMessage.PriorityFlag.LEVEL_2.value def configure(self, **kwargs): self.send_group_size = kwargs.get("send_group_size", 100) @@ -48,5 +50,5 @@ def send(self, id_, text, identities, context=None): MTMessage.objects.bulk_create( [MTMessage(**kwargs) for kwargs in kwargs_group] ) - with connection.cursor() as curs: - curs.execute(f"NOTIFY {self.model.name}") + if context.get("priority_flag", 0) >= self.minimum_notify_priority_flag: + pg_notify(self.model.name) diff --git a/src/smpp_gateway/smpp.py b/src/smpp_gateway/smpp.py index 1c1d401..9796d1f 100644 --- a/src/smpp_gateway/smpp.py +++ b/src/smpp_gateway/smpp.py @@ -20,6 +20,8 @@ def get_smpplib_client( submit_sm_params: dict, set_priority_flag: bool, mt_messages_per_second: int, + socket_timeout: int, + event_loop_timeout: int, hc_check_uuid: str, hc_ping_key: str, hc_check_slug: str, @@ -38,10 +40,13 @@ def get_smpplib_client( submit_sm_params, set_priority_flag, mt_messages_per_second, + event_loop_timeout, + # Pass-through arguments to smpplib.client.Client: host, port, allow_unknown_opt_params=True, sequence_generator=sequence_generator, + timeout=socket_timeout, ) return client @@ -73,6 +78,8 @@ def start_smpp_client(options): json.loads(options["submit_sm_params"]), options["set_priority_flag"], options["mt_messages_per_second"], + options["socket_timeout"], + options["event_loop_timeout"], options["hc_check_uuid"], options["hc_ping_key"], options["hc_check_slug"], diff --git a/tests/test_client.py b/tests/test_client.py index 3331ac5..b3167d5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -26,6 +26,8 @@ def test_received_mo_message(self): {}, # submit_sm_params False, # set_priority_flag 20, # mt_messages_per_second + 30, # socket_timeout + 5, # event_loop_timeout "", # hc_check_uuid "", # hc_ping_key "", # hc_check_slug @@ -63,6 +65,8 @@ def test_received_message_receipt(self): {}, # submit_sm_params False, # set_priority_flag 20, # mt_messages_per_second + 30, # socket_timeout + 5, # event_loop_timeout "", # hc_check_uuid "", # hc_ping_key "", # hc_check_slug @@ -106,6 +110,8 @@ def test_received_null_short_message(self): {}, # submit_sm_params False, # set_priority_flag 20, # mt_messages_per_second + 30, # socket_timeout + 5, # event_loop_timeout "", # hc_check_uuid "", # hc_ping_key "", # hc_check_slug @@ -145,6 +151,8 @@ def test_message_sent_handler(): {}, # submit_sm_params False, # set_priority_flag 20, # mt_messages_per_second + 30, # socket_timeout + 5, # event_loop_timeout "", # hc_check_uuid "", # hc_ping_key "", # hc_check_slug @@ -183,6 +191,8 @@ def get_client_and_message( submit_sm_params or {}, set_priority_flag, 20, # mt_messages_per_second + 30, # socket_timeout + 5, # event_loop_timeout "", # hc_check_uuid "", # hc_ping_key "", # hc_check_slug diff --git a/tests/test_router.py b/tests/test_router.py index 432f81e..cb678d8 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -1,6 +1,10 @@ +from unittest.mock import patch + +from django.conf import settings from django.test import TestCase from django.test.utils import override_settings +from smpp_gateway.models import MTMessage from smpp_gateway.router import PriorityBlockingRouter from .factories import ConnectionFactory @@ -15,7 +19,9 @@ ) class PriorityBlockingRouterTest(TestCase): def setUp(self): - self.router = PriorityBlockingRouter(apps=[], backends={}) + self.router = PriorityBlockingRouter( + apps=[], backends=settings.INSTALLED_BACKENDS + ) self.connection = ConnectionFactory() def test_new_incoming_message(self): @@ -70,3 +76,43 @@ def test_outgoing_message_extra_backend_context_has_priority_flag(self): ) context = msg.extra_backend_context() self.assertEqual(context["priority_flag"], msg.default_priority_flag.value) + + def test_no_postgres_notification_for_low_priority_messages(self): + """Tests that a Postgres NOTIFY is not done for messages where the + priority_flag is less than 2. + """ + for priority in MTMessage.PriorityFlag.values[:2]: + msg = self.router.new_outgoing_message( + text="foo", + connections=[self.connection], + fields={"priority_flag": priority}, + ) + with patch("smpp_gateway.outgoing.pg_notify") as mock_pg_notify: + self.router.send_to_backend( + backend_name="smppsim", + id_=msg.id, + text=msg.text, + identities=[self.connection.identity], + context=msg.fields, + ) + mock_pg_notify.assert_not_called() + + def test_postgres_notification_for_high_priority_messages(self): + """Tests that a Postgres NOTIFY is done for messages where the + priority_flag is at least 2. + """ + for priority in MTMessage.PriorityFlag.values[2:]: + msg = self.router.new_outgoing_message( + text="foo", + connections=[self.connection], + fields={"priority_flag": priority}, + ) + with patch("smpp_gateway.outgoing.pg_notify") as mock_pg_notify: + self.router.send_to_backend( + backend_name="smppsim", + id_=msg.id, + text=msg.text, + identities=[self.connection.identity], + context=msg.fields, + ) + mock_pg_notify.assert_called_with("smppsim")