diff --git a/audit_logger/elastic_filters.py b/audit_logger/elastic_filters.py index 6e886f1..9585ff2 100644 --- a/audit_logger/elastic_filters.py +++ b/audit_logger/elastic_filters.py @@ -3,7 +3,7 @@ from elasticsearch import Elasticsearch from elasticsearch_dsl import A, Q, Search -from fastapi import HTTPException +from fastapi import HTTPException, status from audit_logger.custom_logger import get_logger from audit_logger.models import ( @@ -54,7 +54,7 @@ def process_parameters(self, params: SearchParams) -> Dict[str, Any]: if not response.success(): raise HTTPException( - status_code=400, detail="[QueryFilterElasticsearch] Search failed." + status_code=status.HTTP_400_BAD_REQUEST, detail="Search failed." ) return { diff --git a/audit_logger/exceptions.py b/audit_logger/exceptions.py index 4a7b052..372d71c 100644 --- a/audit_logger/exceptions.py +++ b/audit_logger/exceptions.py @@ -26,9 +26,7 @@ async def value_error_handler(_: Any, exc: Exception) -> JSONResponse: raise exc -async def validation_exception_handler( - _: Any, exc: Exception -) -> JSONResponse: +async def validation_exception_handler(_: Any, exc: Exception) -> JSONResponse: """ Handles validation errors. diff --git a/audit_logger/main.py b/audit_logger/main.py index c0e2662..c7c52cf 100644 --- a/audit_logger/main.py +++ b/audit_logger/main.py @@ -20,7 +20,6 @@ from audit_logger.models import ( AuditLogEntry, BulkAuditLogOptions, - GenericResponse, SearchParams, SearchResults, ) @@ -80,8 +79,9 @@ async def verify_api_key(api_key: str = Depends(api_key_header)) -> str: return api_key -@app.post("/create", dependencies=[Depends(verify_api_key)]) -# ) -> GenericResponse: +@app.post( + "/create", dependencies=[Depends(verify_api_key)], response_class=JSONResponse +) async def create_audit_log_entry(audit_log: AuditLogEntry = Body(...)) -> Any: """ Receives an audit log entry, validates it, and processes @@ -101,10 +101,12 @@ async def create_audit_log_entry(audit_log: AuditLogEntry = Body(...)) -> Any: ) -@app.post("/create-bulk", dependencies=[Depends(verify_api_key)]) +@app.post( + "/create-bulk", dependencies=[Depends(verify_api_key)], response_class=JSONResponse +) async def create_bulk_audit_log_entries( audit_logs: List[AuditLogEntry] = Body(...), -) -> GenericResponse: +) -> Any: """ Receives one or multiple audit log entries, validates them, and processes them to be stored in Elasticsearch. @@ -114,39 +116,55 @@ async def create_bulk_audit_log_entries( Returns: CreateResponse - - Raises: - Union[HTTPException, BulkLimitExceededError] """ bulk_limit = 350 if len(audit_logs) > bulk_limit: raise BulkLimitExceededError(limit=bulk_limit) - return await process_audit_logs( - elastic, - cast(str, env_vars.elastic_index_name), - [dict(model.dict()) for model in audit_logs], - ) + try: + return await process_audit_logs( + elastic, + cast(str, env_vars.elastic_index_name), + [dict(model.dict()) for model in audit_logs], + ) + except HTTPException as e: + raise e + except Exception as e: + logger.error("Error: %s\nFull stack trace:\n%s", e, traceback.format_exc()) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to process audit log entries.", + ) from e -@app.post("/create/create-bulk-auto", dependencies=[Depends(verify_api_key)]) +@app.post( + "/create/create-bulk-auto", + dependencies=[Depends(verify_api_key)], + response_class=JSONResponse, +) async def create_random_audit_log_entries( options: BulkAuditLogOptions, -) -> GenericResponse: +) -> Any: """ Generates and stores a single random audit log entry. Returns: CreateResponse - - Raises: - HTTPException """ - return await process_audit_logs( - elastic, - cast(str, env_vars.elastic_index_name), - generate_audit_log_entries_with_fake_data(options), - ) + try: + return await process_audit_logs( + elastic, + cast(str, env_vars.elastic_index_name), + generate_audit_log_entries_with_fake_data(options), + ) + except HTTPException as e: + raise e + except Exception as e: + logger.error("Error: %s\nFull stack trace:\n%s", e, traceback.format_exc()) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to process audit log entries.", + ) from e @app.post("/search", dependencies=[Depends(verify_api_key)]) diff --git a/audit_logger/models/env_vars.py b/audit_logger/models/env_vars.py index 1a4d937..54b1757 100644 --- a/audit_logger/models/env_vars.py +++ b/audit_logger/models/env_vars.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import List, Optional from pydantic import Field, HttpUrl, field_validator diff --git a/audit_logger/utils.py b/audit_logger/utils.py index cc5e81d..0ef8d51 100644 --- a/audit_logger/utils.py +++ b/audit_logger/utils.py @@ -6,9 +6,19 @@ from typing import Any, Dict, List, Union from zoneinfo import ZoneInfo -from elasticsearch import Elasticsearch, SerializationError, helpers +from elasticsearch import ( + BadRequestError, + ConflictError, + ConnectionError, + Elasticsearch, + NotFoundError, + SerializationError, + TransportError, + helpers, +) from faker import Faker from fastapi import HTTPException, status +from fastapi.responses import JSONResponse from pydantic import ValidationError from audit_logger.custom_logger import get_logger @@ -16,7 +26,6 @@ ActorDetails, AuditLogEntry, BulkAuditLogOptions, - GenericResponse, ResourceDetails, ) from audit_logger.models.env_vars import EnvVars @@ -143,12 +152,12 @@ def generate_audit_log_entries_with_fake_data( return [generate_log_entry().dict() for _ in range(settings.bulk_count)] -# ) -> GenericResponse: +# GenericResponse async def process_audit_logs( elastic: Elasticsearch, elastic_index_name: str, log_entries: Union[AuditLogEntry, List[Union[Dict, AuditLogEntry]]], -) -> Any: +) -> JSONResponse: """ Processes a list of audit log entries by sending them to Elasticsearch using the bulk API. @@ -163,37 +172,76 @@ async def process_audit_logs( Raises: - HTTPException """ - - is_bulk_operation = isinstance(log_entries, list) - if not is_bulk_operation: - log_entries = [log_entries.dict()] - try: + is_bulk_operation = isinstance(log_entries, list) + if not is_bulk_operation: + log_entries = [log_entries.dict()] + operations = create_bulk_operations(elastic_index_name, log_entries) success_count, failed = helpers.bulk(elastic, operations) failed_items = failed if isinstance(failed, list) else [] if len(failed_items) > 0: raise HTTPException( - status_code=500, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to process audit logs: {str(failed_items)}", ) if is_bulk_operation: - return GenericResponse( - status="success", - success_count=success_count, + return JSONResponse( + content={ + "status": "success", + "success_count": success_count, + "failed_items": failed_items, + }, + status_code=status.HTTP_207_MULTI_STATUS, ) - return status.HTTP_201_CREATED - except SerializationError as e: - logger.error( - "SerializationError: %s\nFull stack trace:\n%s", e, traceback.format_exc() + return JSONResponse( + content={ + "status": "success", + }, + status_code=status.HTTP_201_CREATED, ) - raise HTTPException(status_code=500, detail="Failed to process audit logs") - except Exception as e: + + except SerializationError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to serialize data for Elasticsearch: {}".format(e), + ) from e + + except ConnectionError as e: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Could not connect to Elasticsearch: {}".format(e), + ) from e + + except TransportError as e: # Superclass for more specific transport errors + if e.status_code == 404: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e + elif e.status_code == 409: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(e), + ) from e + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Elasticsearch transport error: {}".format(e), + ) from e + + except (NotFoundError, ConflictError, BadRequestError) as e: + raise HTTPException(status_code=e.status_code, detail=str(e)) from e + + except Exception as e: # Catch-all for unexpected errors logger.error("Error: %s\nFull stack trace:\n%s", e, traceback.format_exc()) - raise HTTPException(status_code=500, detail="Failed to process audit logs") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Unexpected error occurred.", + ) from e def validate_date(date_str: str) -> bool: