Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for token only authentication in db selection and check if credentials are valid. #177

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
21 changes: 11 additions & 10 deletions src/pystatis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from configparser import ConfigParser
from pathlib import Path

from pystatis import db
from pystatis.exception import PystatisConfigError

PKG_NAME = __name__.split(".", maxsplit=1)[0]
DEFAULT_CONFIG_DIR = str(Path().home() / f".{PKG_NAME}")
SUPPORTED_DB = ["genesis", "zensus", "regio"]
Expand Down Expand Up @@ -124,13 +127,7 @@


def init_config() -> None:
"""Create a new config .ini file in the given directory.

One-time function to be called for new users to create a new `config.ini` with default values (empty credentials).

Args:
config_dir (str, optional): Path to the root config directory. Defaults to the user home directory.
"""
"""Initialize the config variable by either creating a new config .ini file or load an existing config."""
if not config_exists():
create_default_config()
write_config()
Expand All @@ -150,9 +147,13 @@

def setup_credentials() -> None:
"""Setup credentials for all supported databases."""
for db in get_supported_db():
config.set(db, "username", _get_user_input(db, "username"))
config.set(db, "password", _get_user_input(db, "password"))
for db_name in get_supported_db():
config.set(db_name, "username", _get_user_input(db_name, "username"))
config.set(db_name, "password", _get_user_input(db_name, "password"))
if not db.check_credentials_are_valid(db_name):
raise PystatisConfigError(

Check warning on line 154 in src/pystatis/config.py

View check run for this annotation

Codecov / codecov/patch

src/pystatis/config.py#L154

Added line #L154 was not covered by tests
f"Provided credentials for database '{db_name}' are not valid! Please provide the correct credentials."
)

write_config()

Expand Down
54 changes: 45 additions & 9 deletions src/pystatis/db.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Module provides functions to set the active database and get active database properties."""

import json
import logging

from pystatis import config
from pystatis.cache import normalize_name
from pystatis import cache
from pystatis.exception import PystatisConfigError
from pystatis import http_helper

logger = logging.getLogger(__name__)

Expand All @@ -24,7 +26,7 @@
regex_db = config.get_db_identifiers()

# Strip optional leading * and trailing job id
table_name = normalize_name(table_name).lstrip("*")
table_name = cache.normalize_name(table_name).lstrip("*")

# Get list of matching dbs
db_matches = [db_name for db_name, reg in regex_db.items() if reg.match(table_name)]
Expand Down Expand Up @@ -52,7 +54,7 @@
"""
for db_name in db_matches:
# Return first hit with existing credentials.
if check_credentials(db_name):
if check_credentials_are_set(db_name):
return db_name

raise PystatisConfigError(
Expand All @@ -63,19 +65,26 @@


def get_host(db_name: str) -> str:
return config.config[db_name]["base_url"]
return config.config[db_name]["base_url"] # type: ignore


def get_user(db_name: str) -> str:
return config.config[db_name]["username"]
return config.config[db_name]["username"] # type: ignore


def set_user(db_name: str, new_username: str) -> None:
config.config.set(db_name, "username", new_username)
check_credentials_are_valid(db_name)
config.write_config()

Check warning on line 78 in src/pystatis/db.py

View check run for this annotation

Codecov / codecov/patch

src/pystatis/db.py#L76-L78

Added lines #L76 - L78 were not covered by tests


def get_pw(db_name: str) -> str:
return config.config[db_name]["password"]
return config.config[db_name]["password"] # type: ignore


def set_pw(db_name: str, new_pw: str) -> None:
config.config.set(db_name, "password", new_pw)
check_credentials_are_valid(db_name)
config.write_config()


Expand All @@ -84,14 +93,41 @@
return get_host(db_name), get_user(db_name), get_pw(db_name)


def check_credentials(db_name: str) -> bool:
def check_credentials_are_set(db_name: str) -> bool:
"""
Checks if a username and password is stored for the specified database.
Checks if a username is stored for the specified database.

We only check for the username and not for the password to be compatible with token-only authentication.

Args:
db_name: Name of database to check credentials for.

Returns:
TRUE if credentials were found, FALSE otherwise.
"""
return get_user(db_name) != "" and get_pw(db_name) != ""
return get_user(db_name) != ""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change now ensures that we are fine with either user + password or user only (if token is used).
Of course, this does not check whether the credentials are valid.
However, we check the validity of the credentials when they are set with the functions set_credentials etc., so I opted to not check them here again. Otherwise, we would call the /logincheck endpoint each time the user wants to get a table.
What do you think @pmayd ?



def check_credentials_are_valid(db_name: str) -> bool:
"""
Checks if the provided user and password is valid by calling the respective endpoint.

Since the API returns a 200 status code for valid and invalid credentials, we need to parse the response text itself.

Args:
db_name: Name of database to check credentials for.

Returns:
TRUE if credentials are valid, FALSE otherwise.
"""
credential_check_dict = json.loads(
http_helper.load_data(
endpoint="helloworld",
method="logincheck",
params=dict(),
db_name=db_name,
).decode("UTF-8")
)
credential_check_status = credential_check_dict.get("Status", "")
# Do not check for full sentence to be more robust against slight changes in response.
return "erfolgreich" in credential_check_status.lower()
19 changes: 11 additions & 8 deletions src/pystatis/http_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import requests

from pystatis import config, db
from pystatis.cache import cache_data, hit_in_cash, normalize_name, read_from_cache
from pystatis import cache
from pystatis.exception import DestatisStatusError, NoNewerDataError, TableNotFoundError
from pystatis.types import ParamDict

Expand Down Expand Up @@ -43,12 +43,12 @@ def load_data(
name = params.get("name")

if name is not None:
name = normalize_name(name)
name = cache.normalize_name(name)

if endpoint == "data":
if hit_in_cash(cache_dir, name, params):
if cache.hit_in_cash(cache_dir, name, params):
print("hit")
data = read_from_cache(cache_dir, name, params)
data = cache.read_from_cache(cache_dir, name, params)
else:
response = get_data_from_endpoint(endpoint, method, params, db_name)
content_type = response.headers.get("Content-Type", "text/csv").split("/")[
Expand Down Expand Up @@ -79,16 +79,16 @@ def load_data(
)[-1]
data = response.content

cache_data(cache_dir, name, params, data, content_type)
cache.cache_data(cache_dir, name, params, data, content_type)

# bytes response in case of zip content type cannot be directly decoded, so we have to load the zip first!
if content_type == "zip":
data = read_from_cache(cache_dir, name, params)
data = cache.read_from_cache(cache_dir, name, params)
else:
response = get_data_from_endpoint(endpoint, method, params, db_name)
data = response.content

return data
return data # type: ignore


def get_data_from_endpoint(
Expand Down Expand Up @@ -157,7 +157,10 @@ def get_response(db_name: str, params: ParamDict) -> requests.Response:

response.encoding = "UTF-8"
_check_invalid_status_code(response)
_check_invalid_destatis_status_code(response)

# logincheck endpoint only returns string status with failure/success information. No further check necessary.
if method != "logincheck":
_check_invalid_destatis_status_code(response)

return response

Expand Down
33 changes: 25 additions & 8 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import pytest

from pystatis import config
from pystatis import config, db, http_helper
from pystatis.db import check_credentials_are_valid


@pytest.fixture()
Expand Down Expand Up @@ -64,19 +65,35 @@ def test_missing_file(config_, caplog):
assert record.levelname == "CRITICAL"


def test_setup_credentials(config_):
for db in config.get_supported_db():
def test_setup_credentials(mocker, config_):
mocker.patch.object(db, "check_credentials_are_valid", return_value=True)
for db_name in config.get_supported_db():
for field in ["username", "password"]:
if field == "username":
os.environ[f"PYSTATIS_{db.upper()}_API_{field.upper()}"] = "test"
os.environ[f"PYSTATIS_{db_name.upper()}_API_{field.upper()}"] = "test"
else:
os.environ[f"PYSTATIS_{db.upper()}_API_{field.upper()}"] = "test123!"
os.environ[f"PYSTATIS_{db_name.upper()}_API_{field.upper()}"] = (
"test123!"
)

config.setup_credentials()

for db in config.get_supported_db():
assert config_[db]["username"] == "test"
assert config_[db]["password"] == "test123!"
for db_name in config.get_supported_db():
assert config_[db_name]["username"] == "test"
assert config_[db_name]["password"] == "test123!"


@pytest.mark.parametrize(
"mock_return, check_result",
[
(b'{"Status": "erfolgreich"}', True),
(b'{"Status": "fehlgeschlagen"}', False),
],
)
def test_check_credentials_are_valid(mocker, mock_return: bytes, check_result: bool):
mocker.patch.object(http_helper, "load_data", return_value=mock_return)
# Db name not important since we mock the request result anyway.
assert check_credentials_are_valid("genesis") == check_result


def test_supported_db():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_identify_db_with_multiple_matches(config_):
config_.set("genesis", "password", "test")
db_match = db.identify_db_matches("1234567890")
for db_name in db_match:
if db.check_credentials(db_name):
if db.check_credentials_are_set(db_name):
break
assert db_name == "genesis"

Expand All @@ -60,7 +60,7 @@ def test_identify_db_with_multiple_matches(config_):
config_.set("regio", "password", "test")
db_match = db.identify_db_matches("1234567890")
for db_name in db_match:
if db.check_credentials(db_name):
if db.check_credentials_are_set(db_name):
break
assert db_name == "regio"

Expand Down
2 changes: 1 addition & 1 deletion tests/test_http_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_get_response_from_endpoint(mocker):
"pystatis.http_helper.requests", return_value=_generic_request_status()
)
mocker.patch("pystatis.db.get_settings", return_value=("host", "user", "pw"))
mocker.patch("pystatis.db.check_credentials", return_value=True)
mocker.patch("pystatis.db.check_credentials_are_set", return_value=True)

get_data_from_endpoint(
endpoint="endpoint", method="method", params={"name": "21111-0001"}
Expand Down
10 changes: 5 additions & 5 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
def test_get_data(
mocker, table_name: str, expected_shape: tuple[int, int], language: str
):
mocker.patch.object(pystatis.db, "check_credentials", return_value=True)
mocker.patch.object(pystatis.db, "check_credentials_are_set", return_value=True)
table = pystatis.Table(name=table_name)
table.get_data(prettify=False, language=language)

Expand All @@ -84,7 +84,7 @@ def test_get_data(
def test_get_data_with_quality_on_and_prettify_false(
mocker, table_name: str, expected_shape: tuple[int, int]
):
mocker.patch.object(pystatis.db, "check_credentials", return_value=True)
mocker.patch.object(pystatis.db, "check_credentials_are_set", return_value=True)
table = pystatis.Table(name=table_name)
table.get_data(prettify=False, quality="on")

Expand Down Expand Up @@ -147,7 +147,7 @@ def test_get_data_with_quality_on_and_prettify_true(
expected_shape: tuple[int, int],
expected_columns: tuple[str],
):
mocker.patch.object(pystatis.db, "check_credentials", return_value=True)
mocker.patch.object(pystatis.db, "check_credentials_are_set", return_value=True)
table = pystatis.Table(name=table_name)
table.get_data(prettify=True, quality="on")

Expand Down Expand Up @@ -732,7 +732,7 @@ def test_prettify(
expected_columns: tuple[str],
language: str,
):
mocker.patch.object(pystatis.db, "check_credentials", return_value=True)
mocker.patch.object(pystatis.db, "check_credentials_are_set", return_value=True)
table = pystatis.Table(name=table_name)
table.get_data(prettify=True, language=language)

Expand Down Expand Up @@ -760,7 +760,7 @@ def test_prettify(
],
)
def test_dtype_time_column(mocker, table_name: str, time_col: str, language: str):
mocker.patch.object(pystatis.db, "check_credentials", return_value=True)
mocker.patch.object(pystatis.db, "check_credentials_are_set", return_value=True)
table = pystatis.Table(name=table_name)
table.get_data(prettify=True, language=language)

Expand Down
Loading