Skip to content

Commit

Permalink
new: added --source-branch argument
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Nov 20, 2024
1 parent 0180a22 commit 058477e
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 3 deletions.
1 change: 1 addition & 0 deletions CLI.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ $ dreadnode agent init [OPTIONS] STRIKE
* `-n, --name TEXT`: The project name (used for container naming)
* `-t, --template [rigging_basic|rigging_loop|nerve_basic]`: The template to use for the agent [default: rigging_basic]
* `--source TEXT`: Initialize the agent using a custom template from a github repository, ZIP archive URL or local folder
* `--source-branch TEXT`: If --source is a github repository, use this as the branch name [default: main]
* `--help`: Show this message and exit.

### `dreadnode agent latest`
Expand Down
9 changes: 8 additions & 1 deletion dreadnode_cli/agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def init(
help="Initialize the agent using a custom template from a github repository, ZIP archive URL or local folder",
),
] = None,
source_branch: t.Annotated[
str,
typer.Option(
"--source-branch",
help="If --source is a github repository, use this as the branch name",
),
] = "main",
) -> None:
print(f":coffee: Fetching strike '{strike}' ...")

Expand Down Expand Up @@ -100,7 +107,7 @@ def init(
# - full github repository URL
# - full ZIP archive URL
# - username/repo
source = normalize_template_source(source)
source = normalize_template_source(source, source_branch)
# download and unzip to a temporary directory
source_dir = download_and_unzip_archive(source)
# make sure the temporary directory is cleaned up
Expand Down
8 changes: 8 additions & 0 deletions dreadnode_cli/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ def test_normalize_template_source_custom_zip_url() -> None:
assert normalize_template_source(url) == url


def test_normalize_template_source_custom_branch() -> None:
# Test with custom branch name
url = "user/repo"
branch = "develop"
expected = "https://github.com/user/repo/archive/refs/heads/develop.zip"
assert normalize_template_source(url, branch) == expected


def test_download_and_unzip_archive_success(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch) -> None:
# create a mock zip file with test content
test_file_content = b"test content"
Expand Down
4 changes: 2 additions & 2 deletions dreadnode_cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def parse_jwt_token_expiration(token: str) -> datetime:
return datetime.fromtimestamp(json.loads(payload).get("exp"))


def normalize_template_source(source: str) -> str:
def normalize_template_source(source: str, source_branch: str = "main") -> str:
"""Normalize a template source to a ZIP archive URL."""

# github repository / ZIP archive URL
Expand All @@ -77,7 +77,7 @@ def normalize_template_source(source: str) -> str:

if not source.lower().endswith(".zip"):
# normalize to ZIP archive URL
source = f"{source}/archive/refs/heads/main.zip"
source = f"{source}/archive/refs/heads/{source_branch}.zip"

return source

Expand Down

0 comments on commit 058477e

Please sign in to comment.