Skip to content

Commit

Permalink
Merge branch 'main' into ci/authentication/use-github-secrets-instead…
Browse files Browse the repository at this point in the history
…-of-vault
  • Loading branch information
marcelovilla authored Jan 16, 2025
2 parents 470e981 + 00a6545 commit d3a7b90
Show file tree
Hide file tree
Showing 27 changed files with 984 additions and 218 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_aws_integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ on:
env:
AWS_DEFAULT_REGION: "us-west-2"
NEBARI_IMAGE_TAG: ${{ github.event.inputs.image-tag || 'main' }}
TF_LOG: ${{ github.event.inputs.tf-log-level || 'info' }}
TF_LOG: ${{ github.event.inputs.tf-log-level || 'info' }}

jobs:
test-aws-integration:
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ repos:
exclude: "^src/_nebari/template/"

- repo: https://github.com/crate-ci/typos
rev: typos-dict-v0.11.37
rev: dictgen-v0.3.1
hooks:
- id: typos

Expand All @@ -61,7 +61,7 @@ repos:
args: ["--line-length=88", "--exclude=/src/_nebari/template/"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.1
rev: v0.8.6
hooks:
- id: ruff
args: ["--fix"]
Expand All @@ -77,7 +77,7 @@ repos:

# terraform
- repo: https://github.com/antonbabenko/pre-commit-terraform
rev: v1.96.2
rev: v1.96.3
hooks:
- id: terraform_fmt
args:
Expand Down
262 changes: 148 additions & 114 deletions docs-sphinx/cli.html

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/_nebari/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def common(
[],
"--import-plugin",
help="Import nebari plugin",
callback=import_plugin,
),
excluded_stages: typing.List[str] = typer.Option(
[],
Expand Down
54 changes: 54 additions & 0 deletions src/_nebari/config_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
import pathlib
from typing import Optional

from packaging.requirements import SpecifierSet
from pydantic import BaseModel, ConfigDict, field_validator

from _nebari._version import __version__
from _nebari.utils import yaml

logger = logging.getLogger(__name__)


class ConfigSetMetadata(BaseModel):
model_config: ConfigDict = ConfigDict(extra="allow", arbitrary_types_allowed=True)
name: str # for use with guided init
description: Optional[str] = None
nebari_version: str | SpecifierSet

@field_validator("nebari_version")
@classmethod
def validate_version_requirement(cls, version_req):
if isinstance(version_req, str):
version_req = SpecifierSet(version_req, prereleases=True)

return version_req

def check_version(self, version):
if not self.nebari_version.contains(version, prereleases=True):
raise ValueError(
f'Nebari version "{version}" is not compatible with '
f'version requirement {self.nebari_version} for "{self.name}" config set.'
)


class ConfigSet(BaseModel):
metadata: ConfigSetMetadata
config: dict


def read_config_set(config_set_filepath: str):
"""Read a config set from a config file."""

filename = pathlib.Path(config_set_filepath)

with filename.open() as f:
config_set_yaml = yaml.load(f)

config_set = ConfigSet(**config_set_yaml)

# validation
config_set.metadata.check_version(__version__)

return config_set
2 changes: 1 addition & 1 deletion src/_nebari/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
DEFAULT_NEBARI_IMAGE_TAG = CURRENT_RELEASE
DEFAULT_NEBARI_WORKFLOW_CONTROLLER_IMAGE_TAG = CURRENT_RELEASE

DEFAULT_CONDA_STORE_IMAGE_TAG = "2024.3.1"
DEFAULT_CONDA_STORE_IMAGE_TAG = "2024.11.2"

LATEST_SUPPORTED_PYTHON_VERSION = "3.10"

Expand Down
10 changes: 8 additions & 2 deletions src/_nebari/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pydantic
import requests

from _nebari import constants
from _nebari import constants, utils
from _nebari.config_set import read_config_set
from _nebari.provider import git
from _nebari.provider.cicd import github
from _nebari.provider.cloud import amazon_web_services, azure_cloud, google_cloud
Expand Down Expand Up @@ -47,6 +48,7 @@ def render_config(
region: str = None,
disable_prompt: bool = False,
ssl_cert_email: str = None,
config_set: str = None,
) -> Dict[str, Any]:
config = {
"provider": cloud_provider,
Expand Down Expand Up @@ -176,13 +178,17 @@ def render_config(
config["certificate"] = {"type": CertificateEnum.letsencrypt.value}
config["certificate"]["acme_email"] = ssl_cert_email

if config_set:
config_set = read_config_set(config_set)
config = utils.deep_merge(config, config_set.config)

# validate configuration and convert to model
from nebari.plugins import nebari_plugin_manager

try:
config_model = nebari_plugin_manager.config_schema.model_validate(config)
except pydantic.ValidationError as e:
print(str(e))
raise e

if repository_auto_provision:
match = re.search(github_url_regex, repository)
Expand Down
22 changes: 21 additions & 1 deletion src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import sys
import tempfile
import warnings
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union

from pydantic import ConfigDict, Field, field_validator, model_validator
Expand Down Expand Up @@ -94,6 +95,7 @@ class AzureInputVars(schema.Base):
name: str
environment: str
region: str
authorized_ip_ranges: List[str] = ["0.0.0.0/0"]
kubeconfig_filename: str = get_kubeconfig_filename()
kubernetes_version: str
node_groups: Dict[str, AzureNodeGroupInputVars]
Expand All @@ -104,6 +106,7 @@ class AzureInputVars(schema.Base):
tags: Dict[str, str] = {}
max_pods: Optional[int] = None
network_profile: Optional[Dict[str, str]] = None
azure_policy_enabled: Optional[bool] = None
workload_identity_enabled: bool = False


Expand Down Expand Up @@ -360,6 +363,7 @@ class AzureProvider(schema.Base):
region: str
kubernetes_version: Optional[str] = None
storage_account_postfix: str
authorized_ip_ranges: Optional[List[str]] = ["0.0.0.0/0"]
resource_group_name: Optional[str] = None
node_groups: Dict[str, AzureNodeGroup] = DEFAULT_AZURE_NODE_GROUPS
storage_account_postfix: str
Expand All @@ -370,6 +374,7 @@ class AzureProvider(schema.Base):
network_profile: Optional[Dict[str, str]] = None
max_pods: Optional[int] = None
workload_identity_enabled: bool = False
azure_policy_enabled: Optional[bool] = None

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -613,11 +618,23 @@ def check_provider(cls, data: Any) -> Any:
data[provider] = provider_enum_model_map[provider]()
else:
# if the provider field is invalid, it won't be set when this validator is called
# so we need to check for it explicitly here, and set the `pre` to True
# so we need to check for it explicitly here, and set mode to "before"
# TODO: this is a workaround, check if there is a better way to do this in Pydantic v2
raise ValueError(
f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure"
)
set_providers = {
provider
for provider in provider_name_abbreviation_map.keys()
if provider in data and data[provider]
}
expected_provider_config = provider_enum_name_map[provider]
extra_provider_config = set_providers - {expected_provider_config}
if extra_provider_config:
warnings.warn(
f"Provider is set to {getattr(provider, 'value', provider)}, but configuration defined for other providers: {extra_provider_config}"
)

else:
set_providers = [
provider
Expand All @@ -631,6 +648,7 @@ def check_provider(cls, data: Any) -> Any:
data["provider"] = provider_name_abbreviation_map[set_providers[0]]
elif num_providers == 0:
data["provider"] = schema.ProviderEnum.local.value

return data


Expand Down Expand Up @@ -784,6 +802,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
environment=self.config.namespace,
region=self.config.azure.region,
kubernetes_version=self.config.azure.kubernetes_version,
authorized_ip_ranges=self.config.azure.authorized_ip_ranges,
node_groups={
name: AzureNodeGroupInputVars(
instance=node_group.instance,
Expand All @@ -809,6 +828,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
network_profile=self.config.azure.network_profile,
max_pods=self.config.azure.max_pods,
workload_identity_enabled=self.config.azure.workload_identity_enabled,
azure_policy_enabled=self.config.azure.azure_policy_enabled,
).model_dump()
elif self.config.provider == schema.ProviderEnum.aws:
return AWSInputVars(
Expand Down
2 changes: 2 additions & 0 deletions src/_nebari/stages/infrastructure/template/azure/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ module "kubernetes" {
kubernetes_version = var.kubernetes_version
tags = var.tags
max_pods = var.max_pods
authorized_ip_ranges = var.authorized_ip_ranges

network_profile = var.network_profile

Expand All @@ -43,4 +44,5 @@ module "kubernetes" {
vnet_subnet_id = var.vnet_subnet_id
private_cluster_enabled = var.private_cluster_enabled
workload_identity_enabled = var.workload_identity_enabled
azure_policy_enabled = var.azure_policy_enabled
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ resource "azurerm_kubernetes_cluster" "main" {
location = var.location
resource_group_name = var.resource_group_name
tags = var.tags
api_server_access_profile {
authorized_ip_ranges = var.authorized_ip_ranges
}

# To enable Azure AD Workload Identity oidc_issuer_enabled must be set to true.
oidc_issuer_enabled = var.workload_identity_enabled
Expand All @@ -15,6 +18,9 @@ resource "azurerm_kubernetes_cluster" "main" {
# Azure requires that a new, non-existent Resource Group is used, as otherwise the provisioning of the Kubernetes Service will fail.
node_resource_group = var.node_resource_group_name
private_cluster_enabled = var.private_cluster_enabled
# https://learn.microsoft.com/en-ie/azure/governance/policy/concepts/policy-for-kubernetes
azure_policy_enabled = var.azure_policy_enabled


dynamic "network_profile" {
for_each = var.network_profile != null ? [var.network_profile] : []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,15 @@ variable "workload_identity_enabled" {
type = bool
default = false
}

variable "authorized_ip_ranges" {
description = "The ip range allowed to access the Kubernetes API server, defaults to 0.0.0.0/0"
type = list(string)
default = ["0.0.0.0/0"]
}

variable "azure_policy_enabled" {
description = "Enable Azure Policy"
type = bool
default = false
}
12 changes: 12 additions & 0 deletions src/_nebari/stages/infrastructure/template/azure/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,15 @@ variable "workload_identity_enabled" {
type = bool
default = false
}

variable "authorized_ip_ranges" {
description = "The ip range allowed to access the Kubernetes API server, defaults to 0.0.0.0/0"
type = list(string)
default = ["0.0.0.0/0"]
}

variable "azure_policy_enabled" {
description = "Enable Azure Policy"
type = bool
default = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from pathlib import Path

import requests
from conda_store_server import api, orm, schema
from conda_store_server import api
from conda_store_server._internal import schema
from conda_store_server._internal.server.dependencies import get_conda_store
from conda_store_server.server.auth import GenericOAuthAuthentication
from conda_store_server.server.dependencies import get_conda_store
from conda_store_server.storage import S3Storage


Expand Down Expand Up @@ -422,8 +423,7 @@ async def authenticate(self, request):
for namespace in namespaces:
_namespace = api.get_namespace(db, name=namespace)
if _namespace is None:
db.add(orm.Namespace(name=namespace))
db.commit()
api.ensure_namespace(db, name=namespace)

return schema.AuthenticationToken(
primary_namespace=username,
Expand Down
15 changes: 13 additions & 2 deletions src/_nebari/subcommands/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@

@hookimpl
def nebari_subcommand(cli: typer.Typer):
EXTERNAL_PLUGIN_STYLE = "cyan"

@cli.command()
def info(ctx: typer.Context):
from nebari.plugins import nebari_plugin_manager

rich.print(f"Nebari version: {__version__}")

external_plugins = nebari_plugin_manager.get_external_plugins()

hooks = collections.defaultdict(list)
for plugin in nebari_plugin_manager.plugin_manager.get_plugins():
for hook in nebari_plugin_manager.plugin_manager.get_hookcallers(plugin):
Expand All @@ -27,7 +31,8 @@ def info(ctx: typer.Context):

for hook_name, modules in hooks.items():
for module in modules:
table.add_row(hook_name, module)
style = EXTERNAL_PLUGIN_STYLE if module in external_plugins else None
table.add_row(hook_name, module, style=style)

rich.print(table)

Expand All @@ -36,8 +41,14 @@ def info(ctx: typer.Context):
table.add_column("priority")
table.add_column("module")
for stage in nebari_plugin_manager.ordered_stages:
style = (
EXTERNAL_PLUGIN_STYLE if stage.__module__ in external_plugins else None
)
table.add_row(
stage.name, str(stage.priority), f"{stage.__module__}.{stage.__name__}"
stage.name,
str(stage.priority),
f"{stage.__module__}.{stage.__name__}",
style=style,
)

rich.print(table)
Loading

0 comments on commit d3a7b90

Please sign in to comment.