Skip to content

Commit

Permalink
chore: minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Oct 16, 2024
1 parent dc66370 commit 6a9f03d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
13 changes: 3 additions & 10 deletions dreadnode_cli/api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
import base64
import json
from datetime import datetime, timezone
from typing import Any

import httpx
from rich import print

from dreadnode_cli import __version__
from dreadnode_cli import __version__, utils
from dreadnode_cli.config import ServerConfig, UserConfig
from dreadnode_cli.defaults import (
DEFAULT_MAX_POLL_TIME,
Expand All @@ -17,12 +16,6 @@
)


def _parse_jwt_expiration(token: str) -> datetime:
_, b64payload, _ = token.split(".")
payload = base64.urlsafe_b64decode(b64payload + "==").decode("utf-8")
return datetime.fromtimestamp(json.loads(payload).get("exp"))


class Token:
"""A JWT token with an expiration time."""

Expand All @@ -31,7 +24,7 @@ class Token:

def __init__(self, token: str):
self.data = token
self.expires_at = _parse_jwt_expiration(token)
self.expires_at = utils.parse_jwt_token_expiration(token)

def ttl(self) -> int:
"""Get number of seconds left until the token expires."""
Expand Down Expand Up @@ -226,7 +219,7 @@ async def setup_authenticated_client(config: ServerConfig, force_refresh: bool =
client = Client(base_url=config.url, auth=auth)

if auth.is_expired():
raise Exception("authentication expired")
raise Exception("authentication expired, use [bold]dreadnode login[/] to authenticate again")
elif force_refresh or auth.is_close_to_expiry():
# update the auth data
new_auth = await client.refresh_auth()
Expand Down
10 changes: 10 additions & 0 deletions dreadnode_cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import json
import pathlib
import typing as t
from datetime import datetime
Expand Down Expand Up @@ -31,6 +33,14 @@ def time_to(future_datetime: datetime) -> str:
return ", ".join(result) if result else "Just now"


def parse_jwt_token_expiration(token: str) -> datetime:
"""Return the expiration date from a JWT token."""

_, b64payload, _ = token.split(".")
payload = base64.urlsafe_b64decode(b64payload + "==").decode("utf-8")
return datetime.fromtimestamp(json.loads(payload).get("exp"))


def copy_template(src: pathlib.Path, dest: pathlib.Path, context: dict[str, t.Any]) -> None:
env = Environment(loader=FileSystemLoader(src))

Expand Down

0 comments on commit 6a9f03d

Please sign in to comment.