Skip to content

Commit

Permalink
Merge branch 'main' into enh-parse-images-from-helm
Browse files Browse the repository at this point in the history
  • Loading branch information
viniciusdc authored Oct 30, 2024
2 parents 0341b57 + 88dfe24 commit 6cfe0d3
Show file tree
Hide file tree
Showing 6 changed files with 671 additions and 623 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ on:
push:
branches:
- main
- develop
- release/\d{4}.\d{1,2}.\d{1,2}
paths:
- ".github/workflows/test.yaml"
Expand Down
1,191 changes: 605 additions & 586 deletions RELEASE.md

Large diffs are not rendered by default.

36 changes: 16 additions & 20 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class AzureInputVars(schema.Base):
workload_identity_enabled: bool = False


class AWSAmiTypes(enum.Enum):
class AWSAmiTypes(str, enum.Enum):
AL2_x86_64 = "AL2_x86_64"
AL2_x86_64_GPU = "AL2_x86_64_GPU"
CUSTOM = "CUSTOM"
Expand All @@ -151,25 +151,17 @@ class AWSNodeGroupInputVars(schema.Base):
ami_type: Optional[AWSAmiTypes] = None
launch_template: Optional[AWSNodeLaunchTemplate] = None

@field_validator("ami_type", mode="before")
@classmethod
def _infer_and_validate_ami_type(cls, value, values) -> str:
gpu_enabled = values.get("gpu", False)

# Auto-set ami_type if not provided
if not value:
if values.get("launch_template") and values["launch_template"].ami_id:
return "CUSTOM"
if gpu_enabled:
return "AL2_x86_64_GPU"
return "AL2_x86_64"

# Explicit validation
if value == "AL2_x86_64" and gpu_enabled:
raise ValueError(
"ami_type 'AL2_x86_64' cannot be used with GPU enabled (gpu=True)."
)
return value

def construct_aws_ami_type(gpu_enabled: bool, launch_template: AWSNodeLaunchTemplate):
"""Construct the AWS AMI type based on the provided parameters."""

if launch_template and launch_template.ami_id:
return "CUSTOM"

if gpu_enabled:
return "AL2_x86_64_GPU"

return "AL2_x86_64"


class AWSInputVars(schema.Base):
Expand Down Expand Up @@ -858,6 +850,10 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
single_subnet=node_group.single_subnet,
permissions_boundary=node_group.permissions_boundary,
launch_template=node_group.launch_template,
ami_type=construct_aws_ami_type(
gpu_enabled=node_group.gpu,
launch_template=node_group.launch_template,
),
)
for name, node_group in self.config.amazon_web_services.node_groups.items()
],
Expand Down
24 changes: 13 additions & 11 deletions src/_nebari/stages/terraform_state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from typing import Any, Dict, List, Optional, Tuple, Type

from pydantic import field_validator
from pydantic import BaseModel, field_validator

from _nebari import utils
from _nebari.provider import terraform
Expand Down Expand Up @@ -260,7 +260,7 @@ def check_immutable_fields(self):

# compute diff of remote/prior and current nebari config
nebari_config_diff = utils.JsonDiff(
nebari_config_state.model_dump(), self.config.model_dump()
nebari_config_state, self.config.model_dump()
)
# check if any changed fields are immutable
for keys, old, new in nebari_config_diff.modified():
Expand All @@ -275,16 +275,22 @@ def check_immutable_fields(self):
bottom_level_schema = bottom_level_schema[key]
else:
raise e
extra_field_schema = schema.ExtraFieldSchema(
**bottom_level_schema.model_fields[keys[-1]].json_schema_extra or {}
)

# Return a default (mutable) extra field schema if bottom level is not a Pydantic model (such as a free-form 'overrides' block)
if isinstance(bottom_level_schema, BaseModel):
extra_field_schema = schema.ExtraFieldSchema(
**bottom_level_schema.model_fields[keys[-1]].json_schema_extra or {}
)
else:
extra_field_schema = schema.ExtraFieldSchema()

if extra_field_schema.immutable:
key_path = ".".join(keys)
raise ValueError(
f'Attempting to change immutable field "{key_path}" ("{old}"->"{new}") in Nebari config file. Immutable fields cannot be changed after initial deployment.'
)

def get_nebari_config_state(self):
def get_nebari_config_state(self) -> dict:
directory = str(self.output_directory / self.stage_prefix)
tf_state = terraform.show(directory)
nebari_config_state = None
Expand All @@ -294,11 +300,7 @@ def get_nebari_config_state(self):
tf_state.get("values", {}).get("root_module", {}).get("resources", [])
):
if resource["address"] == "terraform_data.nebari_config":
from nebari.plugins import nebari_plugin_manager

nebari_config_state = nebari_plugin_manager.config_schema(
**resource["values"]["input"]
)
nebari_config_state = resource["values"]["input"]
break
return nebari_config_state

Expand Down
5 changes: 5 additions & 0 deletions tests/tests_unit/cli_validate/local.happy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ theme:
certificate:
type: lets-encrypt
acme_email: [email protected]
jupyterhub:
overrides:
singleuser:
extraEnv:
TEST_ENV: "my_env"
37 changes: 32 additions & 5 deletions tests/tests_unit/test_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def terraform_state_stage(mock_config, tmp_path):

@patch.object(TerraformStateStage, "get_nebari_config_state")
def test_check_immutable_fields_no_changes(mock_get_state, terraform_state_stage):
mock_get_state.return_value = terraform_state_stage.config
mock_get_state.return_value = terraform_state_stage.config.model_dump()

# This should not raise an exception
terraform_state_stage.check_immutable_fields()
Expand All @@ -41,7 +41,7 @@ def test_check_immutable_fields_mutable_change(
):
old_config = mock_config.model_copy(deep=True)
old_config.namespace = "old-namespace"
mock_get_state.return_value = old_config
mock_get_state.return_value = old_config.model_dump()

# This should not raise an exception (namespace is mutable)
terraform_state_stage.check_immutable_fields()
Expand All @@ -54,7 +54,7 @@ def test_check_immutable_fields_immutable_change(
):
old_config = mock_config.model_copy(deep=True)
old_config.provider = schema.ProviderEnum.gcp
mock_get_state.return_value = old_config
mock_get_state.return_value = old_config.model_dump()

# Mock the provider field to be immutable
mock_model_fields.__getitem__.return_value.json_schema_extra = {"immutable": True}
Expand All @@ -77,7 +77,7 @@ def test_check_immutable_fields_no_prior_state(mock_get_state, terraform_state_s
def test_check_dict_value_change(mock_get_state, terraform_state_stage, mock_config):
old_config = mock_config.model_copy(deep=True)
terraform_state_stage.config.local.node_selectors["worker"].value += "new_value"
mock_get_state.return_value = old_config
mock_get_state.return_value = old_config.model_dump()

# should not throw an exception
terraform_state_stage.check_immutable_fields()
Expand All @@ -87,7 +87,34 @@ def test_check_dict_value_change(mock_get_state, terraform_state_stage, mock_con
def test_check_list_change(mock_get_state, terraform_state_stage, mock_config):
old_config = mock_config.model_copy(deep=True)
old_config.environments["environment-dask.yaml"].channels.append("defaults")
mock_get_state.return_value = old_config
mock_get_state.return_value = old_config.model_dump()

# should not throw an exception
terraform_state_stage.check_immutable_fields()


@patch.object(TerraformStateStage, "get_nebari_config_state")
def test_check_immutable_fields_old_nebari_version(
mock_get_state, terraform_state_stage, mock_config
):
old_config = mock_config.model_copy(deep=True).model_dump()
old_config["nebari_version"] = "2024.7.1" # Simulate an old version
mock_get_state.return_value = old_config

# This should not raise an exception
terraform_state_stage.check_immutable_fields()


@patch.object(TerraformStateStage, "get_nebari_config_state")
def test_check_immutable_fields_change_dict_any(
mock_get_state, terraform_state_stage, mock_config
):
old_config = mock_config.model_copy(deep=True).model_dump()
# Change the value of a config deep in 'overrides' block
old_config["jupyterhub"]["overrides"]["singleuser"]["extraEnv"][
"TEST_ENV"
] = "new_value"
mock_get_state.return_value = old_config

# This should not raise an exception
terraform_state_stage.check_immutable_fields()

0 comments on commit 6cfe0d3

Please sign in to comment.