Skip to content

Commit

Permalink
Background tasks for MQTT (#120)
Browse files Browse the repository at this point in the history
* add FastAPI lifespan handler for connecting to MongoDB and for published messages to MQTT

* add metrics for MQTT queue drops

* fix type of app lifespan contexthandler

* gather tasks even if they are cancelled
  • Loading branch information
jschlyter authored Nov 25, 2024
1 parent 6a96b37 commit a5db5ff
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 21 deletions.
21 changes: 12 additions & 9 deletions aggrec/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import logging
import re
Expand All @@ -10,7 +11,6 @@
import bson
import pendulum
import pymongo
from aiomqtt.exceptions import MqttError
from bson.objectid import ObjectId
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -43,6 +43,11 @@
description="The number of duplicate aggregates received",
)

aggregates_mqtt_queue_drops = meter.create_counter(
"aggregates.mqtt_queue_drops",
description="MQTT messages dropped due to full queue",
)


METADATA_HTTP_HEADERS = [
"User-Agent",
Expand Down Expand Up @@ -341,14 +346,12 @@ async def create_aggregate(
aggregates_by_creator_counter.add(1, {"aggregate_type": aggregate_type.value, "creator": creator})

try:
async with request.app.get_mqtt_client() as mqtt_client:
with tracer.start_as_current_span("mqtt.publish"):
await mqtt_client.publish(
request.app.settings.mqtt.topic,
json.dumps(get_new_aggregate_event_message(metadata, request.app.settings)),
)
except MqttError:
logger.warning("Failed to publish new aggregate to MQTT")
request.app.mqtt_new_aggregate_messages.put_nowait(
json.dumps(get_new_aggregate_event_message(metadata, request.app.settings))
)
except asyncio.QueueFull:
aggregates_mqtt_queue_drops.add(1)
logger.warning("MQTT queue full, message dropped")

return Response(status_code=status.HTTP_201_CREATED, headers={"Location": metadata_location})

Expand Down
64 changes: 52 additions & 12 deletions aggrec/server.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import argparse
import asyncio
import logging
from contextlib import asynccontextmanager

import aiobotocore.session
import aiomqtt
import boto3
import mongoengine
import uvicorn
from aiomqtt.exceptions import MqttError
from fastapi import FastAPI
from opentelemetry import trace
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware

import aggrec.aggregates
Expand All @@ -21,12 +25,14 @@

logger = logging.getLogger(__name__)

tracer = trace.get_tracer("aggrec.tracer")


class AggrecServer(FastAPI):
def __init__(self, settings: Settings):
self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__)
self.settings = settings
super().__init__(**OPENAPI_METADATA)
super().__init__(**OPENAPI_METADATA, lifespan=self.lifespan)
self.add_middleware(ProxyHeadersMiddleware)
self.include_router(aggrec.aggregates.router)
self.include_router(aggrec.extras.router)
Expand All @@ -42,18 +48,19 @@ def __init__(self, settings: Settings):
self.key_resolver = key_resolver_from_client_database(
client_database=str(self.settings.clients_database), key_cache=key_cache
)
self.mqtt_new_aggregate_messages = asyncio.Queue(maxsize=self.settings.mqtt.queue_size)

@staticmethod
def connect_mongodb(settings: Settings):
if mongodb_host := str(settings.mongodb.server):
def connect_mongodb(self):
if mongodb_host := str(self.settings.mongodb.server):
params = {"host": mongodb_host}
if "host" in params and params["host"].startswith("mongomock://"):
import mongomock

params["host"] = params["host"].replace("mongomock://", "mongodb://")
params["mongo_client_class"] = mongomock.MongoClient
logger.info("Mongoengine connect %s", params)
logger.info("Connecting to MongoDB %s", params)
mongoengine.connect(**params, tz_aware=True)
logger.info("MongoDB connected")

def get_mqtt_client(self) -> aiomqtt.Client:
client = aiomqtt.Client(
Expand All @@ -78,12 +85,44 @@ def get_s3_client(self) -> aiobotocore.session.ClientCreatorContext:
self.logger.debug("Created S3 client %s", client)
return client

@classmethod
def factory(cls):
logger.info("Starting Aggregate Receiver version %s", __verbose_version__)
app = cls(settings=Settings())
app.connect_mongodb(app.settings)
return app
async def mqtt_publisher(self):
"""Task for publishing enqueued MQTT messages"""
_logger = self.logger.getChild("mqtt_publisher")
_logger.debug("Starting MQTT publish task")
while True:
try:
async with self.get_mqtt_client() as mqtt_client:
_logger.info("Connected to MQTT broker")
while True:
message = await self.mqtt_new_aggregate_messages.get()
_logger.debug("Publishing new aggregate message on %s", self.settings.mqtt.topic)
with tracer.start_as_current_span("mqtt.publish"):
await mqtt_client.publish(
self.settings.mqtt.topic,
message,
)
except MqttError as exc:
_logger.error("MQTT error: %s", str(exc))
except asyncio.exceptions.CancelledError:
_logger.debug("MQTT publish task cancelled")
return
_logger.info("Reconnecting to MQTT broker in %d seconds", self.settings.mqtt.reconnect_interval)
await asyncio.sleep(self.settings.mqtt.reconnect_interval)

@staticmethod
@asynccontextmanager
async def lifespan(app: "AggrecServer"):
app.logger.debug("Lifespan startup")
app.connect_mongodb()
tasks = []
tasks.append(asyncio.create_task(app.mqtt_publisher()))
logger.debug("Background tasks started")
yield
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug("All background tasks cancelled")
app.logger.debug("Lifespan ended")


def main() -> None:
Expand Down Expand Up @@ -111,7 +150,8 @@ def main() -> None:
logging.basicConfig(level=logging.INFO)
log_level = "info"

app = AggrecServer.factory()
logger.info("Starting Aggregate Receiver version %s", __verbose_version__)
app = AggrecServer(settings=Settings())

uvicorn.run(
app,
Expand Down
1 change: 1 addition & 0 deletions aggrec/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class MqttSettings(BaseModel):
password: str | None = None
topic: str = Field(default="aggregates")
reconnect_interval: int = Field(default=5)
queue_size: int = Field(default=1024)


class MongoDB(BaseModel):
Expand Down

0 comments on commit a5db5ff

Please sign in to comment.