Skip to content

Commit

Permalink
- More status codes replaced with FastAPI status constants
Browse files Browse the repository at this point in the history
- linting & formatting
- more exceptions added to `process_audit_logs`
  • Loading branch information
bulletinmybeard committed Apr 6, 2024
1 parent 0b27486 commit f056f2a
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 49 deletions.
4 changes: 2 additions & 2 deletions audit_logger/elastic_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from elasticsearch import Elasticsearch
from elasticsearch_dsl import A, Q, Search
from fastapi import HTTPException
from fastapi import HTTPException, status

from audit_logger.custom_logger import get_logger
from audit_logger.models import (
Expand Down Expand Up @@ -54,7 +54,7 @@ def process_parameters(self, params: SearchParams) -> Dict[str, Any]:

if not response.success():
raise HTTPException(
status_code=400, detail="[QueryFilterElasticsearch] Search failed."
status_code=status.HTTP_400_BAD_REQUEST, detail="Search failed."
)

return {
Expand Down
4 changes: 1 addition & 3 deletions audit_logger/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ async def value_error_handler(_: Any, exc: Exception) -> JSONResponse:
raise exc


async def validation_exception_handler(
_: Any, exc: Exception
) -> JSONResponse:
async def validation_exception_handler(_: Any, exc: Exception) -> JSONResponse:
"""
Handles validation errors.
Expand Down
64 changes: 41 additions & 23 deletions audit_logger/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from audit_logger.models import (
AuditLogEntry,
BulkAuditLogOptions,
GenericResponse,
SearchParams,
SearchResults,
)
Expand Down Expand Up @@ -80,8 +79,9 @@ async def verify_api_key(api_key: str = Depends(api_key_header)) -> str:
return api_key


@app.post("/create", dependencies=[Depends(verify_api_key)])
# ) -> GenericResponse:
@app.post(
"/create", dependencies=[Depends(verify_api_key)], response_class=JSONResponse
)
async def create_audit_log_entry(audit_log: AuditLogEntry = Body(...)) -> Any:
"""
Receives an audit log entry, validates it, and processes
Expand All @@ -101,10 +101,12 @@ async def create_audit_log_entry(audit_log: AuditLogEntry = Body(...)) -> Any:
)


@app.post("/create-bulk", dependencies=[Depends(verify_api_key)])
@app.post(
"/create-bulk", dependencies=[Depends(verify_api_key)], response_class=JSONResponse
)
async def create_bulk_audit_log_entries(
audit_logs: List[AuditLogEntry] = Body(...),
) -> GenericResponse:
) -> Any:
"""
Receives one or multiple audit log entries, validates them, and processes
them to be stored in Elasticsearch.
Expand All @@ -114,39 +116,55 @@ async def create_bulk_audit_log_entries(
Returns:
CreateResponse
Raises:
Union[HTTPException, BulkLimitExceededError]
"""
bulk_limit = 350
if len(audit_logs) > bulk_limit:
raise BulkLimitExceededError(limit=bulk_limit)

return await process_audit_logs(
elastic,
cast(str, env_vars.elastic_index_name),
[dict(model.dict()) for model in audit_logs],
)
try:
return await process_audit_logs(
elastic,
cast(str, env_vars.elastic_index_name),
[dict(model.dict()) for model in audit_logs],
)
except HTTPException as e:
raise e
except Exception as e:
logger.error("Error: %s\nFull stack trace:\n%s", e, traceback.format_exc())
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process audit log entries.",
) from e


@app.post("/create/create-bulk-auto", dependencies=[Depends(verify_api_key)])
@app.post(
"/create/create-bulk-auto",
dependencies=[Depends(verify_api_key)],
response_class=JSONResponse,
)
async def create_random_audit_log_entries(
options: BulkAuditLogOptions,
) -> GenericResponse:
) -> Any:
"""
Generates and stores a single random audit log entry.
Returns:
CreateResponse
Raises:
HTTPException
"""
return await process_audit_logs(
elastic,
cast(str, env_vars.elastic_index_name),
generate_audit_log_entries_with_fake_data(options),
)
try:
return await process_audit_logs(
elastic,
cast(str, env_vars.elastic_index_name),
generate_audit_log_entries_with_fake_data(options),
)
except HTTPException as e:
raise e
except Exception as e:
logger.error("Error: %s\nFull stack trace:\n%s", e, traceback.format_exc())
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process audit log entries.",
) from e


@app.post("/search", dependencies=[Depends(verify_api_key)])
Expand Down
2 changes: 1 addition & 1 deletion audit_logger/models/env_vars.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import List, Optional

from pydantic import Field, HttpUrl, field_validator

Expand Down
88 changes: 68 additions & 20 deletions audit_logger/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,26 @@
from typing import Any, Dict, List, Union
from zoneinfo import ZoneInfo

from elasticsearch import Elasticsearch, SerializationError, helpers
from elasticsearch import (
BadRequestError,
ConflictError,
ConnectionError,
Elasticsearch,
NotFoundError,
SerializationError,
TransportError,
helpers,
)
from faker import Faker
from fastapi import HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import ValidationError

from audit_logger.custom_logger import get_logger
from audit_logger.models import (
ActorDetails,
AuditLogEntry,
BulkAuditLogOptions,
GenericResponse,
ResourceDetails,
)
from audit_logger.models.env_vars import EnvVars
Expand Down Expand Up @@ -143,12 +152,12 @@ def generate_audit_log_entries_with_fake_data(
return [generate_log_entry().dict() for _ in range(settings.bulk_count)]


# ) -> GenericResponse:
# GenericResponse
async def process_audit_logs(
elastic: Elasticsearch,
elastic_index_name: str,
log_entries: Union[AuditLogEntry, List[Union[Dict, AuditLogEntry]]],
) -> Any:
) -> JSONResponse:
"""
Processes a list of audit log entries by sending them to Elasticsearch using the bulk API.
Expand All @@ -163,37 +172,76 @@ async def process_audit_logs(
Raises:
- HTTPException
"""

is_bulk_operation = isinstance(log_entries, list)
if not is_bulk_operation:
log_entries = [log_entries.dict()]

try:
is_bulk_operation = isinstance(log_entries, list)
if not is_bulk_operation:
log_entries = [log_entries.dict()]

operations = create_bulk_operations(elastic_index_name, log_entries)
success_count, failed = helpers.bulk(elastic, operations)
failed_items = failed if isinstance(failed, list) else []

if len(failed_items) > 0:
raise HTTPException(
status_code=500,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to process audit logs: {str(failed_items)}",
)

if is_bulk_operation:
return GenericResponse(
status="success",
success_count=success_count,
return JSONResponse(
content={
"status": "success",
"success_count": success_count,
"failed_items": failed_items,
},
status_code=status.HTTP_207_MULTI_STATUS,
)

return status.HTTP_201_CREATED
except SerializationError as e:
logger.error(
"SerializationError: %s\nFull stack trace:\n%s", e, traceback.format_exc()
return JSONResponse(
content={
"status": "success",
},
status_code=status.HTTP_201_CREATED,
)
raise HTTPException(status_code=500, detail="Failed to process audit logs")
except Exception as e:

except SerializationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to serialize data for Elasticsearch: {}".format(e),
) from e

except ConnectionError as e:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Could not connect to Elasticsearch: {}".format(e),
) from e

except TransportError as e: # Superclass for more specific transport errors
if e.status_code == 404:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
elif e.status_code == 409:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e),
) from e
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Elasticsearch transport error: {}".format(e),
) from e

except (NotFoundError, ConflictError, BadRequestError) as e:
raise HTTPException(status_code=e.status_code, detail=str(e)) from e

except Exception as e: # Catch-all for unexpected errors
logger.error("Error: %s\nFull stack trace:\n%s", e, traceback.format_exc())
raise HTTPException(status_code=500, detail="Failed to process audit logs")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Unexpected error occurred.",
) from e


def validate_date(date_str: str) -> bool:
Expand Down

0 comments on commit f056f2a

Please sign in to comment.