Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renew tokens #209

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 105 additions & 34 deletions src/scitacean/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .logging import get_logger
from .pid import PID
from .typing import DownloadConnection, FileTransfer, UploadConnection
from .util.credentials import ExpiringToken, SecretStr, StrStorage
from .util.credentials import SecretStr, Token


class Client:
Expand All @@ -34,8 +34,9 @@ class Client:
Clients hold all information needed to communicate with a SciCat instance
and a filesystem that holds data files (via ``file_transfer``).

Use :func:`Client.from_token` or :func:`Client.from_credentials` to initialize
a client instead of the constructor directly.
Use :func:`Client.from_token`, :func:`Client.from_credentials`, or
:func:`Client.without_login` to initialize a client instead
of the constructor directly.

See the user guide for typical usage patterns.
In particular, `Downloading Datasets <../../user-guide/downloading.ipynb>`_
Expand All @@ -50,8 +51,8 @@ def __init__(
):
"""Initialize a client.

Do not use directly, instead use :func:`Client.from_token`
or :func:`Client.from_credentials`!
Do not use directly, instead use :func:`Client.from_token`,
:func:`Client.from_credentials`, or :func:`Client.without_login`!
"""
self._client = client
self._file_transfer = file_transfer
Expand All @@ -61,8 +62,9 @@ def from_token(
cls,
*,
url: str,
token: str | StrStorage,
token: str | SecretStr | Token,
file_transfer: FileTransfer | None = None,
auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30),
) -> Client:
"""Create a new client and authenticate with a token.

Expand All @@ -74,26 +76,33 @@ def from_token(
User token to authenticate with SciCat.
file_transfer:
Handler for down-/uploads of files.
auto_renew_period:
If not ``None``, the SciCat login is renewed in operations
that happen within this time delta of the login expiration time.

Returns
-------
:
A new client.
"""
return Client(
client=ScicatClient.from_token(url=url, token=token),
client=ScicatClient.from_token(
url=url,
token=token,
auto_renew_period=auto_renew_period,
),
file_transfer=file_transfer,
)

# TODO rename to login? and provide logout?
@classmethod
def from_credentials(
cls,
*,
url: str,
username: str | StrStorage,
password: str | StrStorage,
username: str | SecretStr,
password: str | SecretStr,
file_transfer: FileTransfer | None = None,
auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30),
) -> Client:
"""Create a new client and authenticate with username and password.

Expand All @@ -108,6 +117,9 @@ def from_credentials(
Password of the user.
file_transfer:
Handler for down-/uploads of files.
auto_renew_period:
If not ``None``, the SciCat login is renewed in operations
that happen within this time delta of the login expiration time.

Returns
-------
Expand All @@ -116,14 +128,20 @@ def from_credentials(
"""
return Client(
client=ScicatClient.from_credentials(
url=url, username=username, password=password
url=url,
username=username,
password=password,
auto_renew_period=auto_renew_period,
),
file_transfer=file_transfer,
)

@classmethod
def without_login(
cls, *, url: str, file_transfer: FileTransfer | None = None
cls,
*,
url: str,
file_transfer: FileTransfer | None = None,
) -> Client:
"""Create a new client without authentication.

Expand All @@ -143,7 +161,8 @@ def without_login(
A new client.
"""
return Client(
client=ScicatClient.without_login(url=url), file_transfer=file_transfer
client=ScicatClient.without_login(url=url),
file_transfer=file_transfer,
)

@property
Expand Down Expand Up @@ -559,24 +578,23 @@ class ScicatClient:
def __init__(
self,
url: str,
token: str | StrStorage | None,
token: str | SecretStr | Token | None,
timeout: datetime.timedelta | None,
auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30),
):
# Need to add a final /
self._base_url = url[:-1] if url.endswith("/") else url
self._timeout = datetime.timedelta(seconds=10) if timeout is None else timeout
self._token: StrStorage | None = (
ExpiringToken.from_jwt(SecretStr(token))
if isinstance(token, str)
else token
)
self._token = _wrap_token(token)
self._auto_renew_period = auto_renew_period

@classmethod
def from_token(
cls,
url: str,
token: str | StrStorage,
token: str | SecretStr | Token,
timeout: datetime.timedelta | None = None,
auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30),
) -> ScicatClient:
"""Create a new low-level client and authenticate with a token.

Expand All @@ -588,21 +606,30 @@ def from_token(
User token to authenticate with SciCat.
timeout:
Timeout for all API requests.
auto_renew_period:
If not ``None``, the SciCat login is renewed in operations
that happen within this time delta of the login expiration time.

Returns
-------
:
A new low-level client.
"""
return ScicatClient(url=url, token=token, timeout=timeout)
return ScicatClient(
url=url,
token=token,
timeout=timeout,
auto_renew_period=auto_renew_period,
)

@classmethod
def from_credentials(
cls,
url: str,
username: str | StrStorage,
password: str | StrStorage,
username: str | SecretStr,
password: str | SecretStr,
timeout: datetime.timedelta | None = None,
auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30),
) -> ScicatClient:
"""Create a new low-level client and authenticate with username and password.

Expand All @@ -617,16 +644,17 @@ def from_credentials(
Password of the user.
timeout:
Timeout for all API requests.
auto_renew_period:
If not ``None``, the SciCat login is renewed in operations
that happen within this time delta of the login expiration time.

Returns
-------
:
A new low-level client.
"""
if not isinstance(username, StrStorage):
username = SecretStr(username)
if not isinstance(password, StrStorage):
password = SecretStr(password)
username = SecretStr(username)
password = SecretStr(password)
return ScicatClient(
url=url,
token=SecretStr(
Expand All @@ -638,6 +666,7 @@ def from_credentials(
)
),
timeout=timeout,
auto_renew_period=auto_renew_period,
)

@classmethod
Expand All @@ -661,7 +690,9 @@ def without_login(
:
A new low-level client.
"""
return ScicatClient(url=url, token=None, timeout=timeout)
return ScicatClient(
url=url, token=None, timeout=timeout, auto_renew_period=None
)

def get_dataset_model(
self, pid: PID, strict_validation: bool = False
Expand Down Expand Up @@ -1001,11 +1032,39 @@ def validate_dataset_model(
if not response["valid"]:
raise ValueError(f"Dataset {dset} did not pass validation in SciCat.")

def renew_login(self) -> None:
"""Request and assign a new SciCat token.

Can be used to prolong a login session before a token expires.
The new token is assigned to the client and is used for all future operations.

Raises :class:`scitacean.ScicatCommError` if renewal fails.
In this case, the old token will not be replaced.
"""
response = self._call_endpoint(
cmd="post", url="users/jwt", operation="renew_login", renew_login=False
)
self._token = _wrap_token(response["jwt"])

def _renew_login_if_needed(self, operation: str) -> None:
if (
self._token is not None
and self._token.expires_at is not None
and self._auto_renew_period is not None
):
if self._token.expires_at + self._auto_renew_period > datetime.datetime.now(
tz=self._token.expires_at.tzinfo
):
get_logger().info(
"Renewing SciCat login during operation '%s'", operation
)
self.renew_login()

def _send_to_scicat(
self, *, cmd: str, url: str, data: model.BaseModel | None = None
) -> requests.Response:
if self._token is not None:
token = self._token.get_str()
token = self._token.expose()
headers = {"Authorization": f"Bearer {token}"}
else:
token = ""
Expand Down Expand Up @@ -1043,7 +1102,11 @@ def _call_endpoint(
url: str,
data: model.BaseModel | None = None,
operation: str,
renew_login: bool = True,
) -> Any:
if renew_login:
self._renew_login_if_needed(operation)

full_url = _url_concat(self._base_url, url)
logger = get_logger()
logger.info("Calling SciCat API at %s for operation '%s'", full_url, operation)
Expand Down Expand Up @@ -1099,12 +1162,12 @@ def _make_orig_datablock(


def _log_in_via_users_login(
url: str, username: StrStorage, password: StrStorage, timeout: datetime.timedelta
url: str, username: SecretStr, password: SecretStr, timeout: datetime.timedelta
) -> requests.Response:
# Currently only used for functional accounts.
response = requests.post(
_url_concat(url, "Users/login"),
json={"username": username.get_str(), "password": password.get_str()},
json={"username": username.expose(), "password": password.expose()},
stream=False,
verify=True,
timeout=timeout.seconds,
Expand All @@ -1117,7 +1180,7 @@ def _log_in_via_users_login(


def _log_in_via_auth_msad(
url: str, username: StrStorage, password: StrStorage, timeout: datetime.timedelta
url: str, username: SecretStr, password: SecretStr, timeout: datetime.timedelta
) -> requests.Response:
# Used for user accounts.
import re
Expand All @@ -1126,7 +1189,7 @@ def _log_in_via_auth_msad(
base_url = re.sub(r"/api/v\d+/?", "", url)
response = requests.post(
_url_concat(base_url, "auth/msad"),
json={"username": username.get_str(), "password": password.get_str()},
json={"username": username.expose(), "password": password.expose()},
stream=False,
verify=True,
timeout=timeout.seconds,
Expand All @@ -1137,7 +1200,7 @@ def _log_in_via_auth_msad(


def _get_token(
url: str, username: StrStorage, password: StrStorage, timeout: datetime.timedelta
url: str, username: SecretStr, password: SecretStr, timeout: datetime.timedelta
) -> str:
"""Log in using the provided username + password.

Expand All @@ -1164,6 +1227,14 @@ def _get_token(
raise ScicatLoginError(response.content)


def _wrap_token(token: str | SecretStr | Token | None) -> Token | None:
match token:
case str() | SecretStr():
return Token.from_jwt(token, denial_period=datetime.timedelta(seconds=2))
case Token() | None:
return token


FileSelector = (
bool | str | list[str] | tuple[str] | re.Pattern[str] | Callable[[File], bool]
)
Expand Down
17 changes: 13 additions & 4 deletions src/scitacean/testing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..error import ScicatCommError
from ..pid import PID
from ..typing import FileTransfer
from ..util.credentials import StrStorage
from ..util.credentials import SecretStr


def _conditionally_disabled(func: Callable[..., Any]) -> Callable[..., Any]:
Expand Down Expand Up @@ -134,8 +134,9 @@ def from_token(
cls,
*,
url: str,
token: str | StrStorage,
token: str | SecretStr,
file_transfer: FileTransfer | None = None,
auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30),
) -> FakeClient:
"""Create a new fake client.

Expand All @@ -148,9 +149,10 @@ def from_credentials(
cls,
*,
url: str,
username: str | StrStorage,
password: str | StrStorage,
username: str | SecretStr,
password: str | SecretStr,
file_transfer: FileTransfer | None = None,
auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30),
) -> FakeClient:
"""Create a new fake client.

Expand Down Expand Up @@ -288,6 +290,13 @@ def validate_dataset_model(
# Models were locally validated on construction, assume they are valid.
pass

@_conditionally_disabled
def renew_login(self) -> None:
"""Request a new SciCat token.

Does nothing because FakeScicatClient does not use authentication.
"""


def _model_dict(mod: model.BaseModel) -> dict[str, Any]:
return {
Expand Down
Loading
Loading