diff --git a/dreadnode_cli/api.py b/dreadnode_cli/api.py index 712df08..a0182c3 100644 --- a/dreadnode_cli/api.py +++ b/dreadnode_cli/api.py @@ -86,30 +86,6 @@ async def _log_request(self, request: httpx.Request) -> None: print("Content:", request.content) print("-------------------------------------------") - def _get_headers(self, additional: dict[str, str] | None = None) -> dict[str, str]: - """Get the common headers for every request.""" - - headers = { - "User-Agent": f"dreadnode-cli/{__version__}", - "Accept": "application/json", - } - - if additional: - headers.update(additional) - - return headers - - def _get_auth_cookies(self) -> dict[str, str]: - """Get the authentication cookies for the request. Will raise an error if not authenticated or if the tokens are expired.""" - - if not self.auth: - raise Exception("not authenticated") - - elif self.auth.is_expired(): - raise Exception("authentication expired") - - return {"refresh_token": self.auth.refresh_token.data} - def _get_error_message(self, response: httpx.Response) -> str: """Get the error message from the response.""" @@ -129,15 +105,48 @@ async def _request( ) -> httpx.Response: """Make a request to the Dreadnode API.""" - # common headers - headers = self._get_headers() - # if authentication is required add the necessary cookies (will check for valid auth data) - cookies = self._get_auth_cookies() if auth else None + cookies = None + headers = { + "User-Agent": f"dreadnode-cli/{__version__}", + "Accept": "application/json", + } + + if auth: + # check if we have valid auth data + if not self.auth: + raise Exception("not authenticated") + elif self.auth.is_expired(): + raise Exception("authentication expired") + # The refresh_token can be sent along w/ requests as a cookie and the server will + # automatically refresh the access_token if needed and put it back into + # the access_token cookie. + cookies = {"refresh_token": self.auth.refresh_token.data} + # The access_token should be used for all requests. + headers["Authorization"] = f"Bearer {self.auth.access_token.data}" async with httpx.AsyncClient( cookies=cookies, headers=headers, event_hooks={"request": [self._log_request]} ) as client: response = await client.request(method, f"{self.base_url}{path}", json=json_data) + + # see if the endpoint returned refreshed tokens + access_token = response.cookies.get("access_token") + refresh_token = response.cookies.get("refresh_token") + + if access_token or refresh_token: + print(f"[DEBUG] got tokens from cookies access_token={access_token} refresh_token={refresh_token}") + print(f"[DEBUG] access_token expires in: {Token(access_token).expires_at}") + print(f"[DEBUG] refresh_token expires in: {Token(refresh_token).expires_at}") + + if access_token and refresh_token: + # update auth data only if it changed + if not self.auth or ( + self.auth + and self.auth.access_token.data == access_token + and self.auth.refresh_token.data == refresh_token + ): + self.auth = Authentication(access_token, refresh_token) + if allow_non_ok or response.status_code == 200: return response else: diff --git a/dreadnode_cli/cli.py b/dreadnode_cli/cli.py index 04378af..ccf5d29 100644 --- a/dreadnode_cli/cli.py +++ b/dreadnode_cli/cli.py @@ -52,10 +52,17 @@ def login( if auth is None: raise Exception("authentication failed") + access_token = auth.get("access_token") + refresh_token = auth.get("refresh_token") + + print() + print(f"[DEBUG] access_token expires in: {api.Token(access_token).expires_at}") + print(f"[DEBUG] refresh_token expires in: {api.Token(refresh_token).expires_at}") + # store the authentication data for the profile UserConfig.read().set_profile_config( profile, - ServerConfig(url=server, access_token=auth.get("access_token"), refresh_token=auth.get("refresh_token")), + ServerConfig(url=server, access_token=access_token, refresh_token=refresh_token), ).write() print(":white_check_mark: authentication successful") diff --git a/dreadnode_cli/profile/cli.py b/dreadnode_cli/profile/cli.py index bc24cc2..b5c52f0 100644 --- a/dreadnode_cli/profile/cli.py +++ b/dreadnode_cli/profile/cli.py @@ -25,14 +25,14 @@ def list() -> None: for profile, server in config.servers.items(): active_profile = server == current - refresh_token = Token(server.refresh_token) + access_token = Token(server.access_token) table.add_row( f"[bold]{profile}*[/]" if active_profile else profile, server.url, "[red]expired[/]" - if refresh_token.is_expired() - else f'{refresh_token.expires_at.strftime("%Y-%m-%d %H:%M:%S")} ({utils.time_to(refresh_token.expires_at)})', + if access_token.is_expired() + else f'{access_token.expires_at.strftime("%Y-%m-%d %H:%M:%S")} ({utils.time_to(access_token.expires_at)})', style="cyan" if active_profile else None, )