diff --git a/jwt_authentication/tests/steps.py b/jwt_authentication/tests/steps.py index 7317b99df..894f665d5 100644 --- a/jwt_authentication/tests/steps.py +++ b/jwt_authentication/tests/steps.py @@ -19,11 +19,24 @@ getuid, ) -HMAC_algorithms = ["HS256", "HS384", "HS512"] -RSA_algorithms = ["RS256", "RS384", "RS512"] -ECDSA_algorithms = ["ES256", "ES384", "ES512", "ES256K"] -PSS_algorithms = ["PS256", "PS384", "PS512"] -EdDSA_algorithms = ["Ed25519", "Ed448"] +algorithm_groups = { + "HMAC": ["HS256", "HS384", "HS512"], + "RSA": ["RS256", "RS384", "RS512"], + "ECDSA": ["ES256", "ES384", "ES512", "ES256K"], + "PSS": ["PS256", "PS384", "PS512"], + "EdDSA": ["Ed25519", "Ed448"], +} + + +def algorithm_from_same_group(algorithm1: str, algorithm2: str) -> bool: + """Check if two algorithms are from the same group.""" + group1 = next( + (group for group, algs in algorithm_groups.items() if algorithm1 in algs), None + ) + group2 = next( + (group for group, algs in algorithm_groups.items() if algorithm2 in algs), None + ) + return group1 == group2 def create_static_jwt( @@ -92,6 +105,23 @@ def create_static_jwt( return jwt.encode(payload, private_key, algorithm=algorithm, headers=headers) +def get_public_key_from_private(private_key_path: str) -> str: + """Get the public key from the private key.""" + with open(private_key_path, "rb") as key_file: + private_key = serialization.load_pem_private_key( + key_file.read(), + password=None, + ) + + public_key = private_key.public_key() + + public_key_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return public_key_pem.decode() + + @TestStep(Given) def change_clickhouse_config( self, @@ -174,19 +204,23 @@ def add_static_key_validator_to_config_xml( static_key_in_base64: str = None, public_key: str = None, restart=True, + node: Node = None, ): """Add static key validator to the config.xml.""" entries = {"jwt_validators": {}} entries["jwt_validators"][f"{validator_id}"] = {} - entries["jwt_validators"][f"{validator_id}"]["algo"] = algorithm.lower() + + if algorithm is not None: + entries["jwt_validators"][f"{validator_id}"]["algo"] = algorithm.lower() if secret is not None: entries["jwt_validators"][f"{validator_id}"]["static_key"] = secret - if static_key_in_base64 is not None: - entries["jwt_validators"][f"{validator_id}"][ - "static_key_in_base64" - ] = static_key_in_base64 + + if static_key_in_base64 is not None: + entries["jwt_validators"][f"{validator_id}"][ + "static_key_in_base64" + ] = static_key_in_base64 if public_key is not None: entries["jwt_validators"][f"{validator_id}"]["public_key"] = public_key @@ -197,6 +231,7 @@ def add_static_key_validator_to_config_xml( preprocessed_name="config.xml", restart=restart, config_file=f"{validator_id}.xml", + node=node, ) @@ -428,10 +463,10 @@ def generate_ssh_keys(self, key_type: str = None, algorithm: str = "RS256"): private_key_file = f"private_key_{getuid()}" public_key_file = f"{private_key_file}.pub" - if algorithm in RSA_algorithms + PSS_algorithms: + if algorithm in algorithm_groups["RSA"] + algorithm_groups["PSS"]: key_type = "rsa" command = f"openssl genpkey -algorithm {key_type} -out {private_key_file}" - elif algorithm in ECDSA_algorithms: + elif algorithm in algorithm_groups["ECDSA"]: key_type = "ec" curve_map = { "ES256": "prime256v1", @@ -443,7 +478,7 @@ def generate_ssh_keys(self, key_type: str = None, algorithm: str = "RS256"): command = ( f"openssl ecparam -name {curve_name} -genkey -noout -out {private_key_file}" ) - elif algorithm in EdDSA_algorithms: + elif algorithm in algorithm_groups["EdDSA"]: key_type = algorithm command = f"openssl genpkey -algorithm {key_type} -out {private_key_file}" else: