From 4827e1d327c68be2a5560aa3ff86d4f22e28729c Mon Sep 17 00:00:00 2001 From: monoxgas Date: Mon, 18 Nov 2024 15:11:41 -0700 Subject: [PATCH] Extend schemas to align with our backend changes. --- dreadnode_cli/agent/format.py | 4 ++- dreadnode_cli/api.py | 46 ++++++++++++++++++++++++++++------- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/dreadnode_cli/agent/format.py b/dreadnode_cli/agent/format.py index 249160c..eaa73ba 100644 --- a/dreadnode_cli/agent/format.py +++ b/dreadnode_cli/agent/format.py @@ -14,12 +14,14 @@ P = t.ParamSpec("P") -def get_status_style(status: str | None) -> str: +def get_status_style(status: api.Client.StrikeRunStatus | api.Client.StrikeRunZoneStatus | None) -> str: return ( { "pending": "dim", "running": "bold cyan", "completed": "bold green", + "mixed": "bold gold3", + "terminated": "bold dark_orange3", "failed": "bold red", "timeout": "bold yellow", }.get(status, "") diff --git a/dreadnode_cli/api.py b/dreadnode_cli/api.py index 6950039..3257601 100644 --- a/dreadnode_cli/api.py +++ b/dreadnode_cli/api.py @@ -231,7 +231,25 @@ def submit_challenge_flag(self, challenge: str, flag: str) -> bool: # Strikes - StrikeRunStatus = t.Literal["pending", "deploying", "running", "completed", "timeout", "failed"] + StrikeRunStatus = t.Literal[ + "pending", # Waiting to be processed in the DB + "deploying", # Dropship pod is being created and configured + "running", # Dropship pod is actively executing + "completed", # All zones finished successfully + "mixed", # Some zones succeeded, others terminated + "terminated", # All zones ended with non-zero exit codes + "timeout", # Maximum allowed run time was exceeded + "failed", # System/infrastructure error occurred + ] + StrikeRunZoneStatus = t.Literal[ + "pending", # Waiting to be processed in the DB + "deploying", # Dropship is creating the zone resources + "running", # Zone pods are actively executing + "completed", # Agent completed successfully (exit code 0) + "terminated", # Agent ended with non-zero exit code + "timeout", # Maximum allowed run time was exceeded + "failed", # System/infrastructure error occurred + ] class StrikeModel(BaseModel): key: str @@ -273,7 +291,7 @@ class StrikeAgentResponse(BaseModel): key: str name: str | None created_at: datetime - latest_run_status: t.Optional["Client.StrikeRunStatus"] + latest_run_status: "Client.StrikeRunStatus | None" latest_run_id: UUID | None versions: list["Client.StrikeAgentVersion"] latest_version: "Client.StrikeAgentVersion" @@ -286,7 +304,7 @@ class StrikeAgentSummaryResponse(BaseModel): key: str name: str | None created_at: datetime - latest_run_status: t.Optional["Client.StrikeRunStatus"] + latest_run_status: "Client.StrikeRunStatus | None" latest_run_id: UUID | None latest_version: "Client.StrikeAgentVersion" revision: int @@ -296,23 +314,30 @@ class StrikeRunOutputScore(BaseModel): explanation: str | None = None metadata: dict[str, t.Any] = {} - class StrikeRunOutput(BaseModel): - data: dict[str, t.Any] + class StrikeRunOutputSummary(BaseModel): score: t.Optional["Client.StrikeRunOutputScore"] = None metadata: dict[str, t.Any] = {} - class StrikeRunZone(BaseModel): + class StrikeRunOutput(StrikeRunOutputSummary): + data: dict[str, t.Any] + + class _StrikeRunZone(BaseModel): id: UUID key: str - status: "Client.StrikeRunStatus" + status: "Client.StrikeRunZoneStatus" start: datetime | None end: datetime | None + + class StrikeRunZoneSummary(_StrikeRunZone): + outputs: list["Client.StrikeRunOutputSummary"] + + class StrikeRunZone(_StrikeRunZone): agent_logs: str | None container_logs: dict[str, str] outputs: list["Client.StrikeRunOutput"] inferences: list[dict[str, t.Any]] - class StrikeRunSummaryResponse(BaseModel): + class _StrikeRun(BaseModel): id: UUID strike_id: UUID strike_key: str @@ -328,7 +353,10 @@ class StrikeRunSummaryResponse(BaseModel): start: datetime | None end: datetime | None - class StrikeRunResponse(StrikeRunSummaryResponse): + class StrikeRunSummaryResponse(_StrikeRun): + zones: list["Client.StrikeRunZoneSummary"] + + class StrikeRunResponse(_StrikeRun): zones: list["Client.StrikeRunZone"] def get_strike(self, strike: str) -> StrikeResponse: