diff --git a/Makefile b/Makefile index c78ff9e65..d778007ad 100644 --- a/Makefile +++ b/Makefile @@ -45,7 +45,19 @@ components/renku_data_services/data_connectors/apispec.py: components/renku_data ##@ Apispec -schemas: components/renku_data_services/crc/apispec.py components/renku_data_services/storage/apispec.py components/renku_data_services/users/apispec.py components/renku_data_services/project/apispec.py components/renku_data_services/namespace/apispec.py components/renku_data_services/secrets/apispec.py components/renku_data_services/connected_services/apispec.py components/renku_data_services/repositories/apispec.py components/renku_data_services/notebooks/apispec.py components/renku_data_services/platform/apispec.py components/renku_data_services/message_queue/apispec.py components/renku_data_services/data_connectors/apispec.py ## Generate pydantic classes from apispec yaml files +schemas: components/renku_data_services/crc/apispec.py \ +components/renku_data_services/storage/apispec.py \ +components/renku_data_services/users/apispec.py \ +components/renku_data_services/project/apispec.py \ +components/renku_data_services/session/apispec.py \ +components/renku_data_services/namespace/apispec.py \ +components/renku_data_services/secrets/apispec.py \ +components/renku_data_services/connected_services/apispec.py \ +components/renku_data_services/repositories/apispec.py \ +components/renku_data_services/notebooks/apispec.py \ +components/renku_data_services/platform/apispec.py \ +components/renku_data_services/message_queue/apispec.py \ +components/renku_data_services/data_connectors/apispec.py ## Generate pydantic classes from apispec yaml files @echo "generated classes based on ApiSpec" ##@ Avro schemas @@ -90,6 +102,8 @@ style_checks: ## Run linting and style checks @$(call test_apispec_up_to_date,"platform") @echo "checking message_queue apispec is up to date" @$(call test_apispec_up_to_date,"message_queue") + @echo "checking session apispec is up to date" + @$(call test_apispec_up_to_date,"session") poetry run mypy poetry run ruff format --check poetry run ruff check . diff --git a/components/renku_data_services/migrations/versions/1ef98b967767_add_command_and_args_to_environment.py b/components/renku_data_services/migrations/versions/1ef98b967767_add_command_and_args_to_environment.py new file mode 100644 index 000000000..f91f342ce --- /dev/null +++ b/components/renku_data_services/migrations/versions/1ef98b967767_add_command_and_args_to_environment.py @@ -0,0 +1,41 @@ +"""Add command and args to environment + +Revision ID: 1ef98b967767 +Revises: 584598f3b769 +Create Date: 2024-08-25 21:05:02.158021 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "1ef98b967767" +down_revision = "584598f3b769" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "environments", + sa.Column("args", sa.JSON().with_variant(postgresql.JSONB(astext_type=sa.Text()), "postgresql"), nullable=True), + schema="sessions", + ) + op.add_column( + "environments", + sa.Column( + "command", sa.JSON().with_variant(postgresql.JSONB(astext_type=sa.Text()), "postgresql"), nullable=True + ), + schema="sessions", + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("environments", "command", schema="sessions") + op.drop_column("environments", "args", schema="sessions") + # ### end Alembic commands ### diff --git a/components/renku_data_services/migrations/versions/584598f3b769_expand_and_separate_environments_from_.py b/components/renku_data_services/migrations/versions/584598f3b769_expand_and_separate_environments_from_.py new file mode 100644 index 000000000..6973937d6 --- /dev/null +++ b/components/renku_data_services/migrations/versions/584598f3b769_expand_and_separate_environments_from_.py @@ -0,0 +1,109 @@ +"""expand and separate environments from session launchers + +Revision ID: 584598f3b769 +Revises: 9058bf0a1a12 +Create Date: 2024-08-12 14:25:24.292285 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "584598f3b769" +down_revision = "9058bf0a1a12" +branch_labels = None +depends_on = None + +default_url: str = "/lab" +working_dir: str = "/home/jovyan/work" +mount_dir: str = "/home/jovyan/work" +uid: int = 1000 +gid: int = 1000 +port: int = 8888 + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.execute("DELETE FROM sessions.launchers") + op.drop_column("launchers", "default_url", schema="sessions") + op.drop_column("launchers", "environment_kind", schema="sessions") + op.drop_column("launchers", "container_image", schema="sessions") + op.execute("DROP TYPE environmentkind CASCADE") + op.execute("CREATE TYPE environmentkind AS ENUM ('GLOBAL', 'CUSTOM')") + op.add_column("environments", sa.Column("port", sa.Integer(), nullable=True), schema="sessions") + op.add_column("environments", sa.Column("working_directory", sa.String(), nullable=True), schema="sessions") + op.add_column("environments", sa.Column("mount_directory", sa.String(), nullable=True), schema="sessions") + op.add_column("environments", sa.Column("uid", sa.Integer(), nullable=True), schema="sessions") + op.add_column("environments", sa.Column("gid", sa.Integer(), nullable=True), schema="sessions") + op.add_column( + "environments", + sa.Column("environment_kind", sa.Enum("GLOBAL", "CUSTOM", name="environmentkind"), nullable=True), + schema="sessions", + ) + op.execute(sa.text("UPDATE sessions.environments SET port = :port WHERE port is NULL").bindparams(port=port)) + op.execute( + sa.text( + "UPDATE sessions.environments SET working_directory = :working_dir WHERE working_directory is NULL" + ).bindparams(working_dir=working_dir) + ) + op.execute( + sa.text( + "UPDATE sessions.environments SET mount_directory = :mount_dir WHERE mount_directory is NULL" + ).bindparams(mount_dir=mount_dir) + ) + op.execute(sa.text("UPDATE sessions.environments SET uid = :uid WHERE uid is NULL").bindparams(uid=uid)) + op.execute(sa.text("UPDATE sessions.environments SET gid = :gid WHERE gid is NULL").bindparams(gid=gid)) + op.execute("UPDATE sessions.environments SET environment_kind = 'GLOBAL' WHERE environment_kind is NULL") + op.execute( + sa.text("UPDATE sessions.environments SET default_url = :default_url WHERE default_url is NULL").bindparams( + default_url=default_url + ) + ) + op.alter_column("environments", "port", nullable=False, schema="sessions") + op.alter_column("environments", "working_directory", nullable=False, schema="sessions") + op.alter_column("environments", "mount_directory", nullable=False, schema="sessions") + op.alter_column("environments", "uid", nullable=False, schema="sessions") + op.alter_column("environments", "gid", nullable=False, schema="sessions") + op.alter_column("environments", "environment_kind", nullable=False, schema="sessions") + op.alter_column( + "environments", "default_url", existing_type=sa.VARCHAR(length=200), nullable=False, schema="sessions" + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("environments", "environment_kind", schema="sessions") + op.drop_column("environments", "gid", schema="sessions") + op.drop_column("environments", "uid", schema="sessions") + op.drop_column("environments", "mount_directory", schema="sessions") + op.drop_column("environments", "working_directory", schema="sessions") + op.drop_column("environments", "port", schema="sessions") + op.execute("DROP TYPE environmentkind") + op.execute("CREATE TYPE environmentkind AS ENUM ('global_environment', 'container_image')") + op.add_column( + "launchers", + sa.Column("container_image", sa.VARCHAR(length=500), autoincrement=False, nullable=True), + schema="sessions", + ) + op.add_column( + "launchers", + sa.Column( + "environment_kind", + postgresql.ENUM("global_environment", "container_image", name="environmentkind"), + autoincrement=False, + nullable=False, + ), + schema="sessions", + ) + op.add_column( + "launchers", + sa.Column("default_url", sa.VARCHAR(length=200), autoincrement=False, nullable=True), + schema="sessions", + ) + op.alter_column( + "environments", "default_url", existing_type=sa.VARCHAR(length=200), nullable=True, schema="sessions" + ) + # ### end Alembic commands ### diff --git a/components/renku_data_services/session/api.spec.yaml b/components/renku_data_services/session/api.spec.yaml index 6e4c8dce4..a16046084 100644 --- a/components/renku_data_services/session/api.spec.yaml +++ b/components/renku_data_services/session/api.spec.yaml @@ -10,10 +10,10 @@ servers: paths: /environments: get: - summary: Get all environments + summary: Get all global environments responses: "200": - description: List of environments + description: List of global environments content: application/json: schema: @@ -23,7 +23,7 @@ paths: tags: - environments post: - summary: Create a new session environment + summary: Create a new global session environment description: Requires admin permissions requestBody: required: true @@ -44,7 +44,7 @@ paths: - environments /environments/{environment_id}: get: - summary: Get a session environment + summary: Get a global session environment parameters: - in: path name: environment_id @@ -69,7 +69,7 @@ paths: tags: - environments patch: - summary: Update specific fields of an existing session environment + summary: Update specific fields of an existing global session environment description: Requires admin permissions parameters: - in: path @@ -101,7 +101,7 @@ paths: tags: - environments delete: - summary: Remove a session environment + summary: Remove a global session environment parameters: - in: path name: environment_id @@ -175,7 +175,7 @@ paths: tags: - session_launchers patch: - summary: Update specific fields of an existing session + summary: Update specific fields of an existing session launcher parameters: - in: path name: launcher_id @@ -220,32 +220,6 @@ paths: $ref: "#/components/responses/Error" tags: - session_launchers - /session_launchers/{launcher_id}/start: - post: - summary: Use a session launcher to start a session - parameters: - - in: path - name: launcher_id - required: true - schema: - $ref: "#/components/schemas/Ulid" - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/SessionStart" - responses: - "201": - description: The started session - content: - application/json: - schema: - $ref: "#/components/schemas/Session" - default: - $ref: "#/components/responses/Error" - tags: - - sessions /projects/{project_id}/session_launchers: get: summary: Get a project's session launchers @@ -289,11 +263,31 @@ components: $ref: "#/components/schemas/ContainerImage" default_url: $ref: "#/components/schemas/DefaultUrl" + uid: + $ref: "#/components/schemas/EnvironmentUid" + gid: + $ref: "#/components/schemas/EnvironmentGid" + working_directory: + $ref: "#/components/schemas/EnvironmentWorkingDirectory" + mount_directory: + $ref: "#/components/schemas/EnvironmentMountDirectory" + port: + $ref: "#/components/schemas/EnvironmentPort" + command: + $ref: "#/components/schemas/EnvironmentCommand" + args: + $ref: "#/components/schemas/EnvironmentArgs" required: - id - name - creation_date - container_image + - port + - working_directory + - mount_directory + - uid + - gid + - default_url example: id: 01AN4Z79ZS6XX96588FDX0H099 name: JupyterLab environment @@ -301,6 +295,33 @@ components: description: JupyterLab session environment container_image: renku-jupyter:latest default_url: "/lab" + port: 8080 + working_directory: /home/jovyan/work + mount_directory: /home/jovyan/work + uid: 1000 + gid: 1000 + EnvironmentGetInLauncher: + allOf: + - $ref: "#/components/schemas/Environment" + - type: object + properties: + environment_kind: + $ref: "#/components/schemas/EnvironmentKind" + required: + - environment_kind + example: + environment_kind: global_environment + EnvironmentPostInLauncher: + allOf: + - $ref: "#/components/schemas/EnvironmentPost" + - type: object + properties: + environment_kind: + $ref: "#/components/schemas/EnvironmentKind" + required: + - environment_kind + example: + environment_kind: global_environment EnvironmentPost: description: Data required to create a session environment type: object @@ -312,10 +333,49 @@ components: container_image: $ref: "#/components/schemas/ContainerImage" default_url: - $ref: "#/components/schemas/DefaultUrl" + allOf: + - $ref: "#/components/schemas/DefaultUrl" + - default: /lab + default: /lab + uid: + allOf: + - $ref: "#/components/schemas/EnvironmentUid" + - default: 1000 + default: 1000 + gid: + allOf: + - $ref: "#/components/schemas/EnvironmentGid" + - default: 1000 + default: 1000 + working_directory: + allOf: + - $ref: "#/components/schemas/EnvironmentWorkingDirectory" + - default: /home/jovyan/work + default: /home/jovyan/work + mount_directory: + allOf: + - $ref: "#/components/schemas/EnvironmentMountDirectory" + - default: /home/jovyan/work + default: /home/jovyan/work + port: + allOf: + - $ref: "#/components/schemas/EnvironmentPort" + - default: 8080 + default: 8080 + command: + $ref: "#/components/schemas/EnvironmentCommand" + args: + $ref: "#/components/schemas/EnvironmentArgs" required: - name - container_image + EnvironmentPatchInLauncher: + allOf: + - $ref: "#/components/schemas/EnvironmentPatch" + - type: object + properties: + environment_kind: + $ref: "#/components/schemas/EnvironmentKind" EnvironmentPatch: type: object description: Update a session environment @@ -329,6 +389,20 @@ components: $ref: "#/components/schemas/ContainerImage" default_url: $ref: "#/components/schemas/DefaultUrl" + uid: + $ref: "#/components/schemas/EnvironmentUid" + gid: + $ref: "#/components/schemas/EnvironmentGid" + working_directory: + $ref: "#/components/schemas/EnvironmentWorkingDirectory" + mount_directory: + $ref: "#/components/schemas/EnvironmentMountDirectory" + port: + $ref: "#/components/schemas/EnvironmentPort" + command: + $ref: "#/components/schemas/EnvironmentCommand" + args: + $ref: "#/components/schemas/EnvironmentArgs" SessionLaunchersList: description: A list of Renku session launchers type: array @@ -349,30 +423,36 @@ components: $ref: "#/components/schemas/CreationDate" description: $ref: "#/components/schemas/Description" - environment_kind: - $ref: "#/components/schemas/EnvironmentKind" - environment_id: - $ref: "#/components/schemas/EnvironmentId" + environment: + $ref: "#/components/schemas/EnvironmentGetInLauncher" resource_class_id: $ref: "#/components/schemas/ResourceClassId" - container_image: - $ref: "#/components/schemas/ContainerImage" - default_url: - $ref: "#/components/schemas/DefaultUrl" required: - id - project_id - name - creation_date - - environment_kind + - environment + - resource_class_id example: id: 01AN4Z79ZS5XN0F25N3DB94T4R project_id: 01AN4Z79ZS5XN0F25N3DB94T4R name: Renku R Session creation_date: "2023-11-01T17:32:28Z" description: R compute session - environment_kind: global_environment - environment_id: 01AN4Z79ZS6XX96588FDX0H099 + environment: + id: 01AN4Z79ZS6XX96588FDX0H099 + name: Rstudio + creation_date: "2023-11-01T17:32:28Z" + description: JupyterLab session environment + environment_kind: GLOBAL + container_image: rocker/rstudio + default_url: "/rstudio" + port: 8080 + working_directory: /home/rstudio/work + mount_directory: /home/rstudio/work + uid: 1000 + gid: 1000 SessionLauncherPost: description: Data required to create a session launcher type: object @@ -384,20 +464,21 @@ components: $ref: "#/components/schemas/Ulid" description: $ref: "#/components/schemas/Description" - environment_kind: - $ref: "#/components/schemas/EnvironmentKind" - environment_id: - $ref: "#/components/schemas/EnvironmentId" resource_class_id: $ref: "#/components/schemas/ResourceClassId" - container_image: - $ref: "#/components/schemas/ContainerImage" - default_url: - $ref: "#/components/schemas/DefaultUrl" + environment: + oneOf: + - $ref: "#/components/schemas/EnvironmentPostInLauncher" + - $ref: "#/components/schemas/EnvironmentIdOnlyPost" required: - name - project_id - - environment_kind + - environment + example: + project_id: 01AN4Z79ZS5XN0F25N3DB94T4R + name: Renku R Session + environment: + id: 01AN4Z79ZS6XX96588FDX0H099 SessionLauncherPatch: type: object description: Update a session launcher @@ -407,23 +488,12 @@ components: $ref: "#/components/schemas/SessionName" description: $ref: "#/components/schemas/Description" - environment_kind: - $ref: "#/components/schemas/EnvironmentKind" - environment_id: - $ref: "#/components/schemas/EnvironmentId" - resource_class_id: - $ref: "#/components/schemas/ResourceClassId" - container_image: - $ref: "#/components/schemas/ContainerImage" - default_url: - $ref: "#/components/schemas/DefaultUrl" - SessionStart: - type: object - description: Start a session - additionalProperties: true - properties: resource_class_id: $ref: "#/components/schemas/ResourceClassId" + environment: + oneOf: + - $ref: "#/components/schemas/EnvironmentPatchInLauncher" + - $ref: "#/components/schemas/EnvironmentIdOnlyPatch" Ulid: description: ULID identifier type: string @@ -436,13 +506,25 @@ components: minLength: 1 maxLength: 99 example: My Renku Session :) + EnvironmentIdOnlyPatch: + type: object + properties: + id: + $ref: "#/components/schemas/EnvironmentId" + EnvironmentIdOnlyPost: + type: object + properties: + id: + $ref: "#/components/schemas/EnvironmentId" + required: + - id EnvironmentKind: description: Kind of environment to use type: string enum: - - global_environment - - container_image - example: container_image + - GLOBAL + - CUSTOM + example: CUSTOM EnvironmentId: description: Id of the environment to use type: string @@ -472,15 +554,46 @@ components: type: integer default: null nullable: true - Session: - description: A Renku session - type: object - additionalProperties: true - properties: - name: - $ref: "#/components/schemas/SessionName" - url: - type: string + EnvironmentPort: + type: integer + minimum: 0 + exclusiveMinimum: true + exclusiveMaximum: true + # NOTE: we reserve 65400 - 65535 for usage of Renku sidecars and services + maximum: 65400 + description: The TCP port (on any container in the session) where user requests will be routed to from the ingress + EnvironmentUid: + type: integer + minimum: 0 + exclusiveMinimum: true + maximum: 65535 + description: The user ID used to run the session + EnvironmentGid: + type: integer + minimum: 0 + exclusiveMinimum: true + maximum: 65535 + description: The group ID used to run the session + EnvironmentWorkingDirectory: + type: string + description: The location where the session will start + minLength: 1 + EnvironmentMountDirectory: + type: string + description: The location where the persistent storage for the session will be mounted, usually it should be identical to or a parent of the working directory + minLength: 1 + EnvironmentCommand: + type: array + items: + type: string + description: The command that will be run i.e. will overwrite the image Dockerfile ENTRYPOINT, equivalent to command in Kubernetes + minLength: 1 + EnvironmentArgs: + type: array + items: + type: string + description: The arguments that will follow the command, i.e. will overwrite the image Dockerfile CMD, equivalent to args in Kubernetes + minLength: 1 ErrorResponse: type: object properties: diff --git a/components/renku_data_services/session/apispec.py b/components/renku_data_services/session/apispec.py index 6be545731..3ac7b194c 100644 --- a/components/renku_data_services/session/apispec.py +++ b/components/renku_data_services/session/apispec.py @@ -1,34 +1,20 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-06-10T13:14:40+00:00 +# timestamp: 2024-08-25T21:01:41+00:00 from __future__ import annotations from datetime import datetime from enum import Enum -from typing import List, Optional +from typing import List, Optional, Union from pydantic import ConfigDict, Field, RootModel from renku_data_services.session.apispec_base import BaseAPISpec class EnvironmentKind(Enum): - global_environment = "global_environment" - container_image = "container_image" - - -class Session(BaseAPISpec): - model_config = ConfigDict( - extra="allow", - ) - name: Optional[str] = Field( - None, - description="Renku session name", - example="My Renku Session :)", - max_length=99, - min_length=1, - ) - url: Optional[str] = None + GLOBAL = "GLOBAL" + CUSTOM = "CUSTOM" class Error(BaseAPISpec): @@ -49,7 +35,7 @@ class Environment(BaseAPISpec): description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]{26}$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) name: str = Field( ..., @@ -72,12 +58,46 @@ class Environment(BaseAPISpec): example="renku/renkulab-py:3.10-0.18.1", max_length=500, ) - default_url: Optional[str] = Field( - None, + default_url: str = Field( + ..., description="The default path to open in a session", example="/lab", max_length=200, ) + uid: int = Field( + ..., description="The user ID used to run the session", gt=0, le=65535 + ) + gid: int = Field( + ..., description="The group ID used to run the session", gt=0, le=65535 + ) + working_directory: str = Field( + ..., description="The location where the session will start", min_length=1 + ) + mount_directory: str = Field( + ..., + description="The location where the persistent storage for the session will be mounted, usually it should be identical to or a parent of the working directory", + min_length=1, + ) + port: int = Field( + ..., + description="The TCP port (on any container in the session) where user requests will be routed to from the ingress", + gt=0, + lt=65400, + ) + command: Optional[List[str]] = Field( + None, + description="The command that will be run i.e. will overwrite the image Dockerfile ENTRYPOINT, equivalent to command in Kubernetes", + min_length=1, + ) + args: Optional[List[str]] = Field( + None, + description="The arguments that will follow the command, i.e. will overwrite the image Dockerfile CMD, equivalent to args in Kubernetes", + min_length=1, + ) + + +class EnvironmentGetInLauncher(Environment): + environment_kind: EnvironmentKind class EnvironmentPost(BaseAPISpec): @@ -97,12 +117,44 @@ class EnvironmentPost(BaseAPISpec): example="renku/renkulab-py:3.10-0.18.1", max_length=500, ) - default_url: Optional[str] = Field( - None, + default_url: str = Field( + "/lab", description="The default path to open in a session", example="/lab", max_length=200, ) + uid: int = Field( + 1000, description="The user ID used to run the session", gt=0, le=65535 + ) + gid: int = Field( + 1000, description="The group ID used to run the session", gt=0, le=65535 + ) + working_directory: str = Field( + "/home/jovyan/work", + description="The location where the session will start", + min_length=1, + ) + mount_directory: str = Field( + "/home/jovyan/work", + description="The location where the persistent storage for the session will be mounted, usually it should be identical to or a parent of the working directory", + min_length=1, + ) + port: int = Field( + 8080, + description="The TCP port (on any container in the session) where user requests will be routed to from the ingress", + gt=0, + lt=65400, + ) + command: Optional[List[str]] = Field( + None, + description="The command that will be run i.e. will overwrite the image Dockerfile ENTRYPOINT, equivalent to command in Kubernetes", + min_length=1, + ) + args: Optional[List[str]] = Field( + None, + description="The arguments that will follow the command, i.e. will overwrite the image Dockerfile CMD, equivalent to args in Kubernetes", + min_length=1, + ) class EnvironmentPatch(BaseAPISpec): @@ -131,6 +183,36 @@ class EnvironmentPatch(BaseAPISpec): example="/lab", max_length=200, ) + uid: Optional[int] = Field( + None, description="The user ID used to run the session", gt=0, le=65535 + ) + gid: Optional[int] = Field( + None, description="The group ID used to run the session", gt=0, le=65535 + ) + working_directory: Optional[str] = Field( + None, description="The location where the session will start", min_length=1 + ) + mount_directory: Optional[str] = Field( + None, + description="The location where the persistent storage for the session will be mounted, usually it should be identical to or a parent of the working directory", + min_length=1, + ) + port: Optional[int] = Field( + None, + description="The TCP port (on any container in the session) where user requests will be routed to from the ingress", + gt=0, + lt=65400, + ) + command: Optional[List[str]] = Field( + None, + description="The command that will be run i.e. will overwrite the image Dockerfile ENTRYPOINT, equivalent to command in Kubernetes", + min_length=1, + ) + args: Optional[List[str]] = Field( + None, + description="The arguments that will follow the command, i.e. will overwrite the image Dockerfile CMD, equivalent to args in Kubernetes", + min_length=1, + ) class SessionLauncher(BaseAPISpec): @@ -139,14 +221,14 @@ class SessionLauncher(BaseAPISpec): description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]{26}$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) project_id: str = Field( ..., description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]{26}$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) name: str = Field( ..., @@ -163,27 +245,45 @@ class SessionLauncher(BaseAPISpec): description: Optional[str] = Field( None, description="A description for the resource", max_length=500 ) - environment_kind: EnvironmentKind - environment_id: Optional[str] = Field( + environment: EnvironmentGetInLauncher + resource_class_id: Optional[int] = Field( + ..., description="The identifier of a resource class" + ) + + +class EnvironmentIdOnlyPatch(BaseAPISpec): + id: Optional[str] = Field( None, description="Id of the environment to use", example="01AN4Z79ZS6XX96588FDX0H099", min_length=1, ) - resource_class_id: Optional[int] = Field( - None, description="The identifier of a resource class" - ) - container_image: Optional[str] = Field( - None, - description="A container image", - example="renku/renkulab-py:3.10-0.18.1", - max_length=500, + + +class EnvironmentIdOnlyPost(BaseAPISpec): + id: str = Field( + ..., + description="Id of the environment to use", + example="01AN4Z79ZS6XX96588FDX0H099", + min_length=1, ) - default_url: Optional[str] = Field( - None, - description="The default path to open in a session", - example="/lab", - max_length=200, + + +class EnvironmentList(RootModel[List[Environment]]): + root: List[Environment] = Field(..., description="A list of session environments") + + +class EnvironmentPostInLauncher(EnvironmentPost): + environment_kind: EnvironmentKind + + +class EnvironmentPatchInLauncher(EnvironmentPatch): + environment_kind: Optional[EnvironmentKind] = None + + +class SessionLaunchersList(RootModel[List[SessionLauncher]]): + root: List[SessionLauncher] = Field( + ..., description="A list of Renku session launchers", min_length=0 ) @@ -203,33 +303,15 @@ class SessionLauncherPost(BaseAPISpec): description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]{26}$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) description: Optional[str] = Field( None, description="A description for the resource", max_length=500 ) - environment_kind: EnvironmentKind - environment_id: Optional[str] = Field( - None, - description="Id of the environment to use", - example="01AN4Z79ZS6XX96588FDX0H099", - min_length=1, - ) resource_class_id: Optional[int] = Field( None, description="The identifier of a resource class" ) - container_image: Optional[str] = Field( - None, - description="A container image", - example="renku/renkulab-py:3.10-0.18.1", - max_length=500, - ) - default_url: Optional[str] = Field( - None, - description="The default path to open in a session", - example="/lab", - max_length=200, - ) + environment: Union[EnvironmentPostInLauncher, EnvironmentIdOnlyPost] class SessionLauncherPatch(BaseAPISpec): @@ -246,44 +328,9 @@ class SessionLauncherPatch(BaseAPISpec): description: Optional[str] = Field( None, description="A description for the resource", max_length=500 ) - environment_kind: Optional[EnvironmentKind] = None - environment_id: Optional[str] = Field( - None, - description="Id of the environment to use", - example="01AN4Z79ZS6XX96588FDX0H099", - min_length=1, - ) - resource_class_id: Optional[int] = Field( - None, description="The identifier of a resource class" - ) - container_image: Optional[str] = Field( - None, - description="A container image", - example="renku/renkulab-py:3.10-0.18.1", - max_length=500, - ) - default_url: Optional[str] = Field( - None, - description="The default path to open in a session", - example="/lab", - max_length=200, - ) - - -class SessionStart(BaseAPISpec): - model_config = ConfigDict( - extra="allow", - ) resource_class_id: Optional[int] = Field( None, description="The identifier of a resource class" ) - - -class EnvironmentList(RootModel[List[Environment]]): - root: List[Environment] = Field(..., description="A list of session environments") - - -class SessionLaunchersList(RootModel[List[SessionLauncher]]): - root: List[SessionLauncher] = Field( - ..., description="A list of Renku session launchers", min_length=0 + environment: Optional[Union[EnvironmentPatchInLauncher, EnvironmentIdOnlyPatch]] = ( + None ) diff --git a/components/renku_data_services/session/apispec_base.py b/components/renku_data_services/session/apispec_base.py index a16833290..d91e73fb9 100644 --- a/components/renku_data_services/session/apispec_base.py +++ b/components/renku_data_services/session/apispec_base.py @@ -1,5 +1,7 @@ """Base models for API specifications.""" +from pathlib import PurePosixPath + from pydantic import BaseModel, field_validator from ulid import ULID @@ -12,8 +14,16 @@ class Config: from_attributes = True - @field_validator("id", mode="before", check_fields=False) + @field_validator("id", "project_id", mode="before", check_fields=False) @classmethod def serialize_id(cls, id: str | ULID) -> str: """Custom serializer that can handle ULIDs.""" return str(id) + + @field_validator("working_directory", "mount_directory", check_fields=False, mode="before") + @classmethod + def convert_path_to_string(cls, val: str | PurePosixPath) -> str: + """Converts the python path to a regular string when pydantic deserializes.""" + if isinstance(val, PurePosixPath): + return val.as_posix() + return val diff --git a/components/renku_data_services/session/blueprints.py b/components/renku_data_services/session/blueprints.py index 8fe42995b..75fbf25a4 100644 --- a/components/renku_data_services/session/blueprints.py +++ b/components/renku_data_services/session/blueprints.py @@ -1,6 +1,7 @@ """Session blueprint.""" from dataclasses import dataclass +from pathlib import PurePosixPath from sanic import HTTPResponse, Request, json from sanic.response import JSONResponse @@ -10,7 +11,7 @@ import renku_data_services.base_models as base_models from renku_data_services.base_api.auth import authenticate, validate_path_project_id from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint -from renku_data_services.session import apispec +from renku_data_services.session import apispec, models from renku_data_services.session.db import SessionRepository @@ -47,7 +48,21 @@ def post(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @validate(json=apispec.EnvironmentPost) async def _post(_: Request, user: base_models.APIUser, body: apispec.EnvironmentPost) -> JSONResponse: - environment = await self.session_repo.insert_environment(user=user, new_environment=body) + unsaved_environment = models.UnsavedEnvironment( + name=body.name, + description=body.description, + container_image=body.container_image, + default_url=body.default_url, + port=body.port, + working_directory=PurePosixPath(body.working_directory), + mount_directory=PurePosixPath(body.mount_directory), + uid=body.uid, + gid=body.gid, + environment_kind=models.EnvironmentKind.GLOBAL, + command=body.command, + args=body.args, + ) + environment = await self.session_repo.insert_environment(user=user, new_environment=unsaved_environment) return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"), 201) return "/environments", ["POST"], _post @@ -117,7 +132,32 @@ def post(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @validate(json=apispec.SessionLauncherPost) async def _post(_: Request, user: base_models.APIUser, body: apispec.SessionLauncherPost) -> JSONResponse: - launcher = await self.session_repo.insert_launcher(user=user, new_launcher=body) + environment: str | models.UnsavedEnvironment + if isinstance(body.environment, apispec.EnvironmentIdOnlyPost): + environment = body.environment.id + else: + environment = models.UnsavedEnvironment( + name=body.environment.name, + description=body.environment.description, + container_image=body.environment.container_image, + default_url=body.environment.default_url, + port=body.environment.port, + working_directory=PurePosixPath(body.environment.working_directory), + mount_directory=PurePosixPath(body.environment.mount_directory), + uid=body.environment.uid, + gid=body.environment.gid, + environment_kind=models.EnvironmentKind(body.environment.environment_kind.value), + args=body.environment.args, + command=body.environment.command, + ) + new_launcher = models.UnsavedSessionLauncher( + project_id=ULID.from_str(body.project_id), + name=body.name, + description=body.description, + environment=environment, + resource_class_id=body.resource_class_id, + ) + launcher = await self.session_repo.insert_launcher(user=user, new_launcher=new_launcher) return json( apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"), 201 ) @@ -132,8 +172,35 @@ def patch(self) -> BlueprintFactoryResponse: async def _patch( _: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch ) -> JSONResponse: - body_dict = body.model_dump(exclude_none=True) - launcher = await self.session_repo.update_launcher(user=user, launcher_id=launcher_id, **body_dict) + body_dict = body.model_dump(exclude_none=True, mode="json") + async with self.session_repo.session_maker() as session, session.begin(): + current_launcher = await self.session_repo.get_launcher(user, launcher_id) + new_env: models.UnsavedEnvironment | None = None + if ( + isinstance(body.environment, apispec.EnvironmentPatchInLauncher) + and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL + and body.environment.environment_kind == apispec.EnvironmentKind.CUSTOM + ): + # This means that the global environment is being swapped for a custom one, + # so we have to create a brand new environment, but we have to validate here. + validated_env = apispec.EnvironmentPostInLauncher.model_validate(body_dict.pop("environment")) + new_env = models.UnsavedEnvironment( + name=validated_env.name, + description=validated_env.description, + container_image=validated_env.container_image, + default_url=validated_env.default_url, + port=validated_env.port, + working_directory=PurePosixPath(validated_env.working_directory), + mount_directory=PurePosixPath(validated_env.mount_directory), + uid=validated_env.uid, + gid=validated_env.gid, + environment_kind=models.EnvironmentKind(validated_env.environment_kind.value), + args=validated_env.args, + command=validated_env.command, + ) + launcher = await self.session_repo.update_launcher( + user=user, launcher_id=launcher_id, new_custom_environment=new_env, session=session, **body_dict + ) return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json")) return "/session_launchers/", ["PATCH"], _patch diff --git a/components/renku_data_services/session/db.py b/components/renku_data_services/session/db.py index 06f9c89b7..417820e40 100644 --- a/components/renku_data_services/session/db.py +++ b/components/renku_data_services/session/db.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from contextlib import AbstractAsyncContextManager, nullcontext from datetime import UTC, datetime from typing import Any @@ -15,9 +16,8 @@ from renku_data_services.authz.authz import Authz, ResourceType from renku_data_services.authz.models import Scope from renku_data_services.crc.db import ResourcePoolRepository -from renku_data_services.session import apispec, models +from renku_data_services.session import models from renku_data_services.session import orm as schemas -from renku_data_services.session.apispec import EnvironmentKind class SessionRepository: @@ -31,17 +31,23 @@ def __init__( self.resource_pools: ResourcePoolRepository = resource_pools async def get_environments(self) -> list[models.Environment]: - """Get all session environments from the database.""" + """Get all global session environments from the database.""" async with self.session_maker() as session: - res = await session.scalars(select(schemas.EnvironmentORM)) + res = await session.scalars( + select(schemas.EnvironmentORM).where( + schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL.value + ) + ) environments = res.all() return [e.dump() for e in environments] async def get_environment(self, environment_id: ULID) -> models.Environment: - """Get one session environment from the database.""" + """Get one global session environment from the database.""" async with self.session_maker() as session: res = await session.scalars( - select(schemas.EnvironmentORM).where(schemas.EnvironmentORM.id == str(environment_id)) + select(schemas.EnvironmentORM) + .where(schemas.EnvironmentORM.id == str(environment_id)) + .where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL.value) ) environment = res.one_or_none() if environment is None: @@ -50,64 +56,109 @@ async def get_environment(self, environment_id: ULID) -> models.Environment: ) return environment.dump() - async def insert_environment( + async def __insert_environment( self, user: base_models.APIUser, - new_environment: apispec.EnvironmentPost, - ) -> models.Environment: - """Insert a new session environment.""" + session: AsyncSession, + new_environment: models.UnsavedEnvironment, + ) -> schemas.EnvironmentORM: if user.id is None: - raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.") - if not user.is_admin: - raise errors.ForbiddenError(message="You do not have the required permissions for this operation.") - - environment_model = models.Environment( - id=None, + raise errors.UnauthorizedError( + message="You have to be authenticated to insert an environment in the DB.", quiet=True + ) + environment = schemas.EnvironmentORM( name=new_environment.name, + created_by_id=user.id, + creation_date=datetime.now(UTC), description=new_environment.description, container_image=new_environment.container_image, default_url=new_environment.default_url, - created_by=models.Member(id=user.id), - creation_date=datetime.now(UTC).replace(microsecond=0), + port=new_environment.port, + working_directory=new_environment.working_directory, + mount_directory=new_environment.mount_directory, + uid=new_environment.uid, + gid=new_environment.gid, + environment_kind=new_environment.environment_kind, + command=new_environment.command, + args=new_environment.args, ) - environment = schemas.EnvironmentORM.load(environment_model) + + session.add(environment) + return environment + + async def insert_environment( + self, user: base_models.APIUser, new_environment: models.UnsavedEnvironment + ) -> models.Environment: + """Insert a new global session environment.""" + if user.id is None or not user.is_admin: + raise errors.UnauthorizedError( + message="You do not have the required permissions for this operation.", quiet=True + ) + if new_environment.environment_kind != models.EnvironmentKind.GLOBAL: + raise errors.ValidationError(message="This endpoint only supports adding global environments", quiet=True) async with self.session_maker() as session, session.begin(): - session.add(environment) - return environment.dump() + env = await self.__insert_environment(user, session, new_environment) + return env.dump() + + async def __update_environment( + self, + user: base_models.APIUser, + session: AsyncSession, + environment_id: ULID, + kind: models.EnvironmentKind, + **kwargs: dict, + ) -> models.Environment: + res = await session.scalars( + select(schemas.EnvironmentORM) + .where(schemas.EnvironmentORM.id == str(environment_id)) + .where(schemas.EnvironmentORM.environment_kind == kind.value) + ) + environment = res.one_or_none() + if environment is None: + raise errors.MissingResourceError(message=f"Session environment with id '{environment_id}' does not exist.") + + for key, value in kwargs.items(): + # NOTE: Only some fields can be edited + if key in [ + "name", + "description", + "container_image", + "default_url", + "port", + "working_directory", + "mount_directory", + "uid", + "gid", + "args", + "command", + ]: + setattr(environment, key, value) + + return environment.dump() async def update_environment( self, user: base_models.APIUser, environment_id: ULID, **kwargs: dict ) -> models.Environment: - """Update a session environment entry.""" + """Update a global session environment entry.""" if not user.is_admin: - raise errors.ForbiddenError(message="You do not have the required permissions for this operation.") + raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.") async with self.session_maker() as session, session.begin(): - res = await session.scalars( - select(schemas.EnvironmentORM).where(schemas.EnvironmentORM.id == str(environment_id)) + return await self.__update_environment( + user, session, environment_id, models.EnvironmentKind.GLOBAL, **kwargs ) - environment = res.one_or_none() - if environment is None: - raise errors.MissingResourceError( - message=f"Session environment with id '{environment_id}' does not exist." - ) - - for key, value in kwargs.items(): - # NOTE: Only ``name``, ``description``, ``container_image`` and ``default_url`` can be edited - if key in ["name", "description", "container_image", "default_url"]: - setattr(environment, key, value) - - return environment.dump() async def delete_environment(self, user: base_models.APIUser, environment_id: ULID) -> None: - """Delete a session environment entry.""" + """Delete a global session environment entry.""" if not user.is_admin: raise errors.ForbiddenError(message="You do not have the required permissions for this operation.") async with self.session_maker() as session, session.begin(): res = await session.scalars( - select(schemas.EnvironmentORM).where(schemas.EnvironmentORM.id == str(environment_id)) + select(schemas.EnvironmentORM) + .where(schemas.EnvironmentORM.id == str(environment_id)) + .where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL.value) ) environment = res.one_or_none() @@ -171,37 +222,19 @@ async def get_launcher(self, user: base_models.APIUser, launcher_id: ULID) -> mo return launcher.dump() async def insert_launcher( - self, user: base_models.APIUser, new_launcher: apispec.SessionLauncherPost + self, user: base_models.APIUser, new_launcher: models.UnsavedSessionLauncher ) -> models.SessionLauncher: """Insert a new session launcher.""" if not user.is_authenticated or user.id is None: raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.") project_id = new_launcher.project_id - authorized = await self.project_authz.has_permission( - user, ResourceType.project, ULID.from_str(project_id), Scope.WRITE - ) + authorized = await self.project_authz.has_permission(user, ResourceType.project, project_id, Scope.WRITE) if not authorized: raise errors.MissingResourceError( message=f"Project with id '{project_id}' does not exist or you do not have access to it." ) - launcher_model = models.SessionLauncher( - id=None, - name=new_launcher.name, - project_id=new_launcher.project_id, - description=new_launcher.description, - environment_kind=new_launcher.environment_kind, - environment_id=new_launcher.environment_id, - resource_class_id=new_launcher.resource_class_id, - container_image=new_launcher.container_image, - default_url=new_launcher.default_url, - created_by=models.Member(id=user.id), - creation_date=datetime.now(UTC).replace(microsecond=0), - ) - - models.SessionLauncher.model_validate(launcher_model) - async with self.session_maker() as session, session.begin(): res = await session.scalars(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id)) project = res.one_or_none() @@ -210,16 +243,26 @@ async def insert_launcher( message=f"Project with id '{project_id}' does not exist or you do not have access to it." ) - environment_id = new_launcher.environment_id - if environment_id is not None: - res = await session.scalars( - select(schemas.EnvironmentORM).where(schemas.EnvironmentORM.id == environment_id) + environment_id: ULID + environment: models.Environment + environment_orm: schemas.EnvironmentORM | None + if isinstance(new_launcher.environment, models.UnsavedEnvironment): + environment_orm = await self.__insert_environment(user, session, new_launcher.environment) + environment = environment_orm.dump() + environment_id = environment.id + else: + environment_id = ULID.from_str(new_launcher.environment) + res_env = await session.scalars( + select(schemas.EnvironmentORM) + .where(schemas.EnvironmentORM.id == environment_id) + .where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL.value) ) - environment = res.one_or_none() - if environment is None: + environment_orm = res_env.one_or_none() + if environment_orm is None: raise errors.MissingResourceError( message=f"Session environment with id '{environment_id}' does not exist or you do not have access to it." # noqa: E501 ) + environment = environment_orm.dump() resource_class_id = new_launcher.resource_class_id if resource_class_id is not None: @@ -239,25 +282,49 @@ async def insert_launcher( message=f"You do not have access to resource class with id '{resource_class_id}'." ) - launcher = schemas.SessionLauncherORM.load(launcher_model) + launcher = schemas.SessionLauncherORM( + name=new_launcher.name, + created_by_id=user.id, + creation_date=datetime.now(UTC), + description=new_launcher.description, + project_id=new_launcher.project_id, + environment_id=environment_id, + resource_class_id=new_launcher.resource_class_id, + ) session.add(launcher) + await session.flush() + await session.refresh(launcher) return launcher.dump() async def update_launcher( - self, user: base_models.APIUser, launcher_id: ULID, **kwargs: Any + self, + user: base_models.APIUser, + launcher_id: ULID, + new_custom_environment: models.UnsavedEnvironment | None, + session: AsyncSession | None = None, + **kwargs: Any, ) -> models.SessionLauncher: """Update a session launcher entry.""" if not user.is_authenticated or user.id is None: raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.") - async with self.session_maker() as session, session.begin(): + session_ctx: AbstractAsyncContextManager = nullcontext() + tx: AbstractAsyncContextManager = nullcontext() + if not session: + session = self.session_maker() + session_ctx = session + if not session.in_transaction(): + tx = session.begin() + + async with session_ctx, tx: res = await session.scalars( select(schemas.SessionLauncherORM).where(schemas.SessionLauncherORM.id == launcher_id) ) launcher = res.one_or_none() if launcher is None: raise errors.MissingResourceError( - message=f"Session launcher with id '{launcher_id}' does not exist or you do not have access to it." # noqa: E501 + message=f"Session launcher with id '{launcher_id}' does not " + "exist or you do not have access to it." ) authorized = await self.project_authz.has_permission( @@ -269,17 +336,6 @@ async def update_launcher( if not authorized: raise errors.ForbiddenError(message="You do not have the required permissions for this operation.") - environment_id = kwargs.get("environment_id") - if environment_id is not None: - res = await session.scalars( - select(schemas.EnvironmentORM).where(schemas.EnvironmentORM.id == environment_id) - ) - environment = res.one_or_none() - if environment is None: - raise errors.MissingResourceError( - message=f"Session environment with id '{environment_id}' does not exist or you do not have access to it." # noqa: E501 - ) - resource_class_id = kwargs.get("resource_class_id") if resource_class_id is not None: res = await session.scalars( @@ -299,29 +355,86 @@ async def update_launcher( ) for key, value in kwargs.items(): - # NOTE: Only ``name``, ``description``, ``environment_kind``, - # ``environment_id``, ``resource_class_id``, ``container_image`` and - # ``default_url`` can be edited. + # NOTE: Only some fields can be updated. if key in [ "name", "description", - "environment_kind", - "environment_id", "resource_class_id", - "container_image", - "default_url", ]: setattr(launcher, key, value) - if launcher.environment_kind == EnvironmentKind.global_environment: - launcher.container_image = None - if launcher.environment_kind == EnvironmentKind.container_image: - launcher.environment = None - - launcher_model = launcher.dump() - models.SessionLauncher.model_validate(launcher_model) + env_payload = kwargs.get("environment", {}) + await self.__update_launcher_environment(user, launcher, session, new_custom_environment, **env_payload) + return launcher.dump() - return launcher_model + async def __update_launcher_environment( + self, + user: base_models.APIUser, + launcher: schemas.SessionLauncherORM, + session: AsyncSession, + new_custom_environment: models.UnsavedEnvironment | None, + **kwargs: Any, + ) -> None: + current_env_kind = launcher.environment.environment_kind + match new_custom_environment, current_env_kind, kwargs: + case None, _, {"id": env_id, **nothing_else} if len(nothing_else) == 0: + # The environment in the launcher is set via ID, the new ID has to refer + # to an environment that is GLOBAL. + old_environment = launcher.environment + new_environment_id = ULID.from_str(env_id) + res_env = await session.scalars( + select(schemas.EnvironmentORM).where(schemas.EnvironmentORM.id == new_environment_id) + ) + new_environment = res_env.one_or_none() + if new_environment is None: + raise errors.MissingResourceError( + message=f"Session environment with id '{new_environment_id}' does not exist or " + "you do not have access to it." + ) + if new_environment.environment_kind != models.EnvironmentKind.GLOBAL: + raise errors.ValidationError( + message="Cannot set the environment for a launcher to an existing environment if that " + "existing environment is not global", + quiet=True, + ) + launcher.environment_id = new_environment_id + launcher.environment = new_environment + if old_environment.environment_kind == models.EnvironmentKind.CUSTOM: + # A custom environment exists but it is being updated to a global one + # We remove the custom environment to avoid accumulating custom environments that are not associated + # with any launchers. + await session.delete(old_environment) + case None, models.EnvironmentKind.CUSTOM, {**rest} if ( + rest.get("environment_kind") is None + or rest.get("environment_kind") == models.EnvironmentKind.CUSTOM.value + ): + # Custom environment being updated + for key, val in rest.items(): + # NOTE: Only some fields can be updated. + if key in [ + "name", + "description", + "container_image", + "default_url", + "port", + "working_directory", + "mount_directory", + "uid", + "gid", + "args", + "command", + ]: + setattr(launcher.environment, key, val) + case models.UnsavedEnvironment(), models.EnvironmentKind.GLOBAL, {**nothing_else} if ( + len(nothing_else) == 0 and new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM + ): + # Global environment replaced by a custom one + new_env = await self.__insert_environment(user, session, new_custom_environment) + launcher.environment = new_env + case _: + raise errors.ValidationError( + message="Encountered an invalid payload for updating a launcher environment", quiet=True + ) async def delete_launcher(self, user: base_models.APIUser, launcher_id: ULID) -> None: """Delete a session launcher entry.""" @@ -347,3 +460,5 @@ async def delete_launcher(self, user: base_models.APIUser, launcher_id: ULID) -> raise errors.ForbiddenError(message="You do not have the required permissions for this operation.") await session.delete(launcher) + if launcher.environment.environment_kind == models.EnvironmentKind.CUSTOM: + await session.delete(launcher.environment) diff --git a/components/renku_data_services/session/models.py b/components/renku_data_services/session/models.py index 422524152..6dcff46c2 100644 --- a/components/renku_data_services/session/models.py +++ b/components/renku_data_services/session/models.py @@ -2,62 +2,100 @@ from dataclasses import dataclass from datetime import datetime +from enum import StrEnum +from pathlib import PurePosixPath -from pydantic import BaseModel, model_validator from ulid import ULID from renku_data_services import errors -from renku_data_services.session.apispec import EnvironmentKind -@dataclass(frozen=True, eq=True, kw_only=True) -class Member(BaseModel): - """Member model.""" +class EnvironmentKind(StrEnum): + """The type of environment.""" - id: str + GLOBAL: str = "GLOBAL" + CUSTOM: str = "CUSTOM" -@dataclass(frozen=True, eq=True, kw_only=True) -class Environment(BaseModel): - """Session environment model.""" +@dataclass(kw_only=True, frozen=True, eq=True) +class BaseEnvironment: + """Base session environment model.""" - id: str | None name: str - creation_date: datetime description: str | None container_image: str - default_url: str | None - created_by: Member + default_url: str + port: int + working_directory: PurePosixPath + mount_directory: PurePosixPath + uid: int + gid: int + environment_kind: EnvironmentKind + args: list[str] | None = None + command: list[str] | None = None + + +@dataclass(kw_only=True, frozen=True, eq=True) +class UnsavedEnvironment(BaseEnvironment): + """Session environment model that has not been saved.""" + + port: int = 8888 + description: str | None = None + working_directory: PurePosixPath = PurePosixPath("/home/jovyan/work") + mount_directory: PurePosixPath = PurePosixPath("/home/jovyan/work") + uid: int = 1000 + gid: int = 1000 + + def __post_init__(self) -> None: + if not self.working_directory.is_absolute(): + raise errors.ValidationError(message="The working directory for a session is supposed to be absolute") + if not self.mount_directory.is_absolute(): + raise errors.ValidationError(message="The mount directory for a session is supposed to be absolute") + if self.working_directory.is_reserved(): + raise errors.ValidationError( + message="The requested value for the working directory is reserved by the OS and cannot be used." + ) + if self.mount_directory.is_reserved(): + raise errors.ValidationError( + message="The requested value for the mount directory is reserved by the OS and cannot be used." + ) + + +@dataclass(kw_only=True, frozen=True, eq=True) +class Environment(BaseEnvironment): + """Session environment model that has been saved in the DB.""" + + id: ULID + creation_date: datetime + created_by: str @dataclass(frozen=True, eq=True, kw_only=True) -class SessionLauncher(BaseModel): +class BaseSessionLauncher: """Session launcher model.""" id: ULID | None - project_id: str + project_id: ULID name: str - creation_date: datetime description: str | None - environment_kind: EnvironmentKind - environment_id: str | None + environment: str | UnsavedEnvironment | Environment resource_class_id: int | None - container_image: str | None - default_url: str | None - created_by: Member - @model_validator(mode="after") - def check_launcher_environment_kind(self) -> "SessionLauncher": - """Validates the environment of a launcher.""" - environment_kind = self.environment_kind - environment_id = self.environment_id - container_image = self.container_image +@dataclass(frozen=True, eq=True, kw_only=True) +class UnsavedSessionLauncher(BaseSessionLauncher): + """Session launcher model that has not been persisted in the DB.""" - if environment_kind == EnvironmentKind.global_environment and environment_id is None: - raise errors.ValidationError(message="'environment_id' not set when environment_kind=global_environment") + id: ULID | None = None + environment: str | UnsavedEnvironment + """When a string is passed for the environment it should be the ID of an existing environment.""" - if environment_kind == EnvironmentKind.container_image and container_image is None: - raise errors.ValidationError(message="'container_image' not set when environment_kind=container_image") - return self +@dataclass(frozen=True, eq=True, kw_only=True) +class SessionLauncher(BaseSessionLauncher): + """Session launcher model that has been already saved in the DB.""" + + id: ULID + creation_date: datetime + created_by: str + environment: Environment diff --git a/components/renku_data_services/session/orm.py b/components/renku_data_services/session/orm.py index 4b61d548c..12217cc39 100644 --- a/components/renku_data_services/session/orm.py +++ b/components/renku_data_services/session/orm.py @@ -1,8 +1,10 @@ """SQLAlchemy's schemas for the sessions database.""" from datetime import datetime +from pathlib import PurePosixPath -from sqlalchemy import DateTime, MetaData, String +from sqlalchemy import JSON, DateTime, MetaData, String +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column, relationship from sqlalchemy.schema import ForeignKey from ulid import ULID @@ -10,10 +12,10 @@ from renku_data_services.crc.orm import ResourceClassORM from renku_data_services.project.orm import ProjectORM from renku_data_services.session import models -from renku_data_services.session.apispec import EnvironmentKind -from renku_data_services.utils.sqlalchemy import ULIDType +from renku_data_services.utils.sqlalchemy import PurePosixPathType, ULIDType metadata_obj = MetaData(schema="sessions") # Has to match alembic ini section name +JSONVariant = JSON().with_variant(JSONB(), "postgresql") class BaseORM(MappedAsDataclass, DeclarativeBase): @@ -27,7 +29,7 @@ class EnvironmentORM(BaseORM): __tablename__ = "environments" - id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False) + id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False) """Id of this session environment object.""" name: Mapped[str] = mapped_column("name", String(99)) @@ -45,31 +47,36 @@ class EnvironmentORM(BaseORM): container_image: Mapped[str] = mapped_column("container_image", String(500)) """Container image repository and tag.""" - default_url: Mapped[str | None] = mapped_column("default_url", String(200)) + default_url: Mapped[str] = mapped_column("default_url", String(200)) """Default URL path to open in a session.""" - @classmethod - def load(cls, environment: models.Environment) -> "EnvironmentORM": - """Create EnvironmentORM from the session environment model.""" - return cls( - name=environment.name, - created_by_id=environment.created_by.id, - creation_date=environment.creation_date, - description=environment.description, - container_image=environment.container_image, - default_url=environment.default_url, - ) + port: Mapped[int] = mapped_column("port") + working_directory: Mapped[PurePosixPath] = mapped_column("working_directory", PurePosixPathType) + mount_directory: Mapped[PurePosixPath] = mapped_column("mount_directory", PurePosixPathType) + uid: Mapped[int] = mapped_column("uid") + gid: Mapped[int] = mapped_column("gid") + environment_kind: Mapped[models.EnvironmentKind] = mapped_column("environment_kind") + args: Mapped[list[str] | None] = mapped_column("args", JSONVariant, nullable=True) + command: Mapped[list[str] | None] = mapped_column("command", JSONVariant, nullable=True) def dump(self) -> models.Environment: """Create a session environment model from the EnvironmentORM.""" return models.Environment( id=self.id, name=self.name, - created_by=models.Member(id=self.created_by_id), + created_by=self.created_by_id, creation_date=self.creation_date, description=self.description, container_image=self.container_image, default_url=self.default_url, + gid=self.gid, + uid=self.uid, + environment_kind=self.environment_kind, + mount_directory=self.mount_directory, + working_directory=self.working_directory, + port=self.port, + args=self.args, + command=self.command, ) @@ -93,24 +100,15 @@ class SessionLauncherORM(BaseORM): description: Mapped[str | None] = mapped_column("description", String(500)) """Human-readable description of the session launcher.""" - environment_kind: Mapped[EnvironmentKind] - """The kind of environment definition to use.""" - - container_image: Mapped[str | None] = mapped_column("container_image", String(500)) - """Container image repository and tag.""" - - default_url: Mapped[str | None] = mapped_column("default_url", String(200)) - """Default URL path to open in a session.""" - project: Mapped[ProjectORM] = relationship(init=False) - environment: Mapped[EnvironmentORM | None] = relationship(init=False) + environment: Mapped[EnvironmentORM] = relationship(init=False, lazy="joined") project_id: Mapped[ULID] = mapped_column( "project_id", ForeignKey(ProjectORM.id, ondelete="CASCADE"), default=None, index=True ) """Id of the project this session belongs to.""" - environment_id: Mapped[str | None] = mapped_column( + environment_id: Mapped[ULID] = mapped_column( "environment_id", ForeignKey(EnvironmentORM.id), default=None, nullable=True, index=True ) """Id of the session environment.""" @@ -129,29 +127,23 @@ def load(cls, launcher: models.SessionLauncher) -> "SessionLauncherORM": """Create SessionLauncherORM from the session launcher model.""" return cls( name=launcher.name, - created_by_id=launcher.created_by.id, + created_by_id=launcher.created_by, creation_date=launcher.creation_date, description=launcher.description, - environment_kind=launcher.environment_kind, - container_image=launcher.container_image, project_id=ULID.from_str(launcher.project_id), - environment_id=launcher.environment_id, resource_class_id=launcher.resource_class_id, - default_url=launcher.default_url, + environment_id=launcher.environment.id, ) def dump(self) -> models.SessionLauncher: """Create a session launcher model from the SessionLauncherORM.""" return models.SessionLauncher( id=self.id, - project_id=str(self.project_id), + project_id=self.project_id, name=self.name, - created_by=models.Member(id=self.created_by_id), + created_by=self.created_by_id, creation_date=self.creation_date, description=self.description, - environment_kind=self.environment_kind, - environment_id=self.environment_id if self.environment_id is not None else None, - resource_class_id=self.resource_class_id if self.resource_class_id is not None else None, - container_image=self.container_image, - default_url=self.default_url, + environment=self.environment.dump(), + resource_class_id=self.resource_class_id, ) diff --git a/components/renku_data_services/utils/sqlalchemy.py b/components/renku_data_services/utils/sqlalchemy.py index f1cd59c9b..82bcfb9db 100644 --- a/components/renku_data_services/utils/sqlalchemy.py +++ b/components/renku_data_services/utils/sqlalchemy.py @@ -1,5 +1,6 @@ """Utilities for SQLAlchemy.""" +from pathlib import PurePosixPath from typing import cast from sqlalchemy import Dialect, types @@ -23,3 +24,22 @@ def process_result_value(self, value: str | None, dialect: Dialect) -> ULID | No if value is None: return None return cast(ULID, ULID.from_str(value)) # cast because mypy doesn't understand ULID type annotations + + +class PurePosixPathType(types.TypeDecorator): + """Wrapper type for Path <--> str conversion.""" + + impl = types.String + cache_ok = True + + def process_bind_param(self, value: PurePosixPath | None, dialect: Dialect) -> str | None: + """Transform value for storing in the database.""" + if value is None: + return None + return value.as_posix() + + def process_result_value(self, value: str | None, dialect: Dialect) -> PurePosixPath | None: + """Transform string from database into PosixPath.""" + if value is None: + return None + return PurePosixPath(value) diff --git a/test/bases/renku_data_services/data_api/conftest.py b/test/bases/renku_data_services/data_api/conftest.py index b638f4198..6f53e5254 100644 --- a/test/bases/renku_data_services/data_api/conftest.py +++ b/test/bases/renku_data_services/data_api/conftest.py @@ -291,7 +291,18 @@ async def create_resource_pool_helper(admin: bool = False, **payload) -> dict[st "default": True, "node_affinities": [], "tolerations": [], - } + }, + { + "cpu": 2.0, + "memory": 20, + "gpu": 0, + "name": "test-class-name", + "max_storage": 200, + "default_storage": 2, + "default": False, + "node_affinities": [], + "tolerations": [], + }, ], "quota": {"cpu": 100, "memory": 100, "gpu": 0}, "default": False, diff --git a/test/bases/renku_data_services/data_api/test_migrations.py b/test/bases/renku_data_services/data_api/test_migrations.py index 31d8bd74f..6c2d8ac76 100644 --- a/test/bases/renku_data_services/data_api/test_migrations.py +++ b/test/bases/renku_data_services/data_api/test_migrations.py @@ -1,9 +1,12 @@ import base64 +from datetime import UTC, datetime from typing import Any import pytest +import sqlalchemy as sa from alembic.script import ScriptDirectory from sanic_testing.testing import SanicASGITestClient +from ulid import ULID from renku_data_services.app_config.config import Config from renku_data_services.message_queue.avro_models.io.renku.events import v2 @@ -94,3 +97,39 @@ async def test_migration_to_f34b87ddd954( ] assert len(group_removed_events) == 2 assert set(added_group_ids) == {e.id for e in group_removed_events} + + +@pytest.mark.asyncio +async def test_migration_to_584598f3b769(app_config: Config) -> None: + run_migrations_for_app("common", "dcc1c1ee662f") + await app_config.kc_user_repo.initialize(app_config.kc_api) + await app_config.group_repo.generate_user_namespaces() + env_id = str(ULID()) + async with app_config.db.async_session_maker() as session, session.begin(): + await session.execute( + sa.text( + "INSERT INTO " + "sessions.environments(id, name, created_by_id, creation_date, container_image, default_url) " + "VALUES (:id, :name, :created_by, :date, :image, :url)" + ).bindparams( + id=env_id, + name="test", + created_by="test", + date=datetime.now(UTC), + image="test", + url="/test", + ) + ) + run_migrations_for_app("common", "584598f3b769") + async with app_config.db.async_session_maker() as session, session.begin(): + res = await session.execute(sa.text("SELECT * FROM sessions.environments")) + data = res.all() + assert len(data) == 1 + env = data[0]._mapping + assert env["id"] == env_id + assert env["name"] == "test" + assert env["container_image"] == "test" + assert env["default_url"] == "/test" + assert env["port"] == 8888 + assert env["uid"] == 1000 + assert env["gid"] == 1000 diff --git a/test/bases/renku_data_services/data_api/test_resource_pools.py b/test/bases/renku_data_services/data_api/test_resource_pools.py index 58807514c..a7d15ed77 100644 --- a/test/bases/renku_data_services/data_api/test_resource_pools.py +++ b/test/bases/renku_data_services/data_api/test_resource_pools.py @@ -300,7 +300,7 @@ async def test_put_resource_class( ) -> None: _, res = await create_rp(valid_resource_pool_payload, sanic_client) assert res.status_code == 201 - assert len(res.json.get("classes", [])) == 1 + assert len(res.json.get("classes", [])) == 2 res_cls_payload = {**res.json.get("classes", [])[0], "cpu": 5.0} res_cls_expected_response = {**res.json.get("classes", [])[0], "cpu": 5.0} res_cls_payload.pop("id", None) @@ -672,7 +672,7 @@ async def test_patch_tolerations( rp = res.json rp_id = rp["id"] assert len(rp["classes"]) > 0 - res_class = rp["classes"][0] + res_class = [i for i in rp["classes"] if len(i.get("tolerations", [])) > 0][0] res_class_id = res_class["id"] assert len(res_class["tolerations"]) == 1 # Patch in a 2nd toleration @@ -712,7 +712,7 @@ async def test_patch_affinities( rp = res.json rp_id = rp["id"] assert len(rp["classes"]) > 0 - res_class = rp["classes"][0] + res_class = [i for i in rp["classes"] if len(i.get("node_affinities", [])) > 0][0] res_class_id = res_class["id"] assert len(res_class["node_affinities"]) == 1 assert res_class["node_affinities"][0] == {"key": "affinity1", "required_during_scheduling": False} @@ -766,7 +766,7 @@ async def test_remove_all_tolerations_put( rp = res.json rp_id = rp["id"] assert len(rp["classes"]) > 0 - res_class = rp["classes"][0] + res_class = [i for i in rp["classes"] if len(i.get("tolerations", [])) > 0][0] res_class_id = res_class["id"] assert len(res_class["tolerations"]) == 1 assert res_class["tolerations"][0] == "toleration1" @@ -798,7 +798,7 @@ async def test_remove_all_affinities_put( rp = res.json rp_id = rp["id"] assert len(rp["classes"]) > 0 - res_class = rp["classes"][0] + res_class = [i for i in rp["classes"] if len(i.get("node_affinities", [])) > 0][0] res_class_id = res_class["id"] assert len(res_class["node_affinities"]) == 1 assert res_class["node_affinities"][0] == {"key": "affinity1", "required_during_scheduling": False} @@ -830,7 +830,7 @@ async def test_put_tolerations( rp = res.json rp_id = rp["id"] assert len(rp["classes"]) > 0 - res_class = rp["classes"][0] + res_class = [i for i in rp["classes"] if len(i.get("tolerations", [])) > 0][0] res_class_id = res_class["id"] assert len(res_class["tolerations"]) == 1 assert res_class["tolerations"][0] == "toleration1" @@ -862,7 +862,7 @@ async def test_put_affinities( rp = res.json rp_id = rp["id"] assert len(rp["classes"]) > 0 - res_class = rp["classes"][0] + res_class = [i for i in rp["classes"] if len(i.get("node_affinities", [])) > 0][0] res_class_id = res_class["id"] assert len(res_class["node_affinities"]) == 1 assert res_class["node_affinities"][0] == {"key": "affinity1", "required_during_scheduling": False} @@ -900,7 +900,7 @@ async def test_get_all_tolerations( rp = res.json rp_id = rp["id"] assert len(rp["classes"]) > 0 - res_class = rp["classes"][0] + res_class = [i for i in rp["classes"] if len(i.get("tolerations", [])) > 0][0] res_class_id = res_class["id"] _, res = await sanic_client.get( f"/api/data/resource_pools/{rp_id}/classes/{res_class_id}/tolerations", @@ -920,7 +920,7 @@ async def test_get_all_affinities( rp = res.json rp_id = rp["id"] assert len(rp["classes"]) > 0 - res_class = rp["classes"][0] + res_class = [i for i in rp["classes"] if len(i.get("node_affinities", [])) > 0][0] res_class_id = res_class["id"] _, res = await sanic_client.get( f"/api/data/resource_pools/{rp_id}/classes/{res_class_id}/node_affinities", diff --git a/test/bases/renku_data_services/data_api/test_sessions.py b/test/bases/renku_data_services/data_api/test_sessions.py index 825e190a1..2be42158c 100644 --- a/test/bases/renku_data_services/data_api/test_sessions.py +++ b/test/bases/renku_data_services/data_api/test_sessions.py @@ -29,10 +29,12 @@ async def create_session_launcher_helper(name: str, project_id: str, **payload) payload = payload.copy() payload.update({"name": name, "project_id": project_id}) payload["description"] = payload.get("description") or "A session launcher." - payload["environment_kind"] = payload.get("environment_kind") or "container_image" - - if payload["environment_kind"] == "container_image": - payload["container_image"] = payload.get("container_image") or "some_image:some_tag" + if "environment" not in payload: + payload["environment"] = { + "environment_kind": "CUSTOM", + "name": "Test", + "container_image": "some_image:some_tag", + } _, res = await sanic_client.post("/api/data/session_launchers", headers=user_headers, json=payload) @@ -110,7 +112,7 @@ async def test_post_session_environment_unauthorized(sanic_client: SanicASGITest _, res = await sanic_client.post("/api/data/environments", headers=user_headers, json=payload) - assert res.status_code == 403, res.text + assert res.status_code == 401, res.text @pytest.mark.asyncio @@ -150,7 +152,7 @@ async def test_patch_session_environment_unauthorized( _, res = await sanic_client.patch(f"/api/data/environments/{environment_id}", headers=user_headers, json=payload) - assert res.status_code == 403, res.text + assert res.status_code == 401, res.text @pytest.mark.asyncio @@ -217,8 +219,7 @@ async def test_get_session_launcher( "Some launcher", project_id=project["id"], description="Some launcher.", - environment_kind="global_environment", - environment_id=env["id"], + environment={"id": env["id"]}, ) launcher_id = launcher["id"] @@ -229,9 +230,10 @@ async def test_get_session_launcher( assert res.json.get("name") == "Some launcher" assert res.json.get("project_id") == project["id"] assert res.json.get("description") == "Some launcher." - assert res.json.get("environment_kind") == "global_environment" - assert res.json.get("environment_id") == env["id"] - assert res.json.get("container_image") is None + environment = res.json.get("environment", {}) + assert environment.get("environment_kind") == "GLOBAL" + assert environment.get("id") == env["id"] + assert environment.get("container_image") == env["container_image"] assert res.json.get("resource_class_id") is None @@ -276,9 +278,12 @@ async def test_post_session_launcher( "name": "Launcher 1", "project_id": project["id"], "description": "A session launcher.", - "environment_kind": "container_image", - "container_image": "some_image:some_tag", "resource_class_id": resource_pool["classes"][0]["id"], + "environment": { + "container_image": "some_image:some_tag", + "name": "custom_name", + "environment_kind": "CUSTOM", + }, } _, res = await sanic_client.post("/api/data/session_launchers", headers=admin_headers, json=payload) @@ -288,9 +293,10 @@ async def test_post_session_launcher( assert res.json.get("name") == "Launcher 1" assert res.json.get("project_id") == project["id"] assert res.json.get("description") == "A session launcher." - assert res.json.get("environment_kind") == "container_image" - assert res.json.get("container_image") == "some_image:some_tag" - assert res.json.get("environment_id") is None + environment = res.json.get("environment", {}) + assert environment.get("environment_kind") == "CUSTOM" + assert environment.get("container_image") == "some_image:some_tag" + assert environment.get("id") is not None assert res.json.get("resource_class_id") == resource_pool["classes"][0]["id"] @@ -303,20 +309,21 @@ async def test_post_session_launcher_unauthorized( create_project, create_resource_pool, regular_user, + create_session_environment, ) -> None: project = await create_project("Some project") resource_pool_data = valid_resource_pool_payload resource_pool_data["public"] = False resource_pool = await create_resource_pool(admin=True, **resource_pool_data) + environment = await create_session_environment("Test environment") payload = { "name": "Launcher 1", "project_id": project["id"], "description": "A session launcher.", - "environment_kind": "container_image", - "container_image": "some_image:some_tag", "resource_class_id": resource_pool["classes"][0]["id"], + "environment": {"id": environment["id"]}, } _, res = await sanic_client.post("/api/data/session_launchers", headers=user_headers, json=payload) @@ -338,3 +345,130 @@ async def test_delete_session_launcher( _, res = await sanic_client.delete(f"/api/data/session_launchers/{launcher_id}", headers=user_headers) assert res.status_code == 204, res.text + + +@pytest.mark.asyncio +async def test_patch_session_launcher( + sanic_client: SanicASGITestClient, + valid_resource_pool_payload: dict[str, Any], + user_headers, + create_project, + create_resource_pool, +) -> None: + project = await create_project("Some project 1") + resource_pool_data = valid_resource_pool_payload + resource_pool = await create_resource_pool(admin=True, **resource_pool_data) + + payload = { + "name": "Launcher 1", + "project_id": project["id"], + "description": "A session launcher.", + "resource_class_id": resource_pool["classes"][0]["id"], + "environment": { + "container_image": "some_image:some_tag", + "name": "custom_name", + "environment_kind": "CUSTOM", + }, + } + + _, res = await sanic_client.post("/api/data/session_launchers", headers=user_headers, json=payload) + + assert res.status_code == 201, res.text + assert res.json is not None + assert res.json.get("name") == "Launcher 1" + assert res.json.get("description") == "A session launcher." + environment = res.json.get("environment", {}) + assert environment.get("environment_kind") == "CUSTOM" + assert environment.get("container_image") == "some_image:some_tag" + assert environment.get("id") is not None + assert res.json.get("resource_class_id") == resource_pool["classes"][0]["id"] + + patch_payload = { + "name": "New Name", + "description": "An updated session launcher.", + "resource_class_id": resource_pool["classes"][1]["id"], + } + _, res = await sanic_client.patch( + f"/api/data/session_launchers/{res.json['id']}", headers=user_headers, json=patch_payload + ) + assert res.status_code == 200, res.text + assert res.json is not None + assert res.json.get("name") == patch_payload["name"] + assert res.json.get("description") == patch_payload["description"] + assert res.json.get("resource_class_id") == patch_payload["resource_class_id"] + + +@pytest.mark.asyncio +async def test_patch_session_launcher_environment( + sanic_client: SanicASGITestClient, + valid_resource_pool_payload: dict[str, Any], + user_headers, + create_project, + create_resource_pool, + create_session_environment, +) -> None: + project = await create_project("Some project 1") + resource_pool_data = valid_resource_pool_payload + resource_pool = await create_resource_pool(admin=True, **resource_pool_data) + global_env = await create_session_environment("Some environment") + + # Create a new custom environment with the launcher + payload = { + "name": "Launcher 1", + "project_id": project["id"], + "description": "A session launcher.", + "resource_class_id": resource_pool["classes"][0]["id"], + "environment": { + "container_image": "some_image:some_tag", + "name": "custom_name", + "environment_kind": "CUSTOM", + }, + } + _, res = await sanic_client.post("/api/data/session_launchers", headers=user_headers, json=payload) + assert res.status_code == 201, res.text + assert res.json is not None + environment = res.json.get("environment", {}) + assert environment.get("environment_kind") == "CUSTOM" + assert environment.get("container_image") == "some_image:some_tag" + assert environment.get("id") is not None + + # Patch in a global environment + patch_payload = { + "environment": {"id": global_env["id"]}, + } + _, res = await sanic_client.patch( + f"/api/data/session_launchers/{res.json['id']}", headers=user_headers, json=patch_payload + ) + assert res.status_code == 200, res.text + assert res.json is not None + launcher_id = res.json["id"] + global_env["environment_kind"] = "GLOBAL" + assert res.json["environment"] == global_env + + # Trying to patch a field of the global environment should fail + patch_payload = { + "environment": {"container_image": "new_image"}, + } + _, res = await sanic_client.patch( + f"/api/data/session_launchers/{launcher_id}", headers=user_headers, json=patch_payload + ) + assert res.status_code == 422, res.text + + # Patching in a wholly new custom environment over the global is allowed + patch_payload = { + "environment": {"container_image": "new_image", "name": "new_custom", "environment_kind": "CUSTOM"}, + } + _, res = await sanic_client.patch( + f"/api/data/session_launchers/{launcher_id}", headers=user_headers, json=patch_payload + ) + assert res.status_code == 200, res.text + + # Should be able to patch some fields of the custom environment + patch_payload = { + "environment": {"container_image": "nginx:latest"}, + } + _, res = await sanic_client.patch( + f"/api/data/session_launchers/{launcher_id}", headers=user_headers, json=patch_payload + ) + assert res.status_code == 200, res.text + assert res.json["environment"]["container_image"] == "nginx:latest"