Skip to content

Commit

Permalink
feat: remote cluster support
Browse files Browse the repository at this point in the history
cleanup formatting

feat: remote cluster support
  • Loading branch information
Ralf Grubenmann committed Jun 19, 2024
1 parent 77b1ecc commit 782d920
Show file tree
Hide file tree
Showing 15 changed files with 223 additions and 142 deletions.
7 changes: 7 additions & 0 deletions example.config.hocon
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,11 @@ git {
}
k8s {
namespace = namespace_where_notebooks_run
remote_clusters = [
{
name = remote_cluster
namespace = notebooks_namespace_in_remote_cluster
kube_config_path = path_where_kubeconfig_is_mounted
}
]
}
4 changes: 1 addition & 3 deletions renku_notebooks/api/amalthea_patches/cloudstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def main(server: "UserServer") -> list[dict[str, Any]]:
return []
for i, cloud_storage_request in enumerate(server.cloudstorage):
cloud_storage_patches.extend(
cloud_storage_request.get_manifest_patch(
f"{server.server_name}-ds-{i}", server.k8s_client.preferred_namespace
)
cloud_storage_request.get_manifest_patch(f"{server.server_name}-ds-{i}", server.preferred_namespace)
)
if server.repositories:
cloud_storage_patches.append(
Expand Down
22 changes: 6 additions & 16 deletions renku_notebooks/api/amalthea_patches/git_sidecar.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def main(server: "UserServer"):
},
{
"name": "GIT_RPC_SENTRY__ENABLED",
"value": str(
config.sessions.git_rpc_server.sentry.enabled
).lower(),
"value": str(config.sessions.git_rpc_server.sentry.enabled).lower(),
},
{
"name": "GIT_RPC_SENTRY__DSN",
Expand All @@ -79,9 +77,7 @@ def main(server: "UserServer"):
},
{
"name": "GIT_RPC_SENTRY__SAMPLE_RATE",
"value": str(
config.sessions.git_rpc_server.sentry.sample_rate
),
"value": str(config.sessions.git_rpc_server.sentry.sample_rate),
},
{
"name": "SENTRY_RELEASE",
Expand Down Expand Up @@ -158,16 +154,12 @@ def main(server: "UserServer"):
{
"op": "add",
"path": "/statefulset/spec/template/spec/containers/1/args/-",
"value": (
f"--skip-auth-route=^/sessions/{server.server_name}/sidecar/health$"
),
"value": (f"--skip-auth-route=^/sessions/{server.server_name}/sidecar/health$"),
},
{
"op": "add",
"path": "/statefulset/spec/template/spec/containers/1/args/-",
"value": (
f"--skip-auth-route=^/sessions/{server.server_name}/sidecar/health/$"
),
"value": (f"--skip-auth-route=^/sessions/{server.server_name}/sidecar/health/$"),
},
{
"op": "add",
Expand All @@ -177,9 +169,7 @@ def main(server: "UserServer"):
{
"op": "add",
"path": "/statefulset/spec/template/spec/containers/1/args/-",
"value": (
f"--skip-auth-route=^/sessions/{server.server_name}/sidecar/jsonrpc/map$"
),
"value": (f"--skip-auth-route=^/sessions/{server.server_name}/sidecar/jsonrpc/map$"),
},
{
"op": "add",
Expand All @@ -202,7 +192,7 @@ def main(server: "UserServer"):
"kind": "Service",
"metadata": {
"name": f"{server.server_name}-rpc-server",
"namespace": server.k8s_client.preferred_namespace,
"namespace": server.preferred_namespace,
},
"spec": {
"ports": [
Expand Down
2 changes: 1 addition & 1 deletion renku_notebooks/api/amalthea_patches/jupyter_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def image_pull_secret(server: "UserServer"):
"kind": "Secret",
"metadata": {
"name": image_pull_secret_name,
"namespace": server._k8s_client.preferred_namespace,
"namespace": server.preferred_namespace,
},
"type": "kubernetes.io/dockerconfigjson",
},
Expand Down
2 changes: 1 addition & 1 deletion renku_notebooks/api/classes/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_image_manifest(self, image: "Image") -> Optional[dict[str, Any]]:
"""Query the docker API to get the manifest of an image."""
if image.hostname != self.hostname:
raise ImageParseError(
f"The image hostname {image.hostname} does not match " f"the image repository {self.hostname}"
f"The image hostname {image.hostname} does not match the image repository {self.hostname}"
)
token = self._get_docker_token(image)
image_digest_url = f"https://{image.hostname}/v2/{image.name}/manifests/{image.tag}"
Expand Down
158 changes: 102 additions & 56 deletions renku_notebooks/api/classes/k8s_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,7 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
(None, None),
)
secrets_container_index, secrets_container = next(
(
(i, c)
for i, c in enumerate(init_containers)
if c.name == "init-user-secrets"
),
((i, c) for i, c in enumerate(init_containers) if c.name == "init-user-secrets"),
(None, None),
)

Expand All @@ -294,16 +290,11 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
else None
)
secrets_renku_access_token_env = (
find_env_var(secrets_container, "RENKU_ACCESS_TOKEN")
if secrets_container is not None
else None
find_env_var(secrets_container, "RENKU_ACCESS_TOKEN") if secrets_container is not None else None
)

patches = list()
if (
git_proxy_container_index is not None
and git_proxy_renku_access_token_env is not None
):
if git_proxy_container_index is not None and git_proxy_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -314,10 +305,7 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"value": renku_tokens.access_token,
}
)
if (
git_proxy_container_index is not None
and git_proxy_renku_refresh_token_env is not None
):
if git_proxy_container_index is not None and git_proxy_renku_refresh_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -328,10 +316,7 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"value": renku_tokens.refresh_token,
},
)
if (
git_clone_container_index is not None
and git_clone_renku_access_token_env is not None
):
if git_clone_container_index is not None and git_clone_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -342,10 +327,7 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"value": renku_tokens.access_token,
},
)
if (
secrets_container_index is not None
and secrets_renku_access_token_env is not None
):
if secrets_container_index is not None and secrets_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -367,6 +349,28 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
)


class RemoteK8sClient(NamespacedK8sClient):
def __init__(
self,
name: str,
host: str,
namespace: str,
kube_config_path: str,
amalthea_group: str,
amalthea_version: str,
amalthea_plural: str,
):
super().__init__(namespace, amalthea_group, amalthea_version, amalthea_plural)
self.name = name
self.host = host
load_config(kube_config_path=kube_config_path)
self._custom_objects = client.CustomObjectsApi(client.ApiClient())
self._custom_objects_patch = client.CustomObjectsApi(client.ApiClient())
self._custom_objects_patch.api_client.set_default_header("Content-Type", "application/json-patch+json")
self._core_v1 = client.CoreV1Api()
self._apps_v1 = client.AppsV1Api()


class JsServerCache:
def __init__(self, url: str):
self.url = url
Expand Down Expand Up @@ -406,7 +410,7 @@ def get_server(self, name: str) -> Optional[dict[str, Any]]:
if len(output) == 0:
return
if len(output) > 1:
raise ProgrammingError(f"Expected to find 1 server when getting server {name}, " f"found {len(output)}.")
raise ProgrammingError(f"Expected to find 1 server when getting server {name}, found {len(output)}.")
return output[0]


Expand All @@ -417,11 +421,13 @@ def __init__(
renku_ns_client: NamespacedK8sClient,
username_label: str,
session_ns_client: Optional[NamespacedK8sClient] = None,
remote_cluster_clients: dict[str, RemoteK8sClient] = {},
):
self.js_cache = js_cache
self.renku_ns_client = renku_ns_client
self.username_label = username_label
self.session_ns_client = session_ns_client
self.remote_cluster_clients = remote_cluster_clients
if not self.username_label:
raise ProgrammingError("username_label has to be provided to K8sClient")

Expand All @@ -430,14 +436,24 @@ def list_servers(self, safe_username: str) -> list[dict[str, Any]]:
Attempt to use the cache first but if the cache fails then use the k8s API.
"""
try:
return self.js_cache.list_servers(safe_username)
except JSCacheError:
logging.warning(f"Skipping the cache to list servers for user: {safe_username}")
label_selector = f"{self.username_label}={safe_username}"
return self.renku_ns_client.list_servers(label_selector) + (
self.session_ns_client.list_servers(label_selector) if self.session_ns_client is not None else []
)
# try:
# return self.js_cache.list_servers(safe_username)
# except JSCacheError:
# logging.warning(f"Skipping the cache to list servers for user: {safe_username}")
# label_selector = f"{self.username_label}={safe_username}"
# return self.renku_ns_client.list_servers(label_selector) + (
# self.session_ns_client.list_servers(label_selector) if self.session_ns_client is not None else []
# )
logging.warning(f"Skipping the cache to list servers for user: {safe_username}")
label_selector = f"{self.username_label}={safe_username}"
remote_cluster_servers = [
s for c in self.remote_cluster_clients.values() for s in c.list_servers(label_selector)
]
return (
self.renku_ns_client.list_servers(label_selector)
+ (self.session_ns_client.list_servers(label_selector) if self.session_ns_client is not None else [])
+ remote_cluster_servers
)

def get_server(self, name: str, safe_username: str) -> Optional[dict[str, Any]]:
"""Attempt to get a specific server by name from the cache.
Expand All @@ -459,7 +475,7 @@ def get_server(self, name: str, safe_username: str) -> Optional[dict[str, Any]]:
output.append(res)
if len(output) > 1:
raise ProgrammingError(
"Expected less than two results for searching for " f"server {name}, but got {len(output)}"
"Expected less than two results for searching for server {name}, but got {len(output)}"
)
if len(output) == 0:
return
Expand All @@ -482,9 +498,12 @@ def get_server_logs(
)
namespace = server.get("metadata", {}).get("namespace")
pod_name = f"{server_name}-0"
if namespace == self.renku_ns_client.namespace:
return self.renku_ns_client.get_pod_logs(pod_name, containers, max_log_lines)
return self.session_ns_client.get_pod_logs(pod_name, containers, max_log_lines)
host = self._get_host_from_manifest(server)
client = self._get_cluster_client(host)
if not client:
namespace = server.get("metadata", {}).get("namespace")
client = self.renku_ns_client if self.renku_ns_client.namespace == namespace else self.session_ns_client
return client.get_pod_logs(pod_name, containers, max_log_lines)

def get_secret(self, name: str) -> Optional[dict[str, Any]]:
if self.session_ns_client is not None:
Expand All @@ -493,54 +512,81 @@ def get_secret(self, name: str) -> Optional[dict[str, Any]]:
return secret
return self.renku_ns_client.get_secret(name)

def create_server(self, manifest: dict[str, Any], safe_username: str):
def create_server(self, manifest: dict[str, Any], safe_username: str, cluster: str | None):
server_name = manifest.get("metadata", {}).get("name")
server = self.get_server(server_name, safe_username)
if server:
# NOTE: server already exists
return server
if cluster:
cluster_client = self.remote_cluster_clients[cluster]
return cluster_client.create_server(manifest)
if not self.session_ns_client:
return self.renku_ns_client.create_server(manifest)
return self.session_ns_client.create_server(manifest)

def patch_server(self, server_name: str, safe_username: str, patch: dict[str, Any]):
def patch_server(self, server_name: str, safe_username: str, patch: dict[str, Any], host: str | None = None):
server = self.get_server(server_name, safe_username)
if not server:
raise MissingResourceError(
f"Cannot find server {server_name} for user " f"{safe_username} in order to patch it."
f"Cannot find server {server_name} for user {safe_username} in order to patch it."
)

namespace = server.get("metadata", {}).get("namespace")
client = self._get_cluster_client(host)
if not client:
client = self.session_ns_client if self.session_ns_client else self.renku_ns_client

if namespace == self.renku_ns_client.namespace:
return self.renku_ns_client.patch_server(server_name=server_name, patch=patch)
else:
return self.session_ns_client.patch_server(server_name=server_name, patch=patch)
return client.patch_server(server_name=server_name, patch=patch)

def patch_statefulset(self, server_name: str, patch: dict[str, Any]) -> client.V1StatefulSet | None:
client = self.session_ns_client if self.session_ns_client else self.renku_ns_client
def patch_statefulset(
self, server_name: str, patch: dict[str, Any], host: str | None = None
) -> client.V1StatefulSet | None:
client = self._get_cluster_client(host)
if not client:
client = self.session_ns_client if self.session_ns_client else self.renku_ns_client
return client.patch_statefulset(server_name=server_name, patch=patch)

def delete_server(self, server_name: str, safe_username: str, forced: bool = False):
server = self.get_server(server_name, safe_username)
if not server:
raise MissingResourceError(
f"Cannot find server {server_name} for user " f"{safe_username} in order to delete it."
f"Cannot find server {server_name} for user {safe_username} in order to delete it."
)
namespace = server.get("metadata", {}).get("namespace")
if namespace == self.renku_ns_client.namespace:
self.renku_ns_client.delete_server(server_name, forced)
else:
self.session_ns_client.delete_server(server_name, forced)

def patch_tokens(self, server_name, renku_tokens: RenkuTokens, gitlab_token: GitlabToken):
host = self._get_host_from_manifest(server)
client = self._get_cluster_client(host)
if not client:
namespace = server.get("metadata", {}).get("namespace")
client = self.renku_ns_client if self.renku_ns_client.namespace == namespace else self.session_ns_client
assert client is not None
client.delete_server(server_name, forced)

def patch_tokens(self, server_name, renku_tokens: RenkuTokens, gitlab_token: GitlabToken, host: str | None = None):
"""Patch the Renku and Gitlab access tokens used in a session."""
client = self.session_ns_client if self.session_ns_client else self.renku_ns_client
client = self._get_cluster_client(host)
if not client:
client = self.session_ns_client if self.session_ns_client else self.renku_ns_client
client.patch_statefulset_tokens(server_name, renku_tokens)
client.patch_image_pull_secret(server_name, gitlab_token)

def _get_cluster_client(self, host: str | None) -> RemoteK8sClient | None:
if not host or not self.remote_cluster_clients:
return None
for cluster in self.remote_cluster_clients.values():
if cluster.host == host:
return cluster
return None

def _get_host_from_manifest(self, manifest: dict) -> str:
return manifest["spec"]["routing"]["host"]

@property
def preferred_namespace(self) -> str:
if self.session_ns_client is not None:
return self.session_ns_client.namespace
return self.renku_ns_client.namespace

def preferred_cluster_namespace(self, host: str) -> str | None:
cluster_client = self._get_cluster_client(host)
if not cluster_client:
return None
return cluster_client.namespace
Loading

0 comments on commit 782d920

Please sign in to comment.