Skip to content

Commit

Permalink
update jwt static key steps
Browse files Browse the repository at this point in the history
  • Loading branch information
alsugiliazova committed Dec 4, 2024
1 parent 7e210f6 commit 85525d6
Showing 1 changed file with 48 additions and 13 deletions.
61 changes: 48 additions & 13 deletions jwt_authentication/tests/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,24 @@
getuid,
)

HMAC_algorithms = ["HS256", "HS384", "HS512"]
RSA_algorithms = ["RS256", "RS384", "RS512"]
ECDSA_algorithms = ["ES256", "ES384", "ES512", "ES256K"]
PSS_algorithms = ["PS256", "PS384", "PS512"]
EdDSA_algorithms = ["Ed25519", "Ed448"]
algorithm_groups = {
"HMAC": ["HS256", "HS384", "HS512"],
"RSA": ["RS256", "RS384", "RS512"],
"ECDSA": ["ES256", "ES384", "ES512", "ES256K"],
"PSS": ["PS256", "PS384", "PS512"],
"EdDSA": ["Ed25519", "Ed448"],
}


def algorithm_from_same_group(algorithm1: str, algorithm2: str) -> bool:
"""Check if two algorithms are from the same group."""
group1 = next(
(group for group, algs in algorithm_groups.items() if algorithm1 in algs), None
)
group2 = next(
(group for group, algs in algorithm_groups.items() if algorithm2 in algs), None
)
return group1 == group2


def create_static_jwt(
Expand Down Expand Up @@ -92,6 +105,23 @@ def create_static_jwt(
return jwt.encode(payload, private_key, algorithm=algorithm, headers=headers)


def get_public_key_from_private(private_key_path: str) -> str:
"""Get the public key from the private key."""
with open(private_key_path, "rb") as key_file:
private_key = serialization.load_pem_private_key(
key_file.read(),
password=None,
)

public_key = private_key.public_key()

public_key_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return public_key_pem.decode()


@TestStep(Given)
def change_clickhouse_config(
self,
Expand Down Expand Up @@ -174,19 +204,23 @@ def add_static_key_validator_to_config_xml(
static_key_in_base64: str = None,
public_key: str = None,
restart=True,
node: Node = None,
):
"""Add static key validator to the config.xml."""

entries = {"jwt_validators": {}}
entries["jwt_validators"][f"{validator_id}"] = {}
entries["jwt_validators"][f"{validator_id}"]["algo"] = algorithm.lower()

if algorithm is not None:
entries["jwt_validators"][f"{validator_id}"]["algo"] = algorithm.lower()

if secret is not None:
entries["jwt_validators"][f"{validator_id}"]["static_key"] = secret
if static_key_in_base64 is not None:
entries["jwt_validators"][f"{validator_id}"][
"static_key_in_base64"
] = static_key_in_base64

if static_key_in_base64 is not None:
entries["jwt_validators"][f"{validator_id}"][
"static_key_in_base64"
] = static_key_in_base64

if public_key is not None:
entries["jwt_validators"][f"{validator_id}"]["public_key"] = public_key
Expand All @@ -197,6 +231,7 @@ def add_static_key_validator_to_config_xml(
preprocessed_name="config.xml",
restart=restart,
config_file=f"{validator_id}.xml",
node=node,
)


Expand Down Expand Up @@ -428,10 +463,10 @@ def generate_ssh_keys(self, key_type: str = None, algorithm: str = "RS256"):
private_key_file = f"private_key_{getuid()}"
public_key_file = f"{private_key_file}.pub"

if algorithm in RSA_algorithms + PSS_algorithms:
if algorithm in algorithm_groups["RSA"] + algorithm_groups["PSS"]:
key_type = "rsa"
command = f"openssl genpkey -algorithm {key_type} -out {private_key_file}"
elif algorithm in ECDSA_algorithms:
elif algorithm in algorithm_groups["ECDSA"]:
key_type = "ec"
curve_map = {
"ES256": "prime256v1",
Expand All @@ -443,7 +478,7 @@ def generate_ssh_keys(self, key_type: str = None, algorithm: str = "RS256"):
command = (
f"openssl ecparam -name {curve_name} -genkey -noout -out {private_key_file}"
)
elif algorithm in EdDSA_algorithms:
elif algorithm in algorithm_groups["EdDSA"]:
key_type = algorithm
command = f"openssl genpkey -algorithm {key_type} -out {private_key_file}"
else:
Expand Down

0 comments on commit 85525d6

Please sign in to comment.