diff --git a/dreadnode_cli/agent/cli.py b/dreadnode_cli/agent/cli.py index 8aa5570..fc7625f 100644 --- a/dreadnode_cli/agent/cli.py +++ b/dreadnode_cli/agent/cli.py @@ -16,16 +16,15 @@ from dreadnode_cli.agent.format import ( format_agent, format_agent_versions, + format_models, format_run, format_runs, - format_strike_models, format_strikes, - format_user_models, ) from dreadnode_cli.agent.templates import cli as templates_cli from dreadnode_cli.agent.templates.format import format_templates from dreadnode_cli.agent.templates.manager import TemplateManager -from dreadnode_cli.config import UserConfig, UserModel, UserModels +from dreadnode_cli.config import UserConfig from dreadnode_cli.profile.cli import switch as switch_profile from dreadnode_cli.types import GithubRepo from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli @@ -340,22 +339,14 @@ def deploy( if strike is None: raise Exception("No strike specified, use -s/--strike or set the strike in the agent config") - user_models = UserModels.read() - user_model: UserModel | None = None - # Verify the model if it was supplied if model is not None: - # check if it's a user model - user_model = next((m for m in user_models.models if m.key == model), None) - if not user_model: - # check if it's a strike model - strike_response = client.get_strike(strike) - if not any(m.key == model for m in strike_response.models): - models(directory, strike=strike) - print() - raise Exception(f"Model '{model}' is not a user model nor was found in strike '{strike_response.name}'") - - run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model) + strike_response = client.get_strike(strike) + if not any(m.key == model for m in strike_response.models): + print(format_models(strike_response.models)) + raise Exception(f"Model '{model}' not found in strike '{strike_response.name}'") + + run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model) agent_config.add_run(run.id).write(directory) formatted = format_run(run) @@ -378,11 +369,6 @@ def models( ] = pathlib.Path("."), strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to query")] = None, ) -> None: - user_models = UserModels.read() - if user_models.models: - print("[bold]User models:[/]\n") - print(format_user_models(user_models.models)) - if strike is None: agent_config = AgentConfig.read(directory) ensure_profile(agent_config) @@ -392,9 +378,7 @@ def models( raise Exception("No strike specified, use -s/--strike or set the strike in the agent config") strike_response = api.create_client().get_strike(strike) - if user_models.models: - print("\n[bold]Strike models:[/]\n") - print(format_strike_models(strike_response.models)) + print(format_models(strike_response.models)) @cli.command(help="List available strikes") diff --git a/dreadnode_cli/agent/format.py b/dreadnode_cli/agent/format.py index 442f97d..f668c1f 100644 --- a/dreadnode_cli/agent/format.py +++ b/dreadnode_cli/agent/format.py @@ -9,7 +9,6 @@ from rich.text import Text from dreadnode_cli import api -from dreadnode_cli.config import UserModel P = t.ParamSpec("P") @@ -63,26 +62,7 @@ def format_time(dt: datetime | None) -> str: return dt.astimezone().strftime("%c") if dt else "-" -def format_user_models(models: list[UserModel]) -> RenderableType: - table = Table(box=box.ROUNDED) - table.add_column("key") - table.add_column("name") - table.add_column("provider") - table.add_column("api_key") - - for model in models: - provider_style = get_model_provider_style(model.provider) - table.add_row( - Text(model.key), - Text(model.name, style=f"bold {provider_style}"), - Text(model.provider, style=provider_style), - Text("yes" if model.api_key else "no", style="green" if model.api_key else "dim"), - ) - - return table - - -def format_strike_models(models: list[api.Client.StrikeModel]) -> RenderableType: +def format_models(models: list[api.Client.StrikeModel]) -> RenderableType: table = Table(box=box.ROUNDED) table.add_column("key") table.add_column("name") @@ -292,8 +272,7 @@ def format_run(run: api.Client.StrikeRunResponse, *, verbose: bool = False, incl agent_name = f"[bold magenta]{run.agent_key}[/]" table.add_row("", "") - # um@ is added to indicate a user model - table.add_row("model", run.model.replace("um@", "") if run.model else "") + table.add_row("model", run.model or "") table.add_row("agent", f"{agent_name} ([dim]rev[/] [yellow]{run.agent_revision}[/])") table.add_row("image", Text(run.agent_version.container.image, style="cyan")) table.add_row("notes", run.agent_version.notes or "-") @@ -325,8 +304,7 @@ def format_runs(runs: list[api.Client.StrikeRunSummaryResponse]) -> RenderableTy str(run.id), f"[bold magenta]{run.agent_key}[/] [dim]:[/] [yellow]{run.agent_revision}[/]", Text(run.status, style="bold " + get_status_style(run.status)), - # um@ is added to indicate a user model - Text(run.model.replace("um@", "") if run.model else "-"), + Text(run.model or "-"), format_time(run.start), Text(format_duration(run.start, run.end), style="bold cyan"), ) diff --git a/dreadnode_cli/api.py b/dreadnode_cli/api.py index a7a9f3d..48b8d28 100644 --- a/dreadnode_cli/api.py +++ b/dreadnode_cli/api.py @@ -11,7 +11,7 @@ from rich import print from dreadnode_cli import __version__, utils -from dreadnode_cli.config import UserConfig, UserModel +from dreadnode_cli.config import UserConfig from dreadnode_cli.defaults import ( DEBUG, DEFAULT_MAX_POLL_TIME, @@ -430,12 +430,7 @@ def create_strike_agent_version( return self.StrikeAgentResponse(**response.json()) def start_strike_run( - self, - agent_version_id: UUID, - *, - model: str | None = None, - user_model: UserModel | None = None, - strike: UUID | str | None = None, + self, agent_version_id: UUID, *, model: str | None = None, strike: UUID | str | None = None ) -> StrikeRunResponse: response = self.request( "POST", @@ -443,7 +438,6 @@ def start_strike_run( json_data={ "agent_version_id": str(agent_version_id), "model": model, - "user_model": user_model.model_dump(mode="json") if user_model else None, "strike": str(strike) if strike else None, }, ) diff --git a/dreadnode_cli/config.py b/dreadnode_cli/config.py index b9c901a..b5c5f5f 100644 --- a/dreadnode_cli/config.py +++ b/dreadnode_cli/config.py @@ -1,11 +1,11 @@ -from pydantic import BaseModel +import pydantic from rich import print from ruamel.yaml import YAML -from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH, USER_MODELS_CONFIG_PATH +from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH -class ServerConfig(BaseModel): +class ServerConfig(pydantic.BaseModel): """Server specific authentication data and API URL.""" url: str @@ -16,7 +16,7 @@ class ServerConfig(BaseModel): refresh_token: str -class UserConfig(BaseModel): +class UserConfig(pydantic.BaseModel): """User configuration supporting multiple server profiles.""" active: str | None = None @@ -74,31 +74,3 @@ def set_server_config(self, config: ServerConfig, profile: str | None = None) -> profile = profile or self.active or DEFAULT_PROFILE_NAME self.servers[profile] = config return self - - -class UserModel(BaseModel): - """ - A user defined model. - """ - - key: str - name: str - provider: str - generator_id: str - api_key: str | None = None - - -class UserModels(BaseModel): - """User models configuration.""" - - models: list[UserModel] = [] - - @classmethod - def read(cls) -> "UserModels": - """Read the user models configuration from the file system or return an empty instance.""" - - if not USER_MODELS_CONFIG_PATH.exists(): - return cls() - - with USER_MODELS_CONFIG_PATH.open("r") as f: - return cls.model_validate(YAML().load(f)) diff --git a/dreadnode_cli/defaults.py b/dreadnode_cli/defaults.py index 46607cd..aab6b78 100644 --- a/dreadnode_cli/defaults.py +++ b/dreadnode_cli/defaults.py @@ -25,12 +25,6 @@ os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config" ) -# path to the user models configuration file -USER_MODELS_CONFIG_PATH = pathlib.Path( - # allow overriding the user config file via env variable - os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "models.yml" -) - # path to the templates directory TEMPLATES_PATH = pathlib.Path( # allow overriding the templates path via env variable