diff --git a/audit_logger/main.py b/audit_logger/main.py index 283bab8..3fe9480 100644 --- a/audit_logger/main.py +++ b/audit_logger/main.py @@ -2,9 +2,10 @@ from contextlib import asynccontextmanager from typing import Any, AsyncGenerator, Dict, List, Optional, cast -from fastapi import Body, FastAPI, HTTPException +from fastapi import Body, Depends, FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse +from fastapi.security import APIKeyHeader from audit_logger.config_manager import ConfigManager from audit_logger.custom_logger import get_logger @@ -63,12 +64,20 @@ async def lifespan(_: Any) -> AsyncGenerator[None, None]: lifespan=lifespan, ) +api_key_header = APIKeyHeader(name="X-API-Key") + app.add_exception_handler(RequestValidationError, validation_exception_handler) add_middleware(app, app_config) -@app.post("/create") +async def verify_api_key(api_key: str = Depends(api_key_header)): + if api_key != app_config.authentication.api_key: + raise HTTPException(status_code=401, detail="Invalid API-Key") + return api_key + + +@app.post("/create", dependencies=[Depends(verify_api_key)]) async def create_audit_log_entry( audit_log: AuditLogEntry = Body(...), ) -> GenericResponse: @@ -92,7 +101,7 @@ async def create_audit_log_entry( ) -@app.post("/create-bulk") +@app.post("/create-bulk", dependencies=[Depends(verify_api_key)]) async def create_bulk_audit_log_entries( audit_logs: List[AuditLogEntry] = Body(...), ) -> GenericResponse: @@ -119,7 +128,7 @@ async def create_bulk_audit_log_entries( ) -@app.post("/create/create-bulk-auto") +@app.post("/create/create-bulk-auto", dependencies=[Depends(verify_api_key)]) async def create_fake_audit_log_entries( options: BulkAuditLogOptions, ) -> GenericResponse: @@ -139,7 +148,7 @@ async def create_fake_audit_log_entries( ) -@app.post("/search") +@app.post("/search", dependencies=[Depends(verify_api_key)]) def search_audit_log_entries( params: Optional[SearchParamsV2] = Body(default=None), ) -> SearchResults: diff --git a/audit_logger/models/config.py b/audit_logger/models/config.py index 6087e59..6d31837 100644 --- a/audit_logger/models/config.py +++ b/audit_logger/models/config.py @@ -28,7 +28,14 @@ class APIMiddlewares(CustomBaseModel): cors: CORSSettings = Field(description="CORS middleware settings") +class Authentication(CustomBaseModel): + api_key: str = Field(description="X-API Key") + + class AppConfig(CustomBaseModel): middlewares: Optional[APIMiddlewares] = Field( - description="API Middlewares settings", + description="API middlewares settings", + ) + authentication: Authentication = Field( + description="API authentication settings", ) diff --git a/audit_logger/models/custom_base.py b/audit_logger/models/custom_base.py index 9a972cc..18768c8 100644 --- a/audit_logger/models/custom_base.py +++ b/audit_logger/models/custom_base.py @@ -1,6 +1,6 @@ from typing import Any -from pydantic import BaseModel, Extra +from pydantic import BaseModel class CustomBaseModel(BaseModel): @@ -9,4 +9,4 @@ def __init__(self, **kwargs: Any) -> None: # Forbid extra fields and raise an exception if any are found. class Config: - extra = Extra.forbid + extra = "forbid" diff --git a/config-sample.yaml b/config-sample.yaml index 1097f84..8c4247c 100644 --- a/config-sample.yaml +++ b/config-sample.yaml @@ -1,3 +1,5 @@ +authentication: + api_key: "change-me-plz" middlewares: cors: allow_origins: