diff --git a/components/renku_data_services/notebooks/api.spec.yaml b/components/renku_data_services/notebooks/api.spec.yaml index 1ded7088b..10fd3cbc0 100644 --- a/components/renku_data_services/notebooks/api.spec.yaml +++ b/components/renku_data_services/notebooks/api.spec.yaml @@ -25,6 +25,7 @@ paths: application/json: schema: "$ref": "#/components/schemas/ImageCheckResponse" + description: The image check has completed successfully '422': content: application/json: @@ -886,15 +887,13 @@ components: launcher_id: $ref: "#/components/schemas/Ulid" disk_storage: - default: 1 type: integer description: The size of disk storage for the session, in gigabytes resource_class_id: - default: nullable: true type: integer - cloudstorage: - $ref: "#/components/schemas/SessionCloudStoragePostList" + data_connectors_overrides: + $ref: "#/components/schemas/SessionDataConnectorsOverrideList" env_variable_overrides: $ref: "#/components/schemas/EnvVariableOverrides" required: @@ -1015,28 +1014,54 @@ components: minLength: 26 maxLength: 26 pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" - SessionCloudStoragePostList: + SessionDataConnectorsOverrideList: type: array items: - "$ref": "#/components/schemas/SessionCloudStoragePost" - SessionCloudStoragePost: + $ref: "#/components/schemas/SessionDataConnectorOverride" + SessionDataConnectorOverride: type: object properties: - configuration: - type: object - additionalProperties: true - readonly: + skip: type: boolean + description: The corresponding data connector will not be mounted if `skip` is set to `true`. + default: false + data_connector_id: + allOf: + - $ref: "#/components/schemas/Ulid" + - description: | + The `data_connector_id` must match an existing data connector from the session launcher's project. + configuration: + $ref: "#/components/schemas/RCloneConfig" source_path: - type: string + $ref: "#/components/schemas/SourcePath" target_path: - type: string - storage_id: - allOf: - - "$ref": "#/components/schemas/Ulid" - - description: If the storage_id is provided then this config must replace an existing storage config in the session + $ref: "#/components/schemas/TargetPath" + readonly: + $ref: "#/components/schemas/StorageReadOnly" required: - - storage_id + - data_connector_id + RCloneConfig: + type: object + description: Dictionary of rclone key:value pairs (based on schema from '/storage_schema') + additionalProperties: + oneOf: + - type: integer + - type: string + nullable: true + - type: boolean + - type: object + SourcePath: + description: the source path to mount, usually starts with bucket/container name + type: string + example: bucket/my/storage/folder/ + TargetPath: + description: the target path relative to the working directory where the storage should be mounted + type: string + example: my/project/folder + StorageReadOnly: + description: Whether this storage should be mounted readonly or not + type: boolean + default: true ServerName: type: string minLength: 5 diff --git a/components/renku_data_services/notebooks/apispec.py b/components/renku_data_services/notebooks/apispec.py index 65f56bac7..67f6ab916 100644 --- a/components/renku_data_services/notebooks/apispec.py +++ b/components/renku_data_services/notebooks/apispec.py @@ -1,12 +1,12 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2025-09-12T06:56:48+00:00 +# timestamp: 2025-10-15T12:41:50+00:00 from __future__ import annotations from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic import ConfigDict, Field, RootModel from renku_data_services.notebooks.apispec_base import BaseAPISpec @@ -267,20 +267,6 @@ class SessionLogsResponse(RootModel[Optional[Dict[str, str]]]): root: Optional[Dict[str, str]] = None -class SessionCloudStoragePost(BaseAPISpec): - configuration: Optional[Dict[str, Any]] = None - readonly: Optional[bool] = None - source_path: Optional[str] = None - target_path: Optional[str] = None - storage_id: str = Field( - ..., - description="ULID identifier", - max_length=26, - min_length=26, - pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", - ) - - class ImageConnectionStatus(Enum): connected = "connected" pending = "pending" @@ -368,6 +354,36 @@ class SessionResources(BaseAPISpec): requests: Optional[SessionResourcesRequests] = None +class SessionDataConnectorOverride(BaseAPISpec): + skip: bool = Field( + False, + description="The corresponding data connector will not be mounted if `skip` is set to `true`.", + ) + data_connector_id: str = Field( + ..., + description="ULID identifier", + max_length=26, + min_length=26, + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", + ) + configuration: Optional[ + Dict[str, Union[int, Optional[str], bool, Dict[str, Any]]] + ] = None + source_path: Optional[str] = Field( + None, + description="the source path to mount, usually starts with bucket/container name", + examples=["bucket/my/storage/folder/"], + ) + target_path: Optional[str] = Field( + None, + description="the target path relative to the working directory where the storage should be mounted", + examples=["my/project/folder"], + ) + readonly: Optional[bool] = Field( + True, description="Whether this storage should be mounted readonly or not" + ) + + class ImageConnection(BaseAPISpec): id: str provider_id: str @@ -396,24 +412,6 @@ class ServersGetResponse(BaseAPISpec): servers: Optional[Dict[str, NotebookResponse]] = None -class SessionPostRequest(BaseAPISpec): - launcher_id: str = Field( - ..., - description="ULID identifier", - max_length=26, - min_length=26, - pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", - ) - disk_storage: int = Field( - 1, description="The size of disk storage for the session, in gigabytes" - ) - resource_class_id: Optional[int] = None - cloudstorage: Optional[List[SessionCloudStoragePost]] = None - env_variable_overrides: Optional[List[EnvVarOverride]] = Field( - None, description="Environment variable overrides for the session pod" - ) - - class SessionResponse(BaseAPISpec): image: str name: str = Field( @@ -452,3 +450,21 @@ class ImageCheckResponse(BaseAPISpec): accessible: bool = Field(..., description="Whether the image is accessible or not.") connection: Optional[ImageConnection] = None provider: Optional[ImageProvider] = None + + +class SessionPostRequest(BaseAPISpec): + launcher_id: str = Field( + ..., + description="ULID identifier", + max_length=26, + min_length=26, + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", + ) + disk_storage: Optional[int] = Field( + None, description="The size of disk storage for the session, in gigabytes" + ) + resource_class_id: Optional[int] = None + data_connectors_overrides: Optional[List[SessionDataConnectorOverride]] = None + env_variable_overrides: Optional[List[EnvVarOverride]] = Field( + None, description="Environment variable overrides for the session pod" + ) diff --git a/components/renku_data_services/notebooks/blueprints.py b/components/renku_data_services/notebooks/blueprints.py index aa48c0702..f303fe43e 100644 --- a/components/renku_data_services/notebooks/blueprints.py +++ b/components/renku_data_services/notebooks/blueprints.py @@ -28,6 +28,8 @@ from renku_data_services.notebooks.core_sessions import ( patch_session, start_session, + validate_session_patch_request, + validate_session_post_request, ) from renku_data_services.notebooks.errors.intermittent import AnonymousUserPatchError from renku_data_services.project.db import ProjectRepository, ProjectSessionSecretRepository @@ -219,9 +221,10 @@ async def _handler( internal_gitlab_user: APIUser, body: apispec.SessionPostRequest, ) -> JSONResponse: + launch_request = validate_session_post_request(body=body) session, created = await start_session( request=request, - body=body, + launch_request=launch_request, user=user, internal_gitlab_user=internal_gitlab_user, nb_config=self.nb_config, @@ -289,19 +292,22 @@ async def _handler( session_id: str, body: apispec.SessionPatchRequest, ) -> HTTPResponse: + patch_request = validate_session_patch_request(body=body) new_session = await patch_session( - body=body, + patch_request=patch_request, session_id=session_id, user=user, internal_gitlab_user=internal_gitlab_user, nb_config=self.nb_config, git_provider_helper=self.git_provider_helper, + connected_svcs_repo=self.connected_svcs_repo, + data_connector_secret_repo=self.data_connector_secret_repo, project_repo=self.project_repo, project_session_secret_repo=self.project_session_secret_repo, rp_repo=self.rp_repo, session_repo=self.session_repo, + user_repo=self.user_repo, metrics=self.metrics, - connected_svcs_repo=self.connected_svcs_repo, ) return json(new_session.as_apispec().model_dump(exclude_none=True, mode="json")) diff --git a/components/renku_data_services/notebooks/core_sessions.py b/components/renku_data_services/notebooks/core_sessions.py index d70132402..cae1aa0d9 100644 --- a/components/renku_data_services/notebooks/core_sessions.py +++ b/components/renku_data_services/notebooks/core_sessions.py @@ -76,10 +76,17 @@ SessionLocation, ShmSizeStr, SizeStr, - State, Storage, ) -from renku_data_services.notebooks.models import ExtraSecret, SessionExtraResources +from renku_data_services.notebooks.models import ( + ExtraSecret, + SessionDataConnectorOverride, + SessionEnvVar, + SessionExtraResources, + SessionLaunchRequest, + SessionPatchRequest, + SessionState, +) from renku_data_services.notebooks.util.kubernetes_ import ( renku_2_make_server_name, ) @@ -253,7 +260,7 @@ async def get_data_sources( server_name: str, data_connectors_stream: AsyncIterator[DataConnectorWithSecrets], work_dir: PurePosixPath, - cloud_storage_overrides: list[apispec.SessionCloudStoragePost], + data_connectors_overrides: list[SessionDataConnectorOverride], user_repo: UserRepo, ) -> SessionExtraResources: """Generate cloud storage related resources.""" @@ -261,6 +268,7 @@ async def get_data_sources( secrets: list[ExtraSecret] = [] dcs: dict[str, RCloneStorage] = {} dcs_secrets: dict[str, list[DataConnectorSecret]] = {} + skipped_dcs: set[str] = set() user_secret_key: str | None = None async for dc in data_connectors_stream: mount_folder = ( @@ -285,17 +293,22 @@ async def get_data_sources( # NOTE: Check the cloud storage overrides from the request body and if any match # then overwrite the projects cloud storages # NOTE: Cloud storages in the session launch request body that are not from the DB will cause a 404 error - # NOTE: Overriding the configuration when a saved secret is there will cause a 422 error - for csr in cloud_storage_overrides: - csr_id = csr.storage_id - if csr_id not in dcs: + # TODO: Is this correct? -> NOTE: Overriding the configuration when a saved secret is there will cause a 422 error + for dco in data_connectors_overrides: + dc_id = str(dco.data_connector_id) + if dc_id not in dcs: raise errors.MissingResourceError( - message=f"You have requested a cloud storage with ID {csr_id} which does not exist " + message=f"You have requested a data connector with ID {dc_id} which does not exist " "or you don't have access to." ) - if csr.target_path is not None and not PurePosixPath(csr.target_path).is_absolute(): - csr.target_path = (work_dir / csr.target_path).as_posix() - dcs[csr_id] = dcs[csr_id].with_override(csr) + # NOTE: if 'skip' is true, we do not mount that data connector + if dco.skip: + skipped_dcs.add(dc_id) + del dcs[dc_id] + continue + if dco.target_path is not None and not PurePosixPath(dco.target_path).is_absolute(): + dco.target_path = (work_dir / dco.target_path).as_posix() + dcs[dc_id] = dcs[dc_id].with_override(dco) # Handle potential duplicate target_path dcs = _deduplicate_target_paths(dcs) @@ -325,13 +338,85 @@ async def get_data_sources( accessMode="ReadOnlyMany" if cs.readonly else "ReadWriteOnce", ) ) + + # Add annotations to track skipped data connectors + # annotations: dict[str, str] = {"renku.io/mounted_data_connectors_ids": json.dumps(sorted(dcs.keys()))} + annotations: dict[str, str] = dict() + if skipped_dcs: + annotations["renku.io/skipped_data_connectors_ids"] = json.dumps(sorted(skipped_dcs)) + return SessionExtraResources( + annotations=annotations, data_sources=data_sources, secrets=secrets, data_connector_secrets=dcs_secrets, ) +async def patch_data_sources( + existing_session: AmaltheaSessionV1Alpha1, + nb_config: NotebooksConfig, + user: AnonymousAPIUser | AuthenticatedAPIUser, + server_name: str, + data_connectors_stream: AsyncIterator[DataConnectorWithSecrets], + work_dir: PurePosixPath, + data_connectors_overrides: list[SessionDataConnectorOverride], + user_repo: UserRepo, +) -> None: # -> SessionExtraResources: + """Handle patching data sources.""" + + # First, collect the data connectors we already mount in the session + existing_dcs: set[str] = set() + secret_name_prefix = f"{server_name}-ds-" + for ds in existing_session.spec.dataSources or []: + if not ds.secretRef: + continue + if ds.secretRef.name.startswith(secret_name_prefix): + dc_id = str(ULID.from_str(ds.secretRef.name[len(secret_name_prefix) :].upper())) + existing_dcs.add(dc_id) + logger.warning(f"existing_dcs = {existing_dcs}") + + # Collect the data connectors we already skip + existing_skipped_dcs: set[str] = set( + json.loads(existing_session.metadata.annotations.get("renku.io/skipped_data_connectors_ids", "[]")) + ) + logger.warning(f"existing_skipped_dcs = {existing_skipped_dcs}") + + # Collect the previously skipped data connectors we should mount now + newly_unskipped_dcs: set[str] = set() + for dco in data_connectors_overrides: + dc_id = str(dco.data_connector_id) + if dc_id in existing_skipped_dcs and not dco.skip: + newly_unskipped_dcs.add(dc_id) + logger.warning(f"newly_unskipped_dcs = {newly_unskipped_dcs}") + + # Collect the new data connectors + new_dcs: dict[str, DataConnectorWithSecrets] = dict() + async for dc in data_connectors_stream: + dc_id = str(dc.data_connector.id) + if (dc_id in newly_unskipped_dcs) or ((dc_id not in existing_dcs) and (dc_id not in existing_skipped_dcs)): + new_dcs[dc_id] = dc + logger.warning(f"new_dcs = {sorted(new_dcs.keys())}") + + async def new_dcs_stream() -> AsyncIterator[DataConnectorWithSecrets]: + for dc in new_dcs.values(): + yield dc + + session_extras = await get_data_sources( + nb_config=nb_config, + server_name=server_name, + user=user, + data_connectors_stream=new_dcs_stream(), + work_dir=work_dir, + data_connectors_overrides=data_connectors_overrides, + user_repo=user_repo, + ) + logger.warning(f"session_extras.annotations = {session_extras.annotations}") + logger.warning(f"session_extras.data_sources = {session_extras.data_sources}") + + pass + + async def request_dc_secret_creation( user: AuthenticatedAPIUser | AnonymousAPIUser, nb_config: NotebooksConfig, @@ -378,10 +463,10 @@ async def request_dc_secret_creation( ) -def get_launcher_env_variables(launcher: SessionLauncher, body: apispec.SessionPostRequest) -> list[SessionEnvItem]: +def get_launcher_env_variables(launcher: SessionLauncher, launch_request: SessionLaunchRequest) -> list[SessionEnvItem]: """Get the environment variables from the launcher, with overrides from the request.""" output: list[SessionEnvItem] = [] - env_overrides = {i.name: i.value for i in body.env_variable_overrides or []} + env_overrides = {i.name: i.value for i in launch_request.env_variable_overrides or []} for env in launcher.env_variables or []: if env.name in env_overrides: output.append(SessionEnvItem(name=env.name, value=env_overrides[env.name])) @@ -390,9 +475,9 @@ def get_launcher_env_variables(launcher: SessionLauncher, body: apispec.SessionP return output -def verify_launcher_env_variable_overrides(launcher: SessionLauncher, body: apispec.SessionPostRequest) -> None: +def verify_launcher_env_variable_overrides(launcher: SessionLauncher, launch_request: SessionLaunchRequest) -> None: """Raise an error if there are env variables that are not defined in the launcher.""" - env_overrides = {i.name: i.value for i in body.env_variable_overrides or []} + env_overrides = {i.name: i.value for i in launch_request.env_variable_overrides or []} known_env_names = {i.name for i in launcher.env_variables or []} unknown_env_names = set(env_overrides.keys()) - known_env_names if unknown_env_names: @@ -649,7 +734,7 @@ def get_remote_env( async def start_session( request: Request, - body: apispec.SessionPostRequest, + launch_request: SessionLaunchRequest, user: AnonymousAPIUser | AuthenticatedAPIUser, internal_gitlab_user: APIUser, nb_config: NotebooksConfig, @@ -669,18 +754,19 @@ async def start_session( Returns a tuple where the first item is an instance of an Amalthea session and the second item is a boolean set to true iff a new session was created. """ - launcher = await session_repo.get_launcher(user, ULID.from_str(body.launcher_id)) + launcher = await session_repo.get_launcher(user=user, launcher_id=launch_request.launcher_id) + launcher_id = launcher.id project = await project_repo.get_project(user=user, project_id=launcher.project_id) # Determine resource_class_id: the class can be overwritten at the user's request - resource_class_id = body.resource_class_id or launcher.resource_class_id + resource_class_id = launch_request.resource_class_id or launcher.resource_class_id cluster = await nb_config.k8s_v2_client.cluster_by_class_id(resource_class_id, user) server_name = renku_2_make_server_name( - user=user, project_id=str(launcher.project_id), launcher_id=body.launcher_id, cluster_id=str(cluster.id) + user=user, project_id=str(launcher.project_id), launcher_id=str(launcher_id), cluster_id=str(cluster.id) ) - existing_session = await nb_config.k8s_v2_client.get_session(server_name, user.id) + existing_session = await nb_config.k8s_v2_client.get_session(name=server_name, safe_username=user.id) if existing_session is not None and existing_session.spec is not None: return existing_session, False @@ -698,7 +784,8 @@ async def start_session( resource_class = resource_pool.get_resource_class(resource_class_id) if not resource_class or not resource_class.id: raise errors.MissingResourceError(message=f"The resource class with ID {resource_class_id} does not exist.") - await nb_config.crc_validator.validate_class_storage(user, resource_class.id, body.disk_storage) + await nb_config.crc_validator.validate_class_storage(user, resource_class.id, launch_request.disk_storage) + disk_storage = launch_request.disk_storage or resource_class.default_storage # Determine session location session_location = SessionLocation.remote if resource_pool.remote else SessionLocation.local @@ -742,7 +829,7 @@ async def start_session( user=user, data_connectors_stream=data_connectors_stream, work_dir=work_dir, - cloud_storage_overrides=body.cloudstorage or [], + data_connectors_overrides=launch_request.data_connectors_overrides or [], user_repo=user_repo, ) ) @@ -795,11 +882,15 @@ async def start_session( ) # Annotations - annotations: dict[str, str] = { - "renku.io/project_id": str(launcher.project_id), - "renku.io/launcher_id": body.launcher_id, - "renku.io/resource_class_id": str(resource_class_id), - } + session_extras = session_extras.concat( + SessionExtraResources( + annotations={ + "renku.io/project_id": str(launcher.project_id), + "renku.io/launcher_id": str(launcher_id), + "renku.io/resource_class_id": str(resource_class_id), + } + ) + ) # Authentication if isinstance(user, AuthenticatedAPIUser): @@ -852,7 +943,7 @@ async def start_session( session_extras = session_extras.concat(SessionExtraResources(secrets=[remote_secret])) # Raise an error if there are invalid environment variables in the request body - verify_launcher_env_variable_overrides(launcher, body) + verify_launcher_env_variable_overrides(launcher, launch_request) env = [ SessionEnvItem(name="RENKU_BASE_URL_PATH", value=base_server_path), SessionEnvItem(name="RENKU_BASE_URL", value=base_server_url), @@ -873,11 +964,11 @@ async def start_session( remote=resource_pool.remote, ) ) - launcher_env_variables = get_launcher_env_variables(launcher, body) + launcher_env_variables = get_launcher_env_variables(launcher, launch_request) env.extend(launcher_env_variables) session = AmaltheaSessionV1Alpha1( - metadata=Metadata(name=server_name, annotations=annotations), + metadata=Metadata(name=server_name, annotations=session_extras.annotations), spec=AmaltheaSessionSpec( location=session_location, imagePullSecrets=[ImagePullSecret(name=image_secret.name, adopt=True)] if image_secret else [], @@ -892,7 +983,7 @@ async def start_session( port=environment.port, storage=Storage( className=storage_class, - size=SizeStr(str(body.disk_storage) + "G"), + size=SizeStr(str(disk_storage) + "G"), mountPath=storage_mount.as_posix(), ), workingDir=work_dir.as_posix(), @@ -950,7 +1041,7 @@ async def start_session( "cpu": int(resource_class.cpu * 1000), "memory": resource_class.memory, "gpu": resource_class.gpu, - "storage": body.disk_storage, + "storage": disk_storage, "resource_class_id": resource_class.id, "resource_pool_id": resource_pool.id or "", "resource_class_name": f"{resource_pool.name}.{resource_class.name}", @@ -961,17 +1052,19 @@ async def start_session( async def patch_session( - body: apispec.SessionPatchRequest, + patch_request: SessionPatchRequest, session_id: str, user: AnonymousAPIUser | AuthenticatedAPIUser, internal_gitlab_user: APIUser, nb_config: NotebooksConfig, git_provider_helper: GitProviderHelperProto, + connected_svcs_repo: ConnectedServicesRepository, + data_connector_secret_repo: DataConnectorSecretRepository, project_repo: ProjectRepository, project_session_secret_repo: ProjectSessionSecretRepository, rp_repo: ResourcePoolRepository, session_repo: SessionRepository, - connected_svcs_repo: ConnectedServicesRepository, + user_repo: UserRepo, metrics: MetricsService, ) -> AmaltheaSessionV1Alpha1: """Patch an Amalthea session.""" @@ -991,43 +1084,43 @@ async def patch_session( # TODO: Some patching should only be done when the session is in some states to avoid inadvertent restarts # Refresh tokens for git proxy if ( - body.state is not None - and body.state.value.lower() == State.Hibernated.value.lower() - and body.state.value.lower() != session.status.state.value.lower() + patch_request.state is not None + and patch_request.state == SessionState.hibernated + and patch_request.state.value.lower() != session.status.state.value.lower() ): # Session is being hibernated patch.spec.hibernated = True is_getting_hibernated = True elif ( - body.state is not None - and body.state.value.lower() == State.Running.value.lower() - and session.status.state.value.lower() != body.state.value.lower() + patch_request.state is not None + and patch_request.state == SessionState.running + and session.status.state.value.lower() != patch_request.state.value.lower() ): # Session is being resumed patch.spec.hibernated = False await metrics.user_requested_session_resume(user, metadata={"session_id": session_id}) # Resource class - if body.resource_class_id is not None: - new_cluster = await nb_config.k8s_v2_client.cluster_by_class_id(body.resource_class_id, user) + if patch_request.resource_class_id is not None: + new_cluster = await nb_config.k8s_v2_client.cluster_by_class_id(patch_request.resource_class_id, user) if new_cluster.id != cluster.id: raise errors.ValidationError( message=( - f"The requested resource class {body.resource_class_id} is not in the " + f"The requested resource class {patch_request.resource_class_id} is not in the " f"same cluster {cluster.id} as the current resource class {session.resource_class_id()}." ) ) - rp = await rp_repo.get_resource_pool_from_class(user, body.resource_class_id) - rc = rp.get_resource_class(body.resource_class_id) + rp = await rp_repo.get_resource_pool_from_class(user, patch_request.resource_class_id) + rc = rp.get_resource_class(patch_request.resource_class_id) if not rc: raise errors.MissingResourceError( - message=f"The resource class you requested with ID {body.resource_class_id} does not exist" + message=f"The resource class you requested with ID {patch_request.resource_class_id} does not exist" ) # TODO: reject session classes which change the cluster if not patch.metadata: patch.metadata = AmaltheaSessionV1Alpha1MetadataPatch() # Patch the resource class ID in the annotations - patch.metadata.annotations = {"renku.io/resource_class_id": str(body.resource_class_id)} + patch.metadata.annotations = {"renku.io/resource_class_id": str(patch_request.resource_class_id)} if not patch.spec.session: patch.spec.session = AmaltheaSessionV1Alpha1SpecSessionPatch() patch.spec.session.resources = resources_from_resource_class(rc) @@ -1062,6 +1155,7 @@ async def patch_session( session_secrets = await project_session_secret_repo.get_all_session_secrets_from_project( user=user, project_id=project.id ) + data_connectors_stream = data_connector_secret_repo.get_data_connectors_with_secrets(user, project.id) git_providers = await git_provider_helper.get_providers(user=user) repositories = repositories_from_project(project, git_providers) @@ -1077,7 +1171,29 @@ async def patch_session( ) ) - # Data connectors: skip + # Data connectors + await patch_data_sources( + existing_session=session, + nb_config=nb_config, + server_name=server_name, + user=user, + data_connectors_stream=data_connectors_stream, + work_dir=work_dir, + # TODO: allow 'data_connectors_overrides' to be passed on the PATCH endpoint + data_connectors_overrides=[], # patch_request.data_connectors_overrides or [], + user_repo=user_repo, + ) + # session_extras = session_extras.concat( + # await get_data_sources( + # nb_config=nb_config, + # server_name=server_name, + # user=user, + # data_connectors_stream=data_connectors_stream, + # work_dir=work_dir, + # data_connectors_overrides=launch_request.data_connectors_overrides or [], + # user_repo=user_repo, + # ) + # ) # TODO: How can we patch data connectors? Should we even patch them? # TODO: The fact that `start_session()` accepts overrides for data connectors # TODO: but that we do not save these overrides (e.g. as annotations) means that @@ -1182,7 +1298,15 @@ def _find_mount_folder(dc: RCloneStorage) -> str: dc_ids.append(dc_id) mount_folders[new_mount_folder] = dc_ids result_dcs[dc_id] = dc.with_override( - override=apispec.SessionCloudStoragePost(storage_id=dc_id, target_path=new_mount_folder) + # override=apispec.SessionCloudStoragePost(storage_id=dc_id, target_path=new_mount_folder) + override=SessionDataConnectorOverride( + skip=False, + data_connector_id=ULID.from_str(dc_id), + target_path=new_mount_folder, + configuration=None, + source_path=None, + readonly=None, + ) ) return result_dcs @@ -1216,3 +1340,42 @@ def _make_patch_spec_list(existing: Sequence[_T], updated: Sequence[_T]) -> list else: patch_list.append(upsert_item) return patch_list + + +def validate_session_post_request(body: apispec.SessionPostRequest) -> SessionLaunchRequest: + """Validate a session launch request.""" + data_connectors_overrides = ( + [ + SessionDataConnectorOverride( + skip=dc.skip, + data_connector_id=ULID.from_str(dc.data_connector_id), + configuration=dc.configuration, + source_path=dc.source_path, + target_path=dc.target_path, + readonly=dc.readonly, + ) + for dc in body.data_connectors_overrides + ] + if body.data_connectors_overrides + else None + ) + env_variable_overrides = ( + [SessionEnvVar(name=ev.name, value=ev.value) for ev in body.env_variable_overrides] + if body.env_variable_overrides + else None + ) + return SessionLaunchRequest( + launcher_id=ULID.from_str(body.launcher_id), + disk_storage=body.disk_storage, + resource_class_id=body.resource_class_id, + data_connectors_overrides=data_connectors_overrides, + env_variable_overrides=env_variable_overrides, + ) + + +def validate_session_patch_request(body: apispec.SessionPatchRequest) -> SessionPatchRequest: + """Validate a session patch request.""" + return SessionPatchRequest( + resource_class_id=body.resource_class_id, + state=SessionState(body.state.value) if body.state else None, + ) diff --git a/components/renku_data_services/notebooks/models.py b/components/renku_data_services/notebooks/models.py index 82cd414a2..c877d81c2 100644 --- a/components/renku_data_services/notebooks/models.py +++ b/components/renku_data_services/notebooks/models.py @@ -1,15 +1,18 @@ """Basic models for amalthea sessions.""" from dataclasses import dataclass, field +from enum import StrEnum from pathlib import Path -from typing import cast +from typing import Any, cast from kubernetes.client import V1ObjectMeta, V1Secret from pydantic import AliasGenerator, BaseModel, Field, Json +from ulid import ULID from renku_data_services.data_connectors.models import DataConnectorSecret from renku_data_services.errors import errors from renku_data_services.errors.errors import ProgrammingError +from renku_data_services.notebooks.api.schemas.cloud_storage import RCloneStorageRequestOverride from renku_data_services.notebooks.crs import ( AmaltheaSessionV1Alpha1, DataSource, @@ -145,6 +148,7 @@ def name(self) -> str: class SessionExtraResources: """Represents extra resources to add to an amalthea session.""" + annotations: dict[str, str] = field(default_factory=dict) containers: list[ExtraContainer] = field(default_factory=list) data_connector_secrets: dict[str, list[DataConnectorSecret]] = field(default_factory=dict) data_sources: list[DataSource] = field(default_factory=list) @@ -157,10 +161,14 @@ def concat(self, added_extras: "SessionExtraResources | None") -> "SessionExtraR """Concatenates these session extras with more session extras.""" if added_extras is None: return self + annotations: dict[str, str] = dict() + annotations.update(self.annotations) + annotations.update(added_extras.annotations) data_connector_secrets: dict[str, list[DataConnectorSecret]] = dict() data_connector_secrets.update(self.data_connector_secrets) data_connector_secrets.update(added_extras.data_connector_secrets) return SessionExtraResources( + annotations=annotations, containers=self.containers + added_extras.containers, data_connector_secrets=data_connector_secrets, data_sources=self.data_sources + added_extras.data_sources, @@ -169,3 +177,41 @@ def concat(self, added_extras: "SessionExtraResources | None") -> "SessionExtraR volume_mounts=self.volume_mounts + added_extras.volume_mounts, volumes=self.volumes + added_extras.volumes, ) + + +@dataclass(eq=True, kw_only=True) +class SessionDataConnectorOverride(RCloneStorageRequestOverride): + """Model for a data connector override.""" + + skip: bool + data_connector_id: ULID + configuration: dict[str, Any] | None + source_path: str | None + target_path: str | None + readonly: bool | None + + +@dataclass(frozen=True, eq=True, kw_only=True) +class SessionLaunchRequest: + """Model for requesting a session launch.""" + + launcher_id: ULID + disk_storage: int | None + resource_class_id: int | None + data_connectors_overrides: list[SessionDataConnectorOverride] | None + env_variable_overrides: list[SessionEnvVar] | None + + +class SessionState(StrEnum): + """Session state.""" + + running = "running" + hibernated = "hibernated" + + +@dataclass(frozen=True, eq=True, kw_only=True) +class SessionPatchRequest: + """Model for patching a session.""" + + resource_class_id: int | None + state: SessionState | None