Skip to content

Commit

Permalink
Yaml config sets (#2876)
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam-D-Lewis authored Jan 6, 2025
1 parent 5c90b2e commit ff66c22
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 6 deletions.
54 changes: 54 additions & 0 deletions src/_nebari/config_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
import pathlib
from typing import Optional

from packaging.requirements import SpecifierSet
from pydantic import BaseModel, ConfigDict, field_validator

from _nebari._version import __version__
from _nebari.utils import yaml

logger = logging.getLogger(__name__)


class ConfigSetMetadata(BaseModel):
model_config: ConfigDict = ConfigDict(extra="allow", arbitrary_types_allowed=True)
name: str # for use with guided init
description: Optional[str] = None
nebari_version: str | SpecifierSet

@field_validator("nebari_version")
@classmethod
def validate_version_requirement(cls, version_req):
if isinstance(version_req, str):
version_req = SpecifierSet(version_req, prereleases=True)

return version_req

def check_version(self, version):
if not self.nebari_version.contains(version, prereleases=True):
raise ValueError(
f'Nebari version "{version}" is not compatible with '
f'version requirement {self.nebari_version} for "{self.name}" config set.'
)


class ConfigSet(BaseModel):
metadata: ConfigSetMetadata
config: dict


def read_config_set(config_set_filepath: str):
"""Read a config set from a config file."""

filename = pathlib.Path(config_set_filepath)

with filename.open() as f:
config_set_yaml = yaml.load(f)

config_set = ConfigSet(**config_set_yaml)

# validation
config_set.metadata.check_version(__version__)

return config_set
10 changes: 8 additions & 2 deletions src/_nebari/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pydantic
import requests

from _nebari import constants
from _nebari import constants, utils
from _nebari.config_set import read_config_set
from _nebari.provider import git
from _nebari.provider.cicd import github
from _nebari.provider.cloud import amazon_web_services, azure_cloud, google_cloud
Expand Down Expand Up @@ -47,6 +48,7 @@ def render_config(
region: str = None,
disable_prompt: bool = False,
ssl_cert_email: str = None,
config_set: str = None,
) -> Dict[str, Any]:
config = {
"provider": cloud_provider,
Expand Down Expand Up @@ -176,13 +178,17 @@ def render_config(
config["certificate"] = {"type": CertificateEnum.letsencrypt.value}
config["certificate"]["acme_email"] = ssl_cert_email

if config_set:
config_set = read_config_set(config_set)
config = utils.deep_merge(config, config_set.config)

# validate configuration and convert to model
from nebari.plugins import nebari_plugin_manager

try:
config_model = nebari_plugin_manager.config_schema.model_validate(config)
except pydantic.ValidationError as e:
print(str(e))
raise e

if repository_auto_provision:
match = re.search(github_url_regex, repository)
Expand Down
16 changes: 15 additions & 1 deletion src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import sys
import tempfile
import warnings
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union

from pydantic import ConfigDict, Field, field_validator, model_validator
Expand Down Expand Up @@ -614,11 +615,23 @@ def check_provider(cls, data: Any) -> Any:
data[provider] = provider_enum_model_map[provider]()
else:
# if the provider field is invalid, it won't be set when this validator is called
# so we need to check for it explicitly here, and set the `pre` to True
# so we need to check for it explicitly here, and set mode to "before"
# TODO: this is a workaround, check if there is a better way to do this in Pydantic v2
raise ValueError(
f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure"
)
set_providers = {
provider
for provider in provider_name_abbreviation_map.keys()
if provider in data and data[provider]
}
expected_provider_config = provider_enum_name_map[provider]
extra_provider_config = set_providers - {expected_provider_config}
if extra_provider_config:
warnings.warn(
f"Provider is set to {getattr(provider, 'value', provider)}, but configuration defined for other providers: {extra_provider_config}"
)

else:
set_providers = [
provider
Expand All @@ -632,6 +645,7 @@ def check_provider(cls, data: Any) -> Any:
data["provider"] = provider_name_abbreviation_map[set_providers[0]]
elif num_providers == 0:
data["provider"] = schema.ProviderEnum.local.value

return data


Expand Down
9 changes: 9 additions & 0 deletions src/_nebari/subcommands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class InitInputs(schema.Base):
region: Optional[str] = None
ssl_cert_email: Optional[schema.email_pydantic] = None
disable_prompt: bool = False
config_set: Optional[str] = None
output: pathlib.Path = pathlib.Path("nebari-config.yaml")
explicit: int = 0

Expand Down Expand Up @@ -134,6 +135,7 @@ def handle_init(inputs: InitInputs, config_schema: BaseModel):
terraform_state=inputs.terraform_state,
ssl_cert_email=inputs.ssl_cert_email,
disable_prompt=inputs.disable_prompt,
config_set=inputs.config_set,
)

try:
Expand Down Expand Up @@ -496,6 +498,12 @@ def init(
False,
is_eager=True,
),
config_set: str = typer.Option(
None,
"--config-set",
"-s",
help="Apply a pre-defined set of nebari configuration options.",
),
output: str = typer.Option(
pathlib.Path("nebari-config.yaml"),
"--output",
Expand Down Expand Up @@ -554,6 +562,7 @@ def init(
inputs.terraform_state = terraform_state
inputs.ssl_cert_email = ssl_cert_email
inputs.disable_prompt = disable_prompt
inputs.config_set = config_set
inputs.output = output
inputs.explicit = explicit

Expand Down
4 changes: 2 additions & 2 deletions src/_nebari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def modified_environ(*remove: List[str], **update: Dict[str, str]):


def deep_merge(*args):
"""Deep merge multiple dictionaries.
"""Deep merge multiple dictionaries. Preserves order in dicts and lists.
>>> value_1 = {
'a': [1, 2],
Expand Down Expand Up @@ -190,7 +190,7 @@ def deep_merge(*args):

if isinstance(d1, dict) and isinstance(d2, dict):
d3 = {}
for key in d1.keys() | d2.keys():
for key in tuple(d1.keys()) + tuple(d2.keys()):
if key in d1 and key in d2:
d3[key] = deep_merge(d1[key], d2[key])
elif key in d1:
Expand Down
73 changes: 73 additions & 0 deletions tests/tests_unit/test_config_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from unittest.mock import patch

import pytest
from packaging.requirements import SpecifierSet

from _nebari.config_set import ConfigSetMetadata, read_config_set

test_version = "2024.12.2"


@pytest.mark.parametrize(
"version_input,test_version,should_pass",
[
# Standard version tests
(">=2024.12.0,<2025.0.0", "2024.12.2", True),
(SpecifierSet(">=2024.12.0,<2025.0.0"), "2024.12.2", True),
# Pre-release version requirement tests
(">=2024.12.0rc1,<2025.0.0", "2024.12.0rc1", True),
(SpecifierSet(">=2024.12.0rc1"), "2024.12.0rc2", True),
# Pre-release test version against standard requirement
(">=2024.12.0,<2025.0.0", "2024.12.1rc1", True),
(SpecifierSet(">=2024.12.0,<2025.0.0"), "2024.12.1rc1", True),
# Failing cases
(">=2025.0.0", "2024.12.2rc1", False),
(SpecifierSet(">=2025.0.0rc1"), "2024.12.2", False),
],
)
def test_version_requirement(version_input, test_version, should_pass):
metadata = ConfigSetMetadata(name="test-config", nebari_version=version_input)

if should_pass:
metadata.check_version(test_version)
else:
with pytest.raises(ValueError) as exc_info:
metadata.check_version(test_version)
assert "Nebari version" in str(exc_info.value)


def test_read_config_set_valid(tmp_path):
config_set_yaml = """
metadata:
name: test-config
nebari_version: ">=2024.12.0"
config:
key: value
"""
config_set_filepath = tmp_path / "config_set.yaml"
config_set_filepath.write_text(config_set_yaml)
with patch("_nebari.config_set.__version__", "2024.12.2"):
config_set = read_config_set(str(config_set_filepath))
assert config_set.metadata.name == "test-config"
assert config_set.config["key"] == "value"


def test_read_config_set_invalid_version(tmp_path):
config_set_yaml = """
metadata:
name: test-config
nebari_version: ">=2025.0.0"
config:
key: value
"""
config_set_filepath = tmp_path / "config_set.yaml"
config_set_filepath.write_text(config_set_yaml)

with patch("_nebari.config_set.__version__", "2024.12.2"):
with pytest.raises(ValueError) as exc_info:
read_config_set(str(config_set_filepath))
assert "Nebari version" in str(exc_info.value)


if __name__ == "__main__":
pytest.main()
10 changes: 10 additions & 0 deletions tests/tests_unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,13 @@ def test_set_provider(config_schema, provider):
result_config_dict = config.model_dump()
assert provider in result_config_dict
assert result_config_dict[provider]["kube_context"] == "some_context"


def test_provider_config_mismatch_warning(config_schema):
config_dict = {
"project_name": "test",
"provider": "local",
"existing": {"kube_context": "some_context"}, # <-- Doesn't match the provider
}
with pytest.warns(UserWarning, match="configuration defined for other providers"):
config_schema(**config_dict)
1 change: 1 addition & 0 deletions tests/tests_unit/test_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_check_immutable_fields_immutable_change(
mock_model_fields, mock_get_state, terraform_state_stage, mock_config
):
old_config = mock_config.model_copy(deep=True)
old_config.local = None
old_config.provider = schema.ProviderEnum.gcp
mock_get_state.return_value = old_config.model_dump()

Expand Down
74 changes: 73 additions & 1 deletion tests/tests_unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion
from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion, deep_merge


@pytest.mark.parametrize(
Expand Down Expand Up @@ -64,3 +64,75 @@ def test_JsonDiff_modified():
diff = JsonDiff(obj1, obj2)
modifieds = diff.modified()
assert sorted(modifieds) == sorted([(["b", "!"], 2, 3), (["+"], 4, 5)])


def test_deep_merge_order_preservation_dict():
value_1 = {
"a": [1, 2],
"b": {"c": 1, "z": [5, 6]},
"e": {"f": {"g": {}}},
"m": 1,
}

value_2 = {
"a": [3, 4],
"b": {"d": 2, "z": [7]},
"e": {"f": {"h": 1}},
"m": [1],
}

expected_result = {
"a": [1, 2, 3, 4],
"b": {"c": 1, "z": [5, 6, 7], "d": 2},
"e": {"f": {"g": {}, "h": 1}},
"m": 1,
}

result = deep_merge(value_1, value_2)
assert result == expected_result
assert list(result.keys()) == list(expected_result.keys())
assert list(result["b"].keys()) == list(expected_result["b"].keys())
assert list(result["e"]["f"].keys()) == list(expected_result["e"]["f"].keys())


def test_deep_merge_order_preservation_list():
value_1 = {
"a": [1, 2],
"b": {"c": 1, "z": [5, 6]},
}

value_2 = {
"a": [3, 4],
"b": {"d": 2, "z": [7]},
}

expected_result = {
"a": [1, 2, 3, 4],
"b": {"c": 1, "z": [5, 6, 7], "d": 2},
}

result = deep_merge(value_1, value_2)
assert result == expected_result
assert result["a"] == expected_result["a"]
assert result["b"]["z"] == expected_result["b"]["z"]


def test_deep_merge_single_dict():
value_1 = {
"a": [1, 2],
"b": {"c": 1, "z": [5, 6]},
}

expected_result = value_1

result = deep_merge(value_1)
assert result == expected_result
assert list(result.keys()) == list(expected_result.keys())
assert list(result["b"].keys()) == list(expected_result["b"].keys())


def test_deep_merge_empty():
expected_result = {}

result = deep_merge()
assert result == expected_result

0 comments on commit ff66c22

Please sign in to comment.