From 7ae2d25d3f1fa82d02f6f95d57e1cac6b928f425 Mon Sep 17 00:00:00 2001 From: Hysun He Date: Sun, 29 Dec 2024 14:54:10 +0800 Subject: [PATCH 1/5] [OCI] Support OCI Object Storage (#4501) * OCI Object Storage Support * example yaml update * example update * add more example yaml * Support RClone-RPM pkg * Add smoke test * ver * smoke test * Resolve dependancy conflict between oci-cli and runpod * Use latest RClone version (v1.68.2) * minor optimize * Address review comments * typo * test * sync code with repo * Address review comments & more testing. * address one more comment --- examples/oci/dataset-mount.yaml | 35 ++ examples/oci/dataset-upload-and-mount.yaml | 47 ++ examples/oci/oci-mounts.yaml | 26 ++ sky/adaptors/oci.py | 33 +- sky/cloud_stores.py | 61 +++ sky/data/data_transfer.py | 37 ++ sky/data/data_utils.py | 11 + sky/data/mounting_utils.py | 43 ++ sky/data/storage.py | 458 +++++++++++++++++++- sky/task.py | 10 + tests/smoke_tests/test_mount_and_storage.py | 17 + 11 files changed, 773 insertions(+), 5 deletions(-) create mode 100644 examples/oci/dataset-mount.yaml create mode 100644 examples/oci/dataset-upload-and-mount.yaml create mode 100644 examples/oci/oci-mounts.yaml diff --git a/examples/oci/dataset-mount.yaml b/examples/oci/dataset-mount.yaml new file mode 100644 index 00000000000..91bec9cda65 --- /dev/null +++ b/examples/oci/dataset-mount.yaml @@ -0,0 +1,35 @@ +name: cpu-task1 + +resources: + cloud: oci + region: us-sanjose-1 + cpus: 2 + disk_size: 256 + disk_tier: medium + use_spot: False + +file_mounts: + # Mount an existing oci bucket + /datasets-storage: + source: oci://skybucket + mode: MOUNT # Either MOUNT or COPY. Optional. + +# Working directory (optional) containing the project codebase. +# Its contents are synced to ~/sky_workdir/ on the cluster. +workdir: . + +num_nodes: 1 + +# Typical use: pip install -r requirements.txt +# Invoked under the workdir (i.e., can use its files). +setup: | + echo "*** Running setup for the task. ***" + +# Typical use: make use of resources, such as running training. +# Invoked under the workdir (i.e., can use its files). +run: | + echo "*** Running the task on OCI ***" + timestamp=$(date +%s) + ls -lthr /datasets-storage + echo "hi" >> /datasets-storage/foo.txt + ls -lthr /datasets-storage diff --git a/examples/oci/dataset-upload-and-mount.yaml b/examples/oci/dataset-upload-and-mount.yaml new file mode 100644 index 00000000000..13ddc4d2b35 --- /dev/null +++ b/examples/oci/dataset-upload-and-mount.yaml @@ -0,0 +1,47 @@ +name: cpu-task1 + +resources: + cloud: oci + region: us-sanjose-1 + cpus: 2 + disk_size: 256 + disk_tier: medium + use_spot: False + +file_mounts: + /datasets-storage: + name: skybucket # Name of storage, optional when source is bucket URI + source: ['./examples/oci'] # Source path, can be local or bucket URL. Optional, do not specify to create an empty bucket. + store: oci # E.g 'oci', 's3', 'gcs'...; default: None. Optional. + persistent: True # Defaults to True; can be set to false. Optional. + mode: MOUNT # Either MOUNT or COPY. Optional. + + /datasets-storage2: + name: skybucket2 # Name of storage, optional when source is bucket URI + source: './examples/oci' # Source path, can be local or bucket URL. Optional, do not specify to create an empty bucket. + store: oci # E.g 'oci', 's3', 'gcs'...; default: None. Optional. + persistent: True # Defaults to True; can be set to false. Optional. + mode: MOUNT # Either MOUNT or COPY. Optional. + +# Working directory (optional) containing the project codebase. +# Its contents are synced to ~/sky_workdir/ on the cluster. +workdir: . + +num_nodes: 1 + +# Typical use: pip install -r requirements.txt +# Invoked under the workdir (i.e., can use its files). +setup: | + echo "*** Running setup for the task. ***" + +# Typical use: make use of resources, such as running training. +# Invoked under the workdir (i.e., can use its files). +run: | + echo "*** Running the task on OCI ***" + ls -lthr /datasets-storage + echo "hi" >> /datasets-storage/foo.txt + ls -lthr /datasets-storage + + ls -lthr /datasets-storage2 + echo "hi" >> /datasets-storage2/foo2.txt + ls -lthr /datasets-storage2 diff --git a/examples/oci/oci-mounts.yaml b/examples/oci/oci-mounts.yaml new file mode 100644 index 00000000000..6fd2aaf16eb --- /dev/null +++ b/examples/oci/oci-mounts.yaml @@ -0,0 +1,26 @@ +resources: + cloud: oci + +file_mounts: + ~/tmpfile: ~/tmpfile + ~/a/b/c/tmpfile: ~/tmpfile + /tmp/workdir: ~/tmp-workdir + + /mydir: + name: skybucket + source: ['~/tmp-workdir'] + store: oci + mode: MOUNT + +setup: | + echo "*** Setup ***" + +run: | + echo "*** Run ***" + + ls -lthr ~/tmpfile + ls -lthr ~/a/b/c + echo hi >> /tmp/workdir/new_file + ls -lthr /tmp/workdir + + ls -lthr /mydir diff --git a/sky/adaptors/oci.py b/sky/adaptors/oci.py index 8fe09479a38..31712de414f 100644 --- a/sky/adaptors/oci.py +++ b/sky/adaptors/oci.py @@ -1,9 +1,11 @@ """Oracle OCI cloud adaptor""" +import functools import logging import os from sky.adaptors import common +from sky.clouds.utils import oci_utils # Suppress OCI circuit breaker logging before lazy import, because # oci modules prints additional message during imports, i.e., the @@ -30,10 +32,16 @@ def get_config_file() -> str: def get_oci_config(region=None, profile='DEFAULT'): conf_file_path = get_config_file() + if not profile or profile == 'DEFAULT': + config_profile = oci_utils.oci_config.get_profile() + else: + config_profile = profile + oci_config = oci.config.from_file(file_location=conf_file_path, - profile_name=profile) + profile_name=config_profile) if region is not None: oci_config['region'] = region + return oci_config @@ -54,6 +62,29 @@ def get_identity_client(region=None, profile='DEFAULT'): return oci.identity.IdentityClient(get_oci_config(region, profile)) +def get_object_storage_client(region=None, profile='DEFAULT'): + return oci.object_storage.ObjectStorageClient( + get_oci_config(region, profile)) + + def service_exception(): """OCI service exception.""" return oci.exceptions.ServiceError + + +def with_oci_env(f): + + @functools.wraps(f) + def wrapper(*args, **kwargs): + # pylint: disable=line-too-long + enter_env_cmds = [ + 'conda info --envs | grep "sky-oci-cli-env" || conda create -n sky-oci-cli-env python=3.10 -y', + '. $(conda info --base 2> /dev/null)/etc/profile.d/conda.sh > /dev/null 2>&1 || true', + 'conda activate sky-oci-cli-env', 'pip install oci-cli', + 'export OCI_CLI_SUPPRESS_FILE_PERMISSIONS_WARNING=True' + ] + operation_cmd = [f(*args, **kwargs)] + leave_env_cmds = ['conda deactivate'] + return ' && '.join(enter_env_cmds + operation_cmd + leave_env_cmds) + + return wrapper diff --git a/sky/cloud_stores.py b/sky/cloud_stores.py index e9c111c56ac..108f33f2c1f 100644 --- a/sky/cloud_stores.py +++ b/sky/cloud_stores.py @@ -7,6 +7,7 @@ * Better interface. * Better implementation (e.g., fsspec, smart_open, using each cloud's SDK). """ +import os import shlex import subprocess import time @@ -18,6 +19,7 @@ from sky.adaptors import azure from sky.adaptors import cloudflare from sky.adaptors import ibm +from sky.adaptors import oci from sky.clouds import gcp from sky.data import data_utils from sky.data.data_utils import Rclone @@ -470,6 +472,64 @@ def make_sync_file_command(self, source: str, destination: str) -> str: return self.make_sync_dir_command(source, destination) +class OciCloudStorage(CloudStorage): + """OCI Cloud Storage.""" + + def is_directory(self, url: str) -> bool: + """Returns whether OCI 'url' is a directory. + In cloud object stores, a "directory" refers to a regular object whose + name is a prefix of other objects. + """ + bucket_name, path = data_utils.split_oci_path(url) + + client = oci.get_object_storage_client() + namespace = client.get_namespace( + compartment_id=oci.get_oci_config()['tenancy']).data + + objects = client.list_objects(namespace_name=namespace, + bucket_name=bucket_name, + prefix=path).data.objects + + if len(objects) == 0: + # A directory with few or no items + return True + + if len(objects) > 1: + # A directory with more than 1 items + return True + + object_name = objects[0].name + if path.endswith(object_name): + # An object path + return False + + # A directory with only 1 item + return True + + @oci.with_oci_env + def make_sync_dir_command(self, source: str, destination: str) -> str: + """Downloads using OCI CLI.""" + bucket_name, path = data_utils.split_oci_path(source) + + download_via_ocicli = (f'oci os object sync --no-follow-symlinks ' + f'--bucket-name {bucket_name} ' + f'--prefix "{path}" --dest-dir "{destination}"') + + return download_via_ocicli + + @oci.with_oci_env + def make_sync_file_command(self, source: str, destination: str) -> str: + """Downloads a file using OCI CLI.""" + bucket_name, path = data_utils.split_oci_path(source) + filename = os.path.basename(path) + destination = os.path.join(destination, filename) + + download_via_ocicli = (f'oci os object get --bucket-name {bucket_name} ' + f'--name "{path}" --file "{destination}"') + + return download_via_ocicli + + def get_storage_from_path(url: str) -> CloudStorage: """Returns a CloudStorage by identifying the scheme:// in a URL.""" result = urllib.parse.urlsplit(url) @@ -485,6 +545,7 @@ def get_storage_from_path(url: str) -> CloudStorage: 's3': S3CloudStorage(), 'r2': R2CloudStorage(), 'cos': IBMCosCloudStorage(), + 'oci': OciCloudStorage(), # TODO: This is a hack, as Azure URL starts with https://, we should # refactor the registry to be able to take regex, so that Azure blob can # be identified with `https://(.*?)\.blob\.core\.windows\.net` diff --git a/sky/data/data_transfer.py b/sky/data/data_transfer.py index 374871031cb..3ccc6f8fc0e 100644 --- a/sky/data/data_transfer.py +++ b/sky/data/data_transfer.py @@ -200,3 +200,40 @@ def _add_bucket_iam_member(bucket_name: str, role: str, member: str) -> None: bucket.set_iam_policy(policy) logger.debug(f'Added {member} with role {role} to {bucket_name}.') + + +def s3_to_oci(s3_bucket_name: str, oci_bucket_name: str) -> None: + """Creates a one-time transfer from Amazon S3 to OCI Object Storage. + Args: + s3_bucket_name: str; Name of the Amazon S3 Bucket + oci_bucket_name: str; Name of the OCI Bucket + """ + # TODO(HysunHe): Implement sync with other clouds (s3, gs) + raise NotImplementedError('Moving data directly from S3 to OCI bucket ' + 'is currently not supported. Please specify ' + 'a local source for the storage object.') + + +def gcs_to_oci(gs_bucket_name: str, oci_bucket_name: str) -> None: + """Creates a one-time transfer from Google Cloud Storage to + OCI Object Storage. + Args: + gs_bucket_name: str; Name of the Google Cloud Storage Bucket + oci_bucket_name: str; Name of the OCI Bucket + """ + # TODO(HysunHe): Implement sync with other clouds (s3, gs) + raise NotImplementedError('Moving data directly from GCS to OCI bucket ' + 'is currently not supported. Please specify ' + 'a local source for the storage object.') + + +def r2_to_oci(r2_bucket_name: str, oci_bucket_name: str) -> None: + """Creates a one-time transfer from Cloudflare R2 to OCI Bucket. + Args: + r2_bucket_name: str; Name of the Cloudflare R2 Bucket + oci_bucket_name: str; Name of the OCI Bucket + """ + raise NotImplementedError( + 'Moving data directly from Cloudflare R2 to OCI ' + 'bucket is currently not supported. Please specify ' + 'a local source for the storage object.') diff --git a/sky/data/data_utils.py b/sky/data/data_utils.py index d66c79afeb0..05c2b42c844 100644 --- a/sky/data/data_utils.py +++ b/sky/data/data_utils.py @@ -730,3 +730,14 @@ def _remove_bucket_profile_rclone(bucket_name: str, lines_to_keep.append(line) return lines_to_keep + + +def split_oci_path(oci_path: str) -> Tuple[str, str]: + """Splits OCI Path into Bucket name and Relative Path to Bucket + Args: + oci_path: str; OCI Path, e.g. oci://imagenet/train/ + """ + path_parts = oci_path.replace('oci://', '').split('/') + bucket = path_parts.pop(0) + key = '/'.join(path_parts) + return bucket, key diff --git a/sky/data/mounting_utils.py b/sky/data/mounting_utils.py index 22b26c372c4..b713a1f1cc5 100644 --- a/sky/data/mounting_utils.py +++ b/sky/data/mounting_utils.py @@ -19,6 +19,7 @@ _BLOBFUSE_CACHE_ROOT_DIR = '~/.sky/blobfuse2_cache' _BLOBFUSE_CACHE_DIR = ('~/.sky/blobfuse2_cache/' '{storage_account_name}_{container_name}') +RCLONE_VERSION = 'v1.68.2' def get_s3_mount_install_cmd() -> str: @@ -158,6 +159,48 @@ def get_cos_mount_cmd(rclone_config_data: str, rclone_config_path: str, return mount_cmd +def get_rclone_install_cmd() -> str: + """ RClone installation for both apt-get and rpm. + This would be common command. + """ + # pylint: disable=line-too-long + install_cmd = ( + f'(which dpkg > /dev/null 2>&1 && (which rclone > /dev/null || (cd ~ > /dev/null' + f' && curl -O https://downloads.rclone.org/{RCLONE_VERSION}/rclone-{RCLONE_VERSION}-linux-amd64.deb' + f' && sudo dpkg -i rclone-{RCLONE_VERSION}-linux-amd64.deb' + f' && rm -f rclone-{RCLONE_VERSION}-linux-amd64.deb)))' + f' || (which rclone > /dev/null || (cd ~ > /dev/null' + f' && curl -O https://downloads.rclone.org/{RCLONE_VERSION}/rclone-{RCLONE_VERSION}-linux-amd64.rpm' + f' && sudo yum --nogpgcheck install rclone-{RCLONE_VERSION}-linux-amd64.rpm -y' + f' && rm -f rclone-{RCLONE_VERSION}-linux-amd64.rpm))') + return install_cmd + + +def get_oci_mount_cmd(mount_path: str, store_name: str, region: str, + namespace: str, compartment: str, config_file: str, + config_profile: str) -> str: + """ OCI specific RClone mount command for oci object storage. """ + # pylint: disable=line-too-long + mount_cmd = ( + f'sudo chown -R `whoami` {mount_path}' + f' && rclone config create oos_{store_name} oracleobjectstorage' + f' provider user_principal_auth namespace {namespace}' + f' compartment {compartment} region {region}' + f' oci-config-file {config_file}' + f' oci-config-profile {config_profile}' + f' && sed -i "s/oci-config-file/config_file/g;' + f' s/oci-config-profile/config_profile/g" ~/.config/rclone/rclone.conf' + f' && ([ ! -f /bin/fusermount3 ] && sudo ln -s /bin/fusermount /bin/fusermount3 || true)' + f' && (grep -q {mount_path} /proc/mounts || rclone mount oos_{store_name}:{store_name} {mount_path} --daemon --allow-non-empty)' + ) + return mount_cmd + + +def get_rclone_version_check_cmd() -> str: + """ RClone version check. This would be common command. """ + return f'rclone --version | grep -q {RCLONE_VERSION}' + + def _get_mount_binary(mount_cmd: str) -> str: """Returns mounting binary in string given as the mount command. diff --git a/sky/data/storage.py b/sky/data/storage.py index 73bdf3aff00..188c97b9545 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -24,6 +24,7 @@ from sky.adaptors import cloudflare from sky.adaptors import gcp from sky.adaptors import ibm +from sky.adaptors import oci from sky.data import data_transfer from sky.data import data_utils from sky.data import mounting_utils @@ -54,7 +55,9 @@ str(clouds.AWS()), str(clouds.GCP()), str(clouds.Azure()), - str(clouds.IBM()), cloudflare.NAME + str(clouds.IBM()), + str(clouds.OCI()), + cloudflare.NAME, ] # Maximum number of concurrent rsync upload processes @@ -115,6 +118,7 @@ class StoreType(enum.Enum): AZURE = 'AZURE' R2 = 'R2' IBM = 'IBM' + OCI = 'OCI' @classmethod def from_cloud(cls, cloud: str) -> 'StoreType': @@ -128,6 +132,8 @@ def from_cloud(cls, cloud: str) -> 'StoreType': return StoreType.R2 elif cloud.lower() == str(clouds.Azure()).lower(): return StoreType.AZURE + elif cloud.lower() == str(clouds.OCI()).lower(): + return StoreType.OCI elif cloud.lower() == str(clouds.Lambda()).lower(): with ux_utils.print_exception_no_traceback(): raise ValueError('Lambda Cloud does not provide cloud storage.') @@ -149,6 +155,8 @@ def from_store(cls, store: 'AbstractStore') -> 'StoreType': return StoreType.R2 elif isinstance(store, IBMCosStore): return StoreType.IBM + elif isinstance(store, OciStore): + return StoreType.OCI else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Unknown store type: {store}') @@ -165,6 +173,8 @@ def store_prefix(self) -> str: return 'r2://' elif self == StoreType.IBM: return 'cos://' + elif self == StoreType.OCI: + return 'oci://' else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Unknown store type: {self}') @@ -564,6 +574,8 @@ def __init__(self, self.add_store(StoreType.R2) elif self.source.startswith('cos://'): self.add_store(StoreType.IBM) + elif self.source.startswith('oci://'): + self.add_store(StoreType.OCI) @staticmethod def _validate_source( @@ -644,7 +656,7 @@ def _validate_local_source(local_source): 'using a bucket by writing : ' f'{source} in the file_mounts section of your YAML') is_local_source = True - elif split_path.scheme in ['s3', 'gs', 'https', 'r2', 'cos']: + elif split_path.scheme in ['s3', 'gs', 'https', 'r2', 'cos', 'oci']: is_local_source = False # Storage mounting does not support mounting specific files from # cloud store - ensure path points to only a directory @@ -668,7 +680,7 @@ def _validate_local_source(local_source): with ux_utils.print_exception_no_traceback(): raise exceptions.StorageSourceError( f'Supported paths: local, s3://, gs://, https://, ' - f'r2://, cos://. Got: {source}') + f'r2://, cos://, oci://. Got: {source}') return source, is_local_source def _validate_storage_spec(self, name: Optional[str]) -> None: @@ -683,7 +695,7 @@ def validate_name(name): """ prefix = name.split('://')[0] prefix = prefix.lower() - if prefix in ['s3', 'gs', 'https', 'r2', 'cos']: + if prefix in ['s3', 'gs', 'https', 'r2', 'cos', 'oci']: with ux_utils.print_exception_no_traceback(): raise exceptions.StorageNameError( 'Prefix detected: `name` cannot start with ' @@ -798,6 +810,11 @@ def _add_store_from_metadata( s_metadata, source=self.source, sync_on_reconstruction=self.sync_on_reconstruction) + elif s_type == StoreType.OCI: + store = OciStore.from_metadata( + s_metadata, + source=self.source, + sync_on_reconstruction=self.sync_on_reconstruction) else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Unknown store type: {s_type}') @@ -886,6 +903,8 @@ def add_store(self, store_cls = R2Store elif store_type == StoreType.IBM: store_cls = IBMCosStore + elif store_type == StoreType.OCI: + store_cls = OciStore else: with ux_utils.print_exception_no_traceback(): raise exceptions.StorageSpecError( @@ -1149,6 +1168,9 @@ def _validate(self): assert data_utils.verify_ibm_cos_bucket(self.name), ( f'Source specified as {self.source}, a COS bucket. ', 'COS Bucket should exist.') + elif self.source.startswith('oci://'): + raise NotImplementedError( + 'Moving data from OCI to S3 is currently not supported.') # Validate name self.name = self.validate_name(self.name) @@ -1260,6 +1282,8 @@ def upload(self): self._transfer_to_s3() elif self.source.startswith('r2://'): self._transfer_to_s3() + elif self.source.startswith('oci://'): + self._transfer_to_s3() else: self.batch_aws_rsync([self.source]) except exceptions.StorageUploadError: @@ -1588,6 +1612,9 @@ def _validate(self): assert data_utils.verify_ibm_cos_bucket(self.name), ( f'Source specified as {self.source}, a COS bucket. ', 'COS Bucket should exist.') + elif self.source.startswith('oci://'): + raise NotImplementedError( + 'Moving data from OCI to GCS is currently not supported.') # Validate name self.name = self.validate_name(self.name) # Check if the storage is enabled @@ -1696,6 +1723,8 @@ def upload(self): self._transfer_to_gcs() elif self.source.startswith('r2://'): self._transfer_to_gcs() + elif self.source.startswith('oci://'): + self._transfer_to_gcs() else: # If a single directory is specified in source, upload # contents to root of bucket by suffixing /*. @@ -2122,6 +2151,9 @@ def _validate(self): assert data_utils.verify_ibm_cos_bucket(self.name), ( f'Source specified as {self.source}, a COS bucket. ', 'COS Bucket should exist.') + elif self.source.startswith('oci://'): + raise NotImplementedError( + 'Moving data from OCI to AZureBlob is not supported.') # Validate name self.name = self.validate_name(self.name) @@ -2474,6 +2506,8 @@ def upload(self): raise NotImplementedError(error_message.format('R2')) elif self.source.startswith('cos://'): raise NotImplementedError(error_message.format('IBM COS')) + elif self.source.startswith('oci://'): + raise NotImplementedError(error_message.format('OCI')) else: self.batch_az_blob_sync([self.source]) except exceptions.StorageUploadError: @@ -2833,6 +2867,10 @@ def _validate(self): assert data_utils.verify_ibm_cos_bucket(self.name), ( f'Source specified as {self.source}, a COS bucket. ', 'COS Bucket should exist.') + elif self.source.startswith('oci://'): + raise NotImplementedError( + 'Moving data from OCI to R2 is currently not supported.') + # Validate name self.name = S3Store.validate_name(self.name) # Check if the storage is enabled @@ -2884,6 +2922,8 @@ def upload(self): self._transfer_to_r2() elif self.source.startswith('r2://'): pass + elif self.source.startswith('oci://'): + self._transfer_to_r2() else: self.batch_aws_rsync([self.source]) except exceptions.StorageUploadError: @@ -3590,3 +3630,413 @@ def _delete_cos_bucket(self): if e.__class__.__name__ == 'NoSuchBucket': logger.debug('bucket already removed') Rclone.delete_rclone_bucket_profile(self.name, Rclone.RcloneClouds.IBM) + + +class OciStore(AbstractStore): + """OciStore inherits from Storage Object and represents the backend + for OCI buckets. + """ + + _ACCESS_DENIED_MESSAGE = 'AccessDeniedException' + + def __init__(self, + name: str, + source: str, + region: Optional[str] = None, + is_sky_managed: Optional[bool] = None, + sync_on_reconstruction: Optional[bool] = True): + self.client: Any + self.bucket: StorageHandle + self.oci_config_file: str + self.config_profile: str + self.compartment: str + self.namespace: str + + # Bucket region should be consistence with the OCI config file + region = oci.get_oci_config()['region'] + + super().__init__(name, source, region, is_sky_managed, + sync_on_reconstruction) + + def _validate(self): + if self.source is not None and isinstance(self.source, str): + if self.source.startswith('oci://'): + assert self.name == data_utils.split_oci_path(self.source)[0], ( + 'OCI Bucket is specified as path, the name should be ' + 'the same as OCI bucket.') + elif not re.search(r'^\w+://', self.source): + # Treat it as local path. + pass + else: + raise NotImplementedError( + f'Moving data from {self.source} to OCI is not supported.') + + # Validate name + self.name = self.validate_name(self.name) + # Check if the storage is enabled + if not _is_storage_cloud_enabled(str(clouds.OCI())): + with ux_utils.print_exception_no_traceback(): + raise exceptions.ResourcesUnavailableError( + 'Storage \'store: oci\' specified, but ' \ + 'OCI access is disabled. To fix, enable '\ + 'OCI by running `sky check`. '\ + 'More info: https://skypilot.readthedocs.io/en/latest/getting-started/installation.html.' # pylint: disable=line-too-long + ) + + @classmethod + def validate_name(cls, name) -> str: + """Validates the name of the OCI store. + + Source for rules: https://docs.oracle.com/en-us/iaas/Content/Object/Tasks/managingbuckets.htm#Managing_Buckets # pylint: disable=line-too-long + """ + + def _raise_no_traceback_name_error(err_str): + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageNameError(err_str) + + if name is not None and isinstance(name, str): + # Check for overall length + if not 1 <= len(name) <= 256: + _raise_no_traceback_name_error( + f'Invalid store name: name {name} must contain 1-256 ' + 'characters.') + + # Check for valid characters and start/end with a number or letter + pattern = r'^[A-Za-z0-9-._]+$' + if not re.match(pattern, name): + _raise_no_traceback_name_error( + f'Invalid store name: name {name} can only contain ' + 'upper or lower case letters, numeric characters, hyphens ' + '(-), underscores (_), and dots (.). Spaces are not ' + 'allowed. Names must start and end with a number or ' + 'letter.') + else: + _raise_no_traceback_name_error('Store name must be specified.') + return name + + def initialize(self): + """Initializes the OCI store object on the cloud. + + Initialization involves fetching bucket if exists, or creating it if + it does not. + + Raises: + StorageBucketCreateError: If bucket creation fails + StorageBucketGetError: If fetching existing bucket fails + StorageInitError: If general initialization fails. + """ + # pylint: disable=import-outside-toplevel + from sky.clouds.utils import oci_utils + from sky.provision.oci.query_utils import query_helper + + self.oci_config_file = oci.get_config_file() + self.config_profile = oci_utils.oci_config.get_profile() + + ## pylint: disable=line-too-long + # What's compartment? See thttps://docs.oracle.com/en/cloud/foundation/cloud_architecture/governance/compartments.html + self.compartment = query_helper.find_compartment(self.region) + self.client = oci.get_object_storage_client(region=self.region, + profile=self.config_profile) + self.namespace = self.client.get_namespace( + compartment_id=oci.get_oci_config()['tenancy']).data + + self.bucket, is_new_bucket = self._get_bucket() + if self.is_sky_managed is None: + # If is_sky_managed is not specified, then this is a new storage + # object (i.e., did not exist in global_user_state) and we should + # set the is_sky_managed property. + # If is_sky_managed is specified, then we take no action. + self.is_sky_managed = is_new_bucket + + def upload(self): + """Uploads source to store bucket. + + Upload must be called by the Storage handler - it is not called on + Store initialization. + + Raises: + StorageUploadError: if upload fails. + """ + try: + if isinstance(self.source, list): + self.batch_oci_rsync(self.source, create_dirs=True) + elif self.source is not None: + if self.source.startswith('oci://'): + pass + else: + self.batch_oci_rsync([self.source]) + except exceptions.StorageUploadError: + raise + except Exception as e: + raise exceptions.StorageUploadError( + f'Upload failed for store {self.name}') from e + + def delete(self) -> None: + deleted_by_skypilot = self._delete_oci_bucket(self.name) + if deleted_by_skypilot: + msg_str = f'Deleted OCI bucket {self.name}.' + else: + msg_str = (f'OCI bucket {self.name} may have been deleted ' + f'externally. Removing from local state.') + logger.info(f'{colorama.Fore.GREEN}{msg_str}' + f'{colorama.Style.RESET_ALL}') + + def get_handle(self) -> StorageHandle: + return self.client.get_bucket(namespace_name=self.namespace, + bucket_name=self.name).data + + def batch_oci_rsync(self, + source_path_list: List[Path], + create_dirs: bool = False) -> None: + """Invokes oci sync to batch upload a list of local paths to Bucket + + Use OCI bulk operation to batch process the file upload + + Args: + source_path_list: List of paths to local files or directories + create_dirs: If the local_path is a directory and this is set to + False, the contents of the directory are directly uploaded to + root of the bucket. If the local_path is a directory and this is + set to True, the directory is created in the bucket root and + contents are uploaded to it. + """ + + @oci.with_oci_env + def get_file_sync_command(base_dir_path, file_names): + includes = ' '.join( + [f'--include "{file_name}"' for file_name in file_names]) + sync_command = ( + 'oci os object bulk-upload --no-follow-symlinks --overwrite ' + f'--bucket-name {self.name} --namespace-name {self.namespace} ' + f'--src-dir "{base_dir_path}" {includes}') + + return sync_command + + @oci.with_oci_env + def get_dir_sync_command(src_dir_path, dest_dir_name): + if dest_dir_name and not str(dest_dir_name).endswith('/'): + dest_dir_name = f'{dest_dir_name}/' + + excluded_list = storage_utils.get_excluded_files(src_dir_path) + excluded_list.append('.git/*') + excludes = ' '.join([ + f'--exclude {shlex.quote(file_name)}' + for file_name in excluded_list + ]) + + # we exclude .git directory from the sync + sync_command = ( + 'oci os object bulk-upload --no-follow-symlinks --overwrite ' + f'--bucket-name {self.name} --namespace-name {self.namespace} ' + f'--object-prefix "{dest_dir_name}" --src-dir "{src_dir_path}" ' + f'{excludes} ') + + return sync_command + + # Generate message for upload + if len(source_path_list) > 1: + source_message = f'{len(source_path_list)} paths' + else: + source_message = source_path_list[0] + + log_path = sky_logging.generate_tmp_logging_file_path( + _STORAGE_LOG_FILE_NAME) + sync_path = f'{source_message} -> oci://{self.name}/' + with rich_utils.safe_status( + ux_utils.spinner_message(f'Syncing {sync_path}', + log_path=log_path)): + data_utils.parallel_upload( + source_path_list=source_path_list, + filesync_command_generator=get_file_sync_command, + dirsync_command_generator=get_dir_sync_command, + log_path=log_path, + bucket_name=self.name, + access_denied_message=self._ACCESS_DENIED_MESSAGE, + create_dirs=create_dirs, + max_concurrent_uploads=1) + + logger.info( + ux_utils.finishing_message(f'Storage synced: {sync_path}', + log_path)) + + def _get_bucket(self) -> Tuple[StorageHandle, bool]: + """Obtains the OCI bucket. + If the bucket exists, this method will connect to the bucket. + + If the bucket does not exist, there are three cases: + 1) Raise an error if the bucket source starts with oci:// + 2) Return None if bucket has been externally deleted and + sync_on_reconstruction is False + 3) Create and return a new bucket otherwise + + Return tuple (Bucket, Boolean): The first item is the bucket + json payload from the OCI API call, the second item indicates + if this is a new created bucket(True) or an existing bucket(False). + + Raises: + StorageBucketCreateError: If creating the bucket fails + StorageBucketGetError: If fetching a bucket fails + """ + try: + get_bucket_response = self.client.get_bucket( + namespace_name=self.namespace, bucket_name=self.name) + bucket = get_bucket_response.data + return bucket, False + except oci.service_exception() as e: + if e.status == 404: # Not Found + if isinstance(self.source, + str) and self.source.startswith('oci://'): + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketGetError( + 'Attempted to connect to a non-existent bucket: ' + f'{self.source}') from e + else: + # If bucket cannot be found (i.e., does not exist), it is + # to be created by Sky. However, creation is skipped if + # Store object is being reconstructed for deletion. + if self.sync_on_reconstruction: + bucket = self._create_oci_bucket(self.name) + return bucket, True + else: + return None, False + elif e.status == 401: # Unauthorized + # AccessDenied error for buckets that are private and not + # owned by user. + command = ( + f'oci os object list --namespace-name {self.namespace} ' + f'--bucket-name {self.name}') + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketGetError( + _BUCKET_FAIL_TO_CONNECT_MESSAGE.format(name=self.name) + + f' To debug, consider running `{command}`.') from e + else: + # Unknown / unexpected error happened. This might happen when + # Object storage service itself functions not normal (e.g. + # maintainance event causes internal server error or request + # timeout, etc). + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketGetError( + f'Failed to connect to OCI bucket {self.name}') from e + + def mount_command(self, mount_path: str) -> str: + """Returns the command to mount the bucket to the mount_path. + + Uses Rclone to mount the bucket. + + Args: + mount_path: str; Path to mount the bucket to. + """ + install_cmd = mounting_utils.get_rclone_install_cmd() + mount_cmd = mounting_utils.get_oci_mount_cmd( + mount_path=mount_path, + store_name=self.name, + region=str(self.region), + namespace=self.namespace, + compartment=self.bucket.compartment_id, + config_file=self.oci_config_file, + config_profile=self.config_profile) + version_check_cmd = mounting_utils.get_rclone_version_check_cmd() + + return mounting_utils.get_mounting_command(mount_path, install_cmd, + mount_cmd, version_check_cmd) + + def _download_file(self, remote_path: str, local_path: str) -> None: + """Downloads file from remote to local on OCI bucket + + Args: + remote_path: str; Remote path on OCI bucket + local_path: str; Local path on user's device + """ + if remote_path.startswith(f'/{self.name}'): + # If the remote path is /bucket_name, we need to + # remove the leading / + remote_path = remote_path.lstrip('/') + + filename = os.path.basename(remote_path) + if not local_path.endswith(filename): + local_path = os.path.join(local_path, filename) + + @oci.with_oci_env + def get_file_download_command(remote_path, local_path): + download_command = (f'oci os object get --bucket-name {self.name} ' + f'--namespace-name {self.namespace} ' + f'--name {remote_path} --file {local_path}') + + return download_command + + download_command = get_file_download_command(remote_path, local_path) + + try: + with rich_utils.safe_status( + f'[bold cyan]Downloading: {remote_path} -> {local_path}[/]' + ): + subprocess.check_output(download_command, + stderr=subprocess.STDOUT, + shell=True) + except subprocess.CalledProcessError as e: + logger.error(f'Download failed: {remote_path} -> {local_path}.\n' + f'Detail errors: {e.output}') + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketDeleteError( + f'Failed download file {self.name}:{remote_path}.') from e + + def _create_oci_bucket(self, bucket_name: str) -> StorageHandle: + """Creates OCI bucket with specific name in specific region + + Args: + bucket_name: str; Name of bucket + region: str; Region name, e.g. us-central1, us-west1 + """ + logger.debug(f'_create_oci_bucket: {bucket_name}') + try: + create_bucket_response = self.client.create_bucket( + namespace_name=self.namespace, + create_bucket_details=oci.oci.object_storage.models. + CreateBucketDetails( + name=bucket_name, + compartment_id=self.compartment, + )) + bucket = create_bucket_response.data + return bucket + except oci.service_exception() as e: + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketCreateError( + f'Failed to create OCI bucket: {self.name}') from e + + def _delete_oci_bucket(self, bucket_name: str) -> bool: + """Deletes OCI bucket, including all objects in bucket + + Args: + bucket_name: str; Name of bucket + + Returns: + bool; True if bucket was deleted, False if it was deleted externally. + """ + logger.debug(f'_delete_oci_bucket: {bucket_name}') + + @oci.with_oci_env + def get_bucket_delete_command(bucket_name): + remove_command = (f'oci os bucket delete --bucket-name ' + f'{bucket_name} --empty --force') + + return remove_command + + remove_command = get_bucket_delete_command(bucket_name) + + try: + with rich_utils.safe_status( + f'[bold cyan]Deleting OCI bucket {bucket_name}[/]'): + subprocess.check_output(remove_command.split(' '), + stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + if 'BucketNotFound' in e.output.decode('utf-8'): + logger.debug( + _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( + bucket_name=bucket_name)) + return False + else: + logger.error(e.output) + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketDeleteError( + f'Failed to delete OCI bucket {bucket_name}.') + return True diff --git a/sky/task.py b/sky/task.py index bd454216b0f..edd2fd211a3 100644 --- a/sky/task.py +++ b/sky/task.py @@ -1031,6 +1031,16 @@ def sync_storage_mounts(self) -> None: storage.name, data_utils.Rclone.RcloneClouds.IBM) blob_path = f'cos://{cos_region}/{storage.name}' self.update_file_mounts({mnt_path: blob_path}) + elif store_type is storage_lib.StoreType.OCI: + if storage.source is not None and not isinstance( + storage.source, + list) and storage.source.startswith('oci://'): + blob_path = storage.source + else: + blob_path = 'oci://' + storage.name + self.update_file_mounts({ + mnt_path: blob_path, + }) else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Storage Type {store_type} ' diff --git a/tests/smoke_tests/test_mount_and_storage.py b/tests/smoke_tests/test_mount_and_storage.py index 93a5f22c274..aa61282aa11 100644 --- a/tests/smoke_tests/test_mount_and_storage.py +++ b/tests/smoke_tests/test_mount_and_storage.py @@ -85,6 +85,23 @@ def test_scp_file_mounts(): smoke_tests_utils.run_one_test(test) +@pytest.mark.oci # For OCI object storage mounts and file mounts. +def test_oci_mounts(): + name = smoke_tests_utils.get_cluster_name() + test_commands = [ + *smoke_tests_utils.STORAGE_SETUP_COMMANDS, + f'sky launch -y -c {name} --cloud oci --num-nodes 2 examples/oci/oci-mounts.yaml', + f'sky logs {name} 1 --status', # Ensure the job succeeded. + ] + test = smoke_tests_utils.Test( + 'oci_mounts', + test_commands, + f'sky down -y {name}', + timeout=20 * 60, # 20 mins + ) + smoke_tests_utils.run_one_test(test) + + @pytest.mark.no_fluidstack # Requires GCP to be enabled def test_using_file_mounts_with_env_vars(generic_cloud: str): name = smoke_tests_utils.get_cluster_name() From 7e40bcdce7437f601bb07e6eaf9fd954efdd12c6 Mon Sep 17 00:00:00 2001 From: zpoint Date: Mon, 30 Dec 2024 16:57:30 +0800 Subject: [PATCH 2/5] [Jobs] Allowing to specify intermediate bucket for file upload (#4257) * debug * support workdir_bucket_name config on yaml file * change the match statement to if else due to mypy limit * pass mypy * yapf format fix * reformat * remove debug line * all dir to same bucket * private member function * fix mypy * support sub dir config to separate to different directory * rename and add smoke test * bucketname * support sub dir mount * private member for _bucket_sub_path and smoke test fix * support copy mount for sub dir * support gcs, s3 delete folder * doc * r2 remove_objects_from_sub_path * support azure remove directory and cos remove * doc string for remove_objects_from_sub_path * fix sky jobs subdir issue * test case update * rename to _bucket_sub_path * change the config schema * setter * bug fix and test update * delete bucket depends on user config or sky generated * add test case * smoke test bug fix * robust smoke test * fix comment * bug fix * set the storage manually * better structure * fix mypy * Update docs/source/reference/config.rst Co-authored-by: Romil Bhardwaj * Update docs/source/reference/config.rst Co-authored-by: Romil Bhardwaj * limit creation for bucket and delete sub dir only * resolve comment * Update docs/source/reference/config.rst Co-authored-by: Romil Bhardwaj * Update sky/utils/controller_utils.py Co-authored-by: Romil Bhardwaj * resolve PR comment * bug fix * bug fix * fix test case * bug fix * fix * fix test case * bug fix * support is_sky_managed param in config * pass param intermediate_bucket_is_sky_managed * resolve PR comment * Update sky/utils/controller_utils.py Co-authored-by: Romil Bhardwaj * hide bucket creation log * reset green color * rename is_sky_managed to _is_sky_managed * bug fix * retrieve _is_sky_managed from stores * propogate the log --------- Co-authored-by: Romil Bhardwaj --- docs/source/reference/config.rst | 4 + sky/data/mounting_utils.py | 63 +- sky/data/storage.py | 577 ++++++++++++++---- sky/skylet/constants.py | 10 +- sky/task.py | 19 +- sky/utils/controller_utils.py | 150 +++-- sky/utils/schemas.py | 11 + tests/smoke_tests/test_managed_job.py | 61 +- tests/smoke_tests/test_mount_and_storage.py | 132 +++- tests/test_yamls/intermediate_bucket.yaml | 21 + .../use_intermediate_bucket_config.yaml | 2 + 11 files changed, 862 insertions(+), 188 deletions(-) create mode 100644 tests/test_yamls/intermediate_bucket.yaml create mode 100644 tests/test_yamls/use_intermediate_bucket_config.yaml diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index d5ee4d2134a..99bd347942a 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -24,6 +24,10 @@ Available fields and semantics: # # Ref: https://docs.skypilot.co/en/latest/examples/managed-jobs.html#customizing-job-controller-resources jobs: + # Bucket to store managed jobs mount files and tmp files. Bucket must already exist. + # Optional. If not set, SkyPilot will create a new bucket for each managed job launch. + # Supports s3://, gs://, https://.blob.core.windows.net/, r2://, cos:/// + bucket: s3://my-bucket/ controller: resources: # same spec as 'resources' in a task YAML cloud: gcp diff --git a/sky/data/mounting_utils.py b/sky/data/mounting_utils.py index b713a1f1cc5..d2a95a3c20b 100644 --- a/sky/data/mounting_utils.py +++ b/sky/data/mounting_utils.py @@ -31,12 +31,19 @@ def get_s3_mount_install_cmd() -> str: return install_cmd -def get_s3_mount_cmd(bucket_name: str, mount_path: str) -> str: +# pylint: disable=invalid-name +def get_s3_mount_cmd(bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount an S3 bucket using goofys.""" + if _bucket_sub_path is None: + _bucket_sub_path = '' + else: + _bucket_sub_path = f':{_bucket_sub_path}' mount_cmd = ('goofys -o allow_other ' f'--stat-cache-ttl {_STAT_CACHE_TTL} ' f'--type-cache-ttl {_TYPE_CACHE_TTL} ' - f'{bucket_name} {mount_path}') + f'{bucket_name}{_bucket_sub_path} {mount_path}') return mount_cmd @@ -50,15 +57,20 @@ def get_gcs_mount_install_cmd() -> str: return install_cmd -def get_gcs_mount_cmd(bucket_name: str, mount_path: str) -> str: +# pylint: disable=invalid-name +def get_gcs_mount_cmd(bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount a GCS bucket using gcsfuse.""" - + bucket_sub_path_arg = f'--only-dir {_bucket_sub_path} '\ + if _bucket_sub_path else '' mount_cmd = ('gcsfuse -o allow_other ' '--implicit-dirs ' f'--stat-cache-capacity {_STAT_CACHE_CAPACITY} ' f'--stat-cache-ttl {_STAT_CACHE_TTL} ' f'--type-cache-ttl {_TYPE_CACHE_TTL} ' f'--rename-dir-limit {_RENAME_DIR_LIMIT} ' + f'{bucket_sub_path_arg}' f'{bucket_name} {mount_path}') return mount_cmd @@ -79,10 +91,12 @@ def get_az_mount_install_cmd() -> str: return install_cmd +# pylint: disable=invalid-name def get_az_mount_cmd(container_name: str, storage_account_name: str, mount_path: str, - storage_account_key: Optional[str] = None) -> str: + storage_account_key: Optional[str] = None, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount an AZ Container using blobfuse2. Args: @@ -91,6 +105,7 @@ def get_az_mount_cmd(container_name: str, belongs to. mount_path: Path where the container will be mounting. storage_account_key: Access key for the given storage account. + _bucket_sub_path: Sub path of the mounting container. Returns: str: Command used to mount AZ container with blobfuse2. @@ -107,25 +122,38 @@ def get_az_mount_cmd(container_name: str, cache_path = _BLOBFUSE_CACHE_DIR.format( storage_account_name=storage_account_name, container_name=container_name) + if _bucket_sub_path is None: + bucket_sub_path_arg = '' + else: + bucket_sub_path_arg = f'--subdirectory={_bucket_sub_path}/ ' mount_cmd = (f'AZURE_STORAGE_ACCOUNT={storage_account_name} ' f'{key_env_var} ' f'blobfuse2 {mount_path} --allow-other --no-symlinks ' '-o umask=022 -o default_permissions ' f'--tmp-path {cache_path} ' + f'{bucket_sub_path_arg}' f'--container-name {container_name}') return mount_cmd -def get_r2_mount_cmd(r2_credentials_path: str, r2_profile_name: str, - endpoint_url: str, bucket_name: str, - mount_path: str) -> str: +# pylint: disable=invalid-name +def get_r2_mount_cmd(r2_credentials_path: str, + r2_profile_name: str, + endpoint_url: str, + bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to install R2 mount utility goofys.""" + if _bucket_sub_path is None: + _bucket_sub_path = '' + else: + _bucket_sub_path = f':{_bucket_sub_path}' mount_cmd = (f'AWS_SHARED_CREDENTIALS_FILE={r2_credentials_path} ' f'AWS_PROFILE={r2_profile_name} goofys -o allow_other ' f'--stat-cache-ttl {_STAT_CACHE_TTL} ' f'--type-cache-ttl {_TYPE_CACHE_TTL} ' f'--endpoint {endpoint_url} ' - f'{bucket_name} {mount_path}') + f'{bucket_name}{_bucket_sub_path} {mount_path}') return mount_cmd @@ -137,9 +165,12 @@ def get_cos_mount_install_cmd() -> str: return install_cmd -def get_cos_mount_cmd(rclone_config_data: str, rclone_config_path: str, - bucket_rclone_profile: str, bucket_name: str, - mount_path: str) -> str: +def get_cos_mount_cmd(rclone_config_data: str, + rclone_config_path: str, + bucket_rclone_profile: str, + bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount an IBM COS bucket using rclone.""" # creates a fusermount soft link on older (<22) Ubuntu systems for # rclone's mount utility. @@ -151,10 +182,14 @@ def get_cos_mount_cmd(rclone_config_data: str, rclone_config_path: str, 'mkdir -p ~/.config/rclone/ && ' f'echo "{rclone_config_data}" >> ' f'{rclone_config_path}') + if _bucket_sub_path is None: + sub_path_arg = f'{bucket_name}/{_bucket_sub_path}' + else: + sub_path_arg = f'/{bucket_name}' # --daemon will keep the mounting process running in the background. mount_cmd = (f'{configure_rclone_profile} && ' 'rclone mount ' - f'{bucket_rclone_profile}:{bucket_name} {mount_path} ' + f'{bucket_rclone_profile}:{sub_path_arg} {mount_path} ' '--daemon') return mount_cmd @@ -252,7 +287,7 @@ def get_mounting_script( script = textwrap.dedent(f""" #!/usr/bin/env bash set -e - + {command_runner.ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD} MOUNT_PATH={mount_path} diff --git a/sky/data/storage.py b/sky/data/storage.py index 188c97b9545..018cb2797ca 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -200,6 +200,45 @@ def get_endpoint_url(cls, store: 'AbstractStore', path: str) -> str: bucket_endpoint_url = f'{store_type.store_prefix()}{path}' return bucket_endpoint_url + @classmethod + def get_fields_from_store_url( + cls, store_url: str + ) -> Tuple['StoreType', Type['AbstractStore'], str, str, Optional[str], + Optional[str]]: + """Returns the store type, store class, bucket name, and sub path from + a store URL, and the storage account name and region if applicable. + + Args: + store_url: str; The store URL. + """ + # The full path from the user config of IBM COS contains the region, + # and Azure Blob Storage contains the storage account name, we need to + # pass these information to the store constructor. + storage_account_name = None + region = None + for store_type in StoreType: + if store_url.startswith(store_type.store_prefix()): + if store_type == StoreType.AZURE: + storage_account_name, bucket_name, sub_path = \ + data_utils.split_az_path(store_url) + store_cls: Type['AbstractStore'] = AzureBlobStore + elif store_type == StoreType.IBM: + bucket_name, sub_path, region = data_utils.split_cos_path( + store_url) + store_cls = IBMCosStore + elif store_type == StoreType.R2: + bucket_name, sub_path = data_utils.split_r2_path(store_url) + store_cls = R2Store + elif store_type == StoreType.GCS: + bucket_name, sub_path = data_utils.split_gcs_path(store_url) + store_cls = GcsStore + elif store_type == StoreType.S3: + bucket_name, sub_path = data_utils.split_s3_path(store_url) + store_cls = S3Store + return store_type, store_cls,bucket_name, \ + sub_path, storage_account_name, region + raise ValueError(f'Unknown store URL: {store_url}') + class StorageMode(enum.Enum): MOUNT = 'MOUNT' @@ -226,25 +265,29 @@ def __init__(self, name: str, source: Optional[SourceType], region: Optional[str] = None, - is_sky_managed: Optional[bool] = None): + is_sky_managed: Optional[bool] = None, + _bucket_sub_path: Optional[str] = None): self.name = name self.source = source self.region = region self.is_sky_managed = is_sky_managed + self._bucket_sub_path = _bucket_sub_path def __repr__(self): return (f'StoreMetadata(' f'\n\tname={self.name},' f'\n\tsource={self.source},' f'\n\tregion={self.region},' - f'\n\tis_sky_managed={self.is_sky_managed})') + f'\n\tis_sky_managed={self.is_sky_managed},' + f'\n\t_bucket_sub_path={self._bucket_sub_path})') def __init__(self, name: str, source: Optional[SourceType], region: Optional[str] = None, is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: Optional[bool] = True): + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): # pylint: disable=invalid-name """Initialize AbstractStore Args: @@ -258,7 +301,11 @@ def __init__(self, there. This is set to false when the Storage object is created not for direct use, e.g. for 'sky storage delete', or the storage is being re-used, e.g., for `sky start` on a stopped cluster. - + _bucket_sub_path: str; The prefix of the bucket directory to be + created in the store, e.g. if _bucket_sub_path=my-dir, the files + will be uploaded to s3:///my-dir/. + This only works if source is a local directory. + # TODO(zpoint): Add support for non-local source. Raises: StorageBucketCreateError: If bucket creation fails StorageBucketGetError: If fetching existing bucket fails @@ -269,10 +316,29 @@ def __init__(self, self.region = region self.is_sky_managed = is_sky_managed self.sync_on_reconstruction = sync_on_reconstruction + + # To avoid mypy error + self._bucket_sub_path: Optional[str] = None + # Trigger the setter to strip any leading/trailing slashes. + self.bucket_sub_path = _bucket_sub_path # Whether sky is responsible for the lifecycle of the Store. self._validate() self.initialize() + @property + def bucket_sub_path(self) -> Optional[str]: + """Get the bucket_sub_path.""" + return self._bucket_sub_path + + @bucket_sub_path.setter + # pylint: disable=invalid-name + def bucket_sub_path(self, bucket_sub_path: Optional[str]) -> None: + """Set the bucket_sub_path, stripping any leading/trailing slashes.""" + if bucket_sub_path is not None: + self._bucket_sub_path = bucket_sub_path.strip('/') + else: + self._bucket_sub_path = None + @classmethod def from_metadata(cls, metadata: StoreMetadata, **override_args): """Create a Store from a StoreMetadata object. @@ -280,19 +346,26 @@ def from_metadata(cls, metadata: StoreMetadata, **override_args): Used when reconstructing Storage and Store objects from global_user_state. """ - return cls(name=override_args.get('name', metadata.name), - source=override_args.get('source', metadata.source), - region=override_args.get('region', metadata.region), - is_sky_managed=override_args.get('is_sky_managed', - metadata.is_sky_managed), - sync_on_reconstruction=override_args.get( - 'sync_on_reconstruction', True)) + return cls( + name=override_args.get('name', metadata.name), + source=override_args.get('source', metadata.source), + region=override_args.get('region', metadata.region), + is_sky_managed=override_args.get('is_sky_managed', + metadata.is_sky_managed), + sync_on_reconstruction=override_args.get('sync_on_reconstruction', + True), + # backward compatibility + _bucket_sub_path=override_args.get( + '_bucket_sub_path', + metadata._bucket_sub_path # pylint: disable=protected-access + ) if hasattr(metadata, '_bucket_sub_path') else None) def get_metadata(self) -> StoreMetadata: return self.StoreMetadata(name=self.name, source=self.source, region=self.region, - is_sky_managed=self.is_sky_managed) + is_sky_managed=self.is_sky_managed, + _bucket_sub_path=self._bucket_sub_path) def initialize(self): """Initializes the Store object on the cloud. @@ -320,7 +393,11 @@ def upload(self) -> None: raise NotImplementedError def delete(self) -> None: - """Removes the Storage object from the cloud.""" + """Removes the Storage from the cloud.""" + raise NotImplementedError + + def _delete_sub_path(self) -> None: + """Removes objects from the sub path in the bucket.""" raise NotImplementedError def get_handle(self) -> StorageHandle: @@ -464,13 +541,19 @@ def remove_store(self, store: AbstractStore) -> None: if storetype in self.sky_stores: del self.sky_stores[storetype] - def __init__(self, - name: Optional[str] = None, - source: Optional[SourceType] = None, - stores: Optional[Dict[StoreType, AbstractStore]] = None, - persistent: Optional[bool] = True, - mode: StorageMode = StorageMode.MOUNT, - sync_on_reconstruction: bool = True) -> None: + def __init__( + self, + name: Optional[str] = None, + source: Optional[SourceType] = None, + stores: Optional[Dict[StoreType, AbstractStore]] = None, + persistent: Optional[bool] = True, + mode: StorageMode = StorageMode.MOUNT, + sync_on_reconstruction: bool = True, + # pylint: disable=invalid-name + _is_sky_managed: Optional[bool] = None, + # pylint: disable=invalid-name + _bucket_sub_path: Optional[str] = None + ) -> None: """Initializes a Storage object. Three fields are required: the name of the storage, the source @@ -508,6 +591,18 @@ def __init__(self, there. This is set to false when the Storage object is created not for direct use, e.g. for 'sky storage delete', or the storage is being re-used, e.g., for `sky start` on a stopped cluster. + _is_sky_managed: Optional[bool]; Indicates if the storage is managed + by Sky. Without this argument, the controller's behavior differs + from the local machine. For example, if a bucket does not exist: + Local Machine (is_sky_managed=True) → + Controller (is_sky_managed=False). + With this argument, the controller aligns with the local machine, + ensuring it retains the is_sky_managed information from the YAML. + During teardown, if is_sky_managed is True, the controller should + delete the bucket. Otherwise, it might mistakenly delete only the + sub-path, assuming is_sky_managed is False. + _bucket_sub_path: Optional[str]; The subdirectory to use for the + storage object. """ self.name: str self.source = source @@ -515,6 +610,8 @@ def __init__(self, self.mode = mode assert mode in StorageMode self.sync_on_reconstruction = sync_on_reconstruction + self._is_sky_managed = _is_sky_managed + self._bucket_sub_path = _bucket_sub_path # TODO(romilb, zhwu): This is a workaround to support storage deletion # for spot. Once sky storage supports forced management for external @@ -577,6 +674,12 @@ def __init__(self, elif self.source.startswith('oci://'): self.add_store(StoreType.OCI) + def get_bucket_sub_path_prefix(self, blob_path: str) -> str: + """Adds the bucket sub path prefix to the blob path.""" + if self._bucket_sub_path is not None: + return f'{blob_path}/{self._bucket_sub_path}' + return blob_path + @staticmethod def _validate_source( source: SourceType, mode: StorageMode, @@ -787,34 +890,40 @@ def _add_store_from_metadata( store = S3Store.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.GCS: store = GcsStore.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.AZURE: assert isinstance(s_metadata, AzureBlobStore.AzureBlobStoreMetadata) store = AzureBlobStore.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.R2: store = R2Store.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.IBM: store = IBMCosStore.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.OCI: store = OciStore.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Unknown store type: {s_type}') @@ -834,7 +943,6 @@ def _add_store_from_metadata( 'to be reconstructed while the corresponding ' 'bucket was externally deleted.') continue - self._add_store(store, is_reconstructed=True) @classmethod @@ -890,6 +998,7 @@ def add_store(self, f'storage account {storage_account_name!r}.') else: logger.info(f'Storage type {store_type} already exists.') + return self.stores[store_type] store_cls: Type[AbstractStore] @@ -909,21 +1018,24 @@ def add_store(self, with ux_utils.print_exception_no_traceback(): raise exceptions.StorageSpecError( f'{store_type} not supported as a Store.') - - # Initialize store object and get/create bucket try: store = store_cls( name=self.name, source=self.source, region=region, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + is_sky_managed=self._is_sky_managed, + _bucket_sub_path=self._bucket_sub_path) except exceptions.StorageBucketCreateError: # Creation failed, so this must be sky managed store. Add failure # to state. logger.error(f'Could not create {store_type} store ' f'with name {self.name}.') - global_user_state.set_storage_status(self.name, - StorageStatus.INIT_FAILED) + try: + global_user_state.set_storage_status(self.name, + StorageStatus.INIT_FAILED) + except ValueError as e: + logger.error(f'Error setting storage status: {e}') raise except exceptions.StorageBucketGetError: # Bucket get failed, so this is not sky managed. Do not update state @@ -1039,12 +1151,15 @@ def warn_for_git_dir(source: str): def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage': common_utils.validate_schema(config, schemas.get_storage_schema(), 'Invalid storage YAML: ') - name = config.pop('name', None) source = config.pop('source', None) store = config.pop('store', None) mode_str = config.pop('mode', None) force_delete = config.pop('_force_delete', None) + # pylint: disable=invalid-name + _is_sky_managed = config.pop('_is_sky_managed', None) + # pylint: disable=invalid-name + _bucket_sub_path = config.pop('_bucket_sub_path', None) if force_delete is None: force_delete = False @@ -1064,7 +1179,9 @@ def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage': storage_obj = cls(name=name, source=source, persistent=persistent, - mode=mode) + mode=mode, + _is_sky_managed=_is_sky_managed, + _bucket_sub_path=_bucket_sub_path) if store is not None: storage_obj.add_store(StoreType(store.upper())) @@ -1072,7 +1189,7 @@ def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage': storage_obj.force_delete = force_delete return storage_obj - def to_yaml_config(self) -> Dict[str, str]: + def to_yaml_config(self) -> Dict[str, Any]: config = {} def add_if_not_none(key: str, value: Optional[Any]): @@ -1088,13 +1205,18 @@ def add_if_not_none(key: str, value: Optional[Any]): add_if_not_none('source', self.source) stores = None + is_sky_managed = self._is_sky_managed if self.stores: stores = ','.join([store.value for store in self.stores]) + is_sky_managed = list(self.stores.values())[0].is_sky_managed add_if_not_none('store', stores) + add_if_not_none('_is_sky_managed', is_sky_managed) add_if_not_none('persistent', self.persistent) add_if_not_none('mode', self.mode.value) if self.force_delete: config['_force_delete'] = True + if self._bucket_sub_path is not None: + config['_bucket_sub_path'] = self._bucket_sub_path return config @@ -1116,7 +1238,8 @@ def __init__(self, source: str, region: Optional[str] = _DEFAULT_REGION, is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: bool = True): + sync_on_reconstruction: bool = True, + _bucket_sub_path: Optional[str] = None): self.client: 'boto3.client.Client' self.bucket: 'StorageHandle' # TODO(romilb): This is purely a stopgap fix for @@ -1129,7 +1252,7 @@ def __init__(self, f'{self._DEFAULT_REGION} for bucket {name!r}.') region = self._DEFAULT_REGION super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) def _validate(self): if self.source is not None and isinstance(self.source, str): @@ -1293,6 +1416,9 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_s3_bucket(self.name) if deleted_by_skypilot: msg_str = f'Deleted S3 bucket {self.name}.' @@ -1302,6 +1428,19 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + deleted_by_skypilot = self._delete_s3_bucket_sub_path( + self.name, self._bucket_sub_path) + if deleted_by_skypilot: + msg_str = f'Removed objects from S3 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + else: + msg_str = f'Failed to remove objects from S3 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + logger.info(f'{colorama.Fore.GREEN}{msg_str}' + f'{colorama.Style.RESET_ALL}') + def get_handle(self) -> StorageHandle: return aws.resource('s3').Bucket(self.name) @@ -1332,9 +1471,11 @@ def get_file_sync_command(base_dir_path, file_names): for file_name in file_names ]) base_dir_path = shlex.quote(base_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ('aws s3 sync --no-follow-symlinks --exclude="*" ' f'{includes} {base_dir_path} ' - f's3://{self.name}') + f's3://{self.name}{sub_path}') return sync_command def get_dir_sync_command(src_dir_path, dest_dir_name): @@ -1346,9 +1487,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): for file_name in excluded_list ]) src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = (f'aws s3 sync --no-follow-symlinks {excludes} ' f'{src_dir_path} ' - f's3://{self.name}/{dest_dir_name}') + f's3://{self.name}{sub_path}/{dest_dir_name}') return sync_command # Generate message for upload @@ -1466,7 +1609,8 @@ def mount_command(self, mount_path: str) -> str: """ install_cmd = mounting_utils.get_s3_mount_install_cmd() mount_cmd = mounting_utils.get_s3_mount_cmd(self.bucket.name, - mount_path) + mount_path, + self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -1516,6 +1660,27 @@ def _create_s3_bucket(self, ) from e return aws.resource('s3').Bucket(bucket_name) + def _execute_s3_remove_command(self, command: str, bucket_name: str, + hint_operating: str, + hint_failed: str) -> bool: + try: + with rich_utils.safe_status( + ux_utils.spinner_message(hint_operating)): + subprocess.check_output(command.split(' '), + stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + if 'NoSuchBucket' in e.output.decode('utf-8'): + logger.debug( + _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( + bucket_name=bucket_name)) + return False + else: + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketDeleteError( + f'{hint_failed}' + f'Detailed error: {e.output}') + return True + def _delete_s3_bucket(self, bucket_name: str) -> bool: """Deletes S3 bucket, including all objects in bucket @@ -1533,29 +1698,28 @@ def _delete_s3_bucket(self, bucket_name: str) -> bool: # The fastest way to delete is to run `aws s3 rb --force`, # which removes the bucket by force. remove_command = f'aws s3 rb s3://{bucket_name} --force' - try: - with rich_utils.safe_status( - ux_utils.spinner_message( - f'Deleting S3 bucket [green]{bucket_name}')): - subprocess.check_output(remove_command.split(' '), - stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as e: - if 'NoSuchBucket' in e.output.decode('utf-8'): - logger.debug( - _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( - bucket_name=bucket_name)) - return False - else: - with ux_utils.print_exception_no_traceback(): - raise exceptions.StorageBucketDeleteError( - f'Failed to delete S3 bucket {bucket_name}.' - f'Detailed error: {e.output}') + success = self._execute_s3_remove_command( + remove_command, bucket_name, + f'Deleting S3 bucket [green]{bucket_name}[/]', + f'Failed to delete S3 bucket {bucket_name}.') + if not success: + return False # Wait until bucket deletion propagates on AWS servers while data_utils.verify_s3_bucket(bucket_name): time.sleep(0.1) return True + def _delete_s3_bucket_sub_path(self, bucket_name: str, + sub_path: str) -> bool: + """Deletes the sub path from the bucket.""" + remove_command = f'aws s3 rm s3://{bucket_name}/{sub_path}/ --recursive' + return self._execute_s3_remove_command( + remove_command, bucket_name, f'Removing objects from S3 bucket ' + f'[green]{bucket_name}/{sub_path}[/]', + f'Failed to remove objects from S3 bucket {bucket_name}/{sub_path}.' + ) + class GcsStore(AbstractStore): """GcsStore inherits from Storage Object and represents the backend @@ -1569,11 +1733,12 @@ def __init__(self, source: str, region: Optional[str] = 'us-central1', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: Optional[bool] = True): + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): self.client: 'storage.Client' self.bucket: StorageHandle super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) def _validate(self): if self.source is not None and isinstance(self.source, str): @@ -1736,6 +1901,9 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_gcs_bucket(self.name) if deleted_by_skypilot: msg_str = f'Deleted GCS bucket {self.name}.' @@ -1745,6 +1913,19 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + deleted_by_skypilot = self._delete_gcs_bucket(self.name, + self._bucket_sub_path) + if deleted_by_skypilot: + msg_str = f'Deleted objects in GCS bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + else: + msg_str = f'GCS bucket {self.name} may have ' \ + 'been deleted externally.' + logger.info(f'{colorama.Fore.GREEN}{msg_str}' + f'{colorama.Style.RESET_ALL}') + def get_handle(self) -> StorageHandle: return self.client.get_bucket(self.name) @@ -1818,9 +1999,11 @@ def get_file_sync_command(base_dir_path, file_names): sync_format = '|'.join(file_names) gsutil_alias, alias_gen = data_utils.get_gsutil_command() base_dir_path = shlex.quote(base_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; {gsutil_alias} ' f'rsync -e -x \'^(?!{sync_format}$).*\' ' - f'{base_dir_path} gs://{self.name}') + f'{base_dir_path} gs://{self.name}{sub_path}') return sync_command def get_dir_sync_command(src_dir_path, dest_dir_name): @@ -1830,9 +2013,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): excludes = '|'.join(excluded_list) gsutil_alias, alias_gen = data_utils.get_gsutil_command() src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; {gsutil_alias} ' f'rsync -e -r -x \'({excludes})\' {src_dir_path} ' - f'gs://{self.name}/{dest_dir_name}') + f'gs://{self.name}{sub_path}/{dest_dir_name}') return sync_command # Generate message for upload @@ -1937,7 +2122,8 @@ def mount_command(self, mount_path: str) -> str: """ install_cmd = mounting_utils.get_gcs_mount_install_cmd() mount_cmd = mounting_utils.get_gcs_mount_cmd(self.bucket.name, - mount_path) + mount_path, + self._bucket_sub_path) version_check_cmd = ( f'gcsfuse --version | grep -q {mounting_utils.GCSFUSE_VERSION}') return mounting_utils.get_mounting_command(mount_path, install_cmd, @@ -1977,19 +2163,33 @@ def _create_gcs_bucket(self, f'{new_bucket.storage_class}{colorama.Style.RESET_ALL}') return new_bucket - def _delete_gcs_bucket(self, bucket_name: str) -> bool: - """Deletes GCS bucket, including all objects in bucket + def _delete_gcs_bucket( + self, + bucket_name: str, + # pylint: disable=invalid-name + _bucket_sub_path: Optional[str] = None + ) -> bool: + """Deletes objects in GCS bucket Args: bucket_name: str; Name of bucket + _bucket_sub_path: str; Sub path in the bucket, if provided only + objects in the sub path will be deleted, else the whole bucket will + be deleted Returns: bool; True if bucket was deleted, False if it was deleted externally. """ - + if _bucket_sub_path is not None: + command_suffix = f'/{_bucket_sub_path}' + hint_text = 'objects in ' + else: + command_suffix = '' + hint_text = '' with rich_utils.safe_status( ux_utils.spinner_message( - f'Deleting GCS bucket [green]{bucket_name}')): + f'Deleting {hint_text}GCS bucket ' + f'[green]{bucket_name}{command_suffix}[/]')): try: self.client.get_bucket(bucket_name) except gcp.forbidden_exception() as e: @@ -2007,8 +2207,9 @@ def _delete_gcs_bucket(self, bucket_name: str) -> bool: return False try: gsutil_alias, alias_gen = data_utils.get_gsutil_command() - remove_obj_command = (f'{alias_gen};{gsutil_alias} ' - f'rm -r gs://{bucket_name}') + remove_obj_command = ( + f'{alias_gen};{gsutil_alias} ' + f'rm -r gs://{bucket_name}{command_suffix}') subprocess.check_output(remove_obj_command, stderr=subprocess.STDOUT, shell=True, @@ -2017,7 +2218,8 @@ def _delete_gcs_bucket(self, bucket_name: str) -> bool: except subprocess.CalledProcessError as e: with ux_utils.print_exception_no_traceback(): raise exceptions.StorageBucketDeleteError( - f'Failed to delete GCS bucket {bucket_name}.' + f'Failed to delete {hint_text}GCS bucket ' + f'{bucket_name}{command_suffix}.' f'Detailed error: {e.output}') @@ -2069,7 +2271,8 @@ def __init__(self, storage_account_name: str = '', region: Optional[str] = 'eastus', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: bool = True): + sync_on_reconstruction: bool = True, + _bucket_sub_path: Optional[str] = None): self.storage_client: 'storage.Client' self.resource_client: 'storage.Client' self.container_name: str @@ -2081,7 +2284,7 @@ def __init__(self, if region is None: region = 'eastus' super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) @classmethod def from_metadata(cls, metadata: AbstractStore.StoreMetadata, @@ -2231,6 +2434,17 @@ def initialize(self): """ self.storage_client = data_utils.create_az_client('storage') self.resource_client = data_utils.create_az_client('resource') + self._update_storage_account_name_and_resource() + + self.container_name, is_new_bucket = self._get_bucket() + if self.is_sky_managed is None: + # If is_sky_managed is not specified, then this is a new storage + # object (i.e., did not exist in global_user_state) and we should + # set the is_sky_managed property. + # If is_sky_managed is specified, then we take no action. + self.is_sky_managed = is_new_bucket + + def _update_storage_account_name_and_resource(self): self.storage_account_name, self.resource_group_name = ( self._get_storage_account_and_resource_group()) @@ -2241,13 +2455,13 @@ def initialize(self): self.storage_account_name, self.resource_group_name, self.storage_client, self.resource_client) - self.container_name, is_new_bucket = self._get_bucket() - if self.is_sky_managed is None: - # If is_sky_managed is not specified, then this is a new storage - # object (i.e., did not exist in global_user_state) and we should - # set the is_sky_managed property. - # If is_sky_managed is specified, then we take no action. - self.is_sky_managed = is_new_bucket + def update_storage_attributes(self, **kwargs: Dict[str, Any]): + assert 'storage_account_name' in kwargs, ( + 'only storage_account_name supported') + assert isinstance(kwargs['storage_account_name'], + str), ('storage_account_name must be a string') + self.storage_account_name = kwargs['storage_account_name'] + self._update_storage_account_name_and_resource() @staticmethod def get_default_storage_account_name(region: Optional[str]) -> str: @@ -2518,6 +2732,9 @@ def upload(self): def delete(self) -> None: """Deletes the storage.""" + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_az_bucket(self.name) if deleted_by_skypilot: msg_str = (f'Deleted AZ Container {self.name!r} under storage ' @@ -2528,6 +2745,32 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + try: + container_url = data_utils.AZURE_CONTAINER_URL.format( + storage_account_name=self.storage_account_name, + container_name=self.name) + container_client = data_utils.create_az_client( + client_type='container', + container_url=container_url, + storage_account_name=self.storage_account_name, + resource_group_name=self.resource_group_name) + # List and delete blobs in the specified directory + blobs = container_client.list_blobs( + name_starts_with=self._bucket_sub_path + '/') + for blob in blobs: + container_client.delete_blob(blob.name) + logger.info( + f'Deleted objects from sub path {self._bucket_sub_path} ' + f'in container {self.name}.') + except Exception as e: # pylint: disable=broad-except + logger.error( + f'Failed to delete objects from sub path ' + f'{self._bucket_sub_path} in container {self.name}. ' + f'Details: {common_utils.format_exception(e, use_bracket=True)}' + ) + def get_handle(self) -> StorageHandle: """Returns the Storage Handle object.""" return self.storage_client.blob_containers.get( @@ -2554,13 +2797,15 @@ def get_file_sync_command(base_dir_path, file_names) -> str: includes_list = ';'.join(file_names) includes = f'--include-pattern "{includes_list}"' base_dir_path = shlex.quote(base_dir_path) + container_path = (f'{self.container_name}/{self._bucket_sub_path}' + if self._bucket_sub_path else self.container_name) sync_command = (f'az storage blob sync ' f'--account-name {self.storage_account_name} ' f'--account-key {self.storage_account_key} ' f'{includes} ' '--delete-destination false ' f'--source {base_dir_path} ' - f'--container {self.container_name}') + f'--container {container_path}') return sync_command def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: @@ -2571,8 +2816,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: [file_name.rstrip('*') for file_name in excluded_list]) excludes = f'--exclude-path "{excludes_list}"' src_dir_path = shlex.quote(src_dir_path) - container_path = (f'{self.container_name}/{dest_dir_name}' - if dest_dir_name else self.container_name) + container_path = (f'{self.container_name}/{self._bucket_sub_path}' + if self._bucket_sub_path else + f'{self.container_name}') + if dest_dir_name: + container_path = f'{container_path}/{dest_dir_name}' sync_command = (f'az storage blob sync ' f'--account-name {self.storage_account_name} ' f'--account-key {self.storage_account_key} ' @@ -2695,6 +2943,7 @@ def _get_bucket(self) -> Tuple[str, bool]: f'{self.storage_account_name!r}.' 'Details: ' f'{common_utils.format_exception(e, use_bracket=True)}') + # If the container cannot be found in both private and public settings, # the container is to be created by Sky. However, creation is skipped # if Store object is being reconstructed for deletion or re-mount with @@ -2725,7 +2974,8 @@ def mount_command(self, mount_path: str) -> str: mount_cmd = mounting_utils.get_az_mount_cmd(self.container_name, self.storage_account_name, mount_path, - self.storage_account_key) + self.storage_account_key, + self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -2824,11 +3074,12 @@ def __init__(self, source: str, region: Optional[str] = 'auto', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: Optional[bool] = True): + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): self.client: 'boto3.client.Client' self.bucket: 'StorageHandle' super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) def _validate(self): if self.source is not None and isinstance(self.source, str): @@ -2933,6 +3184,9 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_r2_bucket(self.name) if deleted_by_skypilot: msg_str = f'Deleted R2 bucket {self.name}.' @@ -2942,6 +3196,19 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + deleted_by_skypilot = self._delete_r2_bucket_sub_path( + self.name, self._bucket_sub_path) + if deleted_by_skypilot: + msg_str = f'Removed objects from R2 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + else: + msg_str = f'Failed to remove objects from R2 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + logger.info(f'{colorama.Fore.GREEN}{msg_str}' + f'{colorama.Style.RESET_ALL}') + def get_handle(self) -> StorageHandle: return cloudflare.resource('s3').Bucket(self.name) @@ -2973,11 +3240,13 @@ def get_file_sync_command(base_dir_path, file_names): ]) endpoint_url = cloudflare.create_endpoint() base_dir_path = shlex.quote(base_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ('AWS_SHARED_CREDENTIALS_FILE=' f'{cloudflare.R2_CREDENTIALS_PATH} ' 'aws s3 sync --no-follow-symlinks --exclude="*" ' f'{includes} {base_dir_path} ' - f's3://{self.name} ' + f's3://{self.name}{sub_path} ' f'--endpoint {endpoint_url} ' f'--profile={cloudflare.R2_PROFILE_NAME}') return sync_command @@ -2992,11 +3261,13 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): ]) endpoint_url = cloudflare.create_endpoint() src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ('AWS_SHARED_CREDENTIALS_FILE=' f'{cloudflare.R2_CREDENTIALS_PATH} ' f'aws s3 sync --no-follow-symlinks {excludes} ' f'{src_dir_path} ' - f's3://{self.name}/{dest_dir_name} ' + f's3://{self.name}{sub_path}/{dest_dir_name} ' f'--endpoint {endpoint_url} ' f'--profile={cloudflare.R2_PROFILE_NAME}') return sync_command @@ -3127,11 +3398,9 @@ def mount_command(self, mount_path: str) -> str: endpoint_url = cloudflare.create_endpoint() r2_credential_path = cloudflare.R2_CREDENTIALS_PATH r2_profile_name = cloudflare.R2_PROFILE_NAME - mount_cmd = mounting_utils.get_r2_mount_cmd(r2_credential_path, - r2_profile_name, - endpoint_url, - self.bucket.name, - mount_path) + mount_cmd = mounting_utils.get_r2_mount_cmd( + r2_credential_path, r2_profile_name, endpoint_url, self.bucket.name, + mount_path, self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -3164,6 +3433,43 @@ def _create_r2_bucket(self, f'{self.name} but failed.') from e return cloudflare.resource('s3').Bucket(bucket_name) + def _execute_r2_remove_command(self, command: str, bucket_name: str, + hint_operating: str, + hint_failed: str) -> bool: + try: + with rich_utils.safe_status( + ux_utils.spinner_message(hint_operating)): + subprocess.check_output(command.split(' '), + stderr=subprocess.STDOUT, + shell=True) + except subprocess.CalledProcessError as e: + if 'NoSuchBucket' in e.output.decode('utf-8'): + logger.debug( + _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( + bucket_name=bucket_name)) + return False + else: + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketDeleteError( + f'{hint_failed}' + f'Detailed error: {e.output}') + return True + + def _delete_r2_bucket_sub_path(self, bucket_name: str, + sub_path: str) -> bool: + """Deletes the sub path from the bucket.""" + endpoint_url = cloudflare.create_endpoint() + remove_command = ( + f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} ' + f'aws s3 rm s3://{bucket_name}/{sub_path}/ --recursive ' + f'--endpoint {endpoint_url} ' + f'--profile={cloudflare.R2_PROFILE_NAME}') + return self._execute_r2_remove_command( + remove_command, bucket_name, + f'Removing objects from R2 bucket {bucket_name}/{sub_path}', + f'Failed to remove objects from R2 bucket {bucket_name}/{sub_path}.' + ) + def _delete_r2_bucket(self, bucket_name: str) -> bool: """Deletes R2 bucket, including all objects in bucket @@ -3186,24 +3492,12 @@ def _delete_r2_bucket(self, bucket_name: str) -> bool: f'aws s3 rb s3://{bucket_name} --force ' f'--endpoint {endpoint_url} ' f'--profile={cloudflare.R2_PROFILE_NAME}') - try: - with rich_utils.safe_status( - ux_utils.spinner_message( - f'Deleting R2 bucket {bucket_name}')): - subprocess.check_output(remove_command, - stderr=subprocess.STDOUT, - shell=True) - except subprocess.CalledProcessError as e: - if 'NoSuchBucket' in e.output.decode('utf-8'): - logger.debug( - _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( - bucket_name=bucket_name)) - return False - else: - with ux_utils.print_exception_no_traceback(): - raise exceptions.StorageBucketDeleteError( - f'Failed to delete R2 bucket {bucket_name}.' - f'Detailed error: {e.output}') + + success = self._execute_r2_remove_command( + remove_command, bucket_name, f'Deleting R2 bucket {bucket_name}', + f'Failed to delete R2 bucket {bucket_name}.') + if not success: + return False # Wait until bucket deletion propagates on AWS servers while data_utils.verify_r2_bucket(bucket_name): @@ -3222,11 +3516,12 @@ def __init__(self, source: str, region: Optional[str] = 'us-east', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: bool = True): + sync_on_reconstruction: bool = True, + _bucket_sub_path: Optional[str] = None): self.client: 'storage.Client' self.bucket: 'StorageHandle' super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) self.bucket_rclone_profile = \ Rclone.generate_rclone_bucket_profile_name( self.name, Rclone.RcloneClouds.IBM) @@ -3371,10 +3666,22 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + self._delete_cos_bucket() logger.info(f'{colorama.Fore.GREEN}Deleted COS bucket {self.name}.' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + bucket = self.s3_resource.Bucket(self.name) + try: + self._delete_cos_bucket_objects(bucket, self._bucket_sub_path + '/') + except ibm.ibm_botocore.exceptions.ClientError as e: + if e.__class__.__name__ == 'NoSuchBucket': + logger.debug('bucket already removed') + def get_handle(self) -> StorageHandle: return self.s3_resource.Bucket(self.name) @@ -3415,10 +3722,13 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: # .git directory is excluded from the sync # wrapping src_dir_path with "" to support path with spaces src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ( 'rclone copy --exclude ".git/*" ' f'{src_dir_path} ' - f'{self.bucket_rclone_profile}:{self.name}/{dest_dir_name}') + f'{self.bucket_rclone_profile}:{self.name}{sub_path}' + f'/{dest_dir_name}') return sync_command def get_file_sync_command(base_dir_path, file_names) -> str: @@ -3444,9 +3754,12 @@ def get_file_sync_command(base_dir_path, file_names) -> str: for file_name in file_names ]) base_dir_path = shlex.quote(base_dir_path) - sync_command = ('rclone copy ' - f'{includes} {base_dir_path} ' - f'{self.bucket_rclone_profile}:{self.name}') + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') + sync_command = ( + 'rclone copy ' + f'{includes} {base_dir_path} ' + f'{self.bucket_rclone_profile}:{self.name}{sub_path}') return sync_command # Generate message for upload @@ -3531,6 +3844,7 @@ def _get_bucket(self) -> Tuple[StorageHandle, bool]: Rclone.RcloneClouds.IBM, self.region, # type: ignore ) + if not bucket_region and self.sync_on_reconstruction: # bucket doesn't exist return self._create_cos_bucket(self.name, self.region), True @@ -3577,7 +3891,8 @@ def mount_command(self, mount_path: str) -> str: Rclone.RCLONE_CONFIG_PATH, self.bucket_rclone_profile, self.bucket.name, - mount_path) + mount_path, + self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -3615,15 +3930,27 @@ def _create_cos_bucket(self, return self.bucket - def _delete_cos_bucket(self): - bucket = self.s3_resource.Bucket(self.name) - try: - bucket_versioning = self.s3_resource.BucketVersioning(self.name) - if bucket_versioning.status == 'Enabled': + def _delete_cos_bucket_objects(self, + bucket: Any, + prefix: Optional[str] = None): + bucket_versioning = self.s3_resource.BucketVersioning(bucket.name) + if bucket_versioning.status == 'Enabled': + if prefix is not None: + res = list( + bucket.object_versions.filter(Prefix=prefix).delete()) + else: res = list(bucket.object_versions.delete()) + else: + if prefix is not None: + res = list(bucket.objects.filter(Prefix=prefix).delete()) else: res = list(bucket.objects.delete()) - logger.debug(f'Deleted bucket\'s content:\n{res}') + logger.debug(f'Deleted bucket\'s content:\n{res}, prefix: {prefix}') + + def _delete_cos_bucket(self): + bucket = self.s3_resource.Bucket(self.name) + try: + self._delete_cos_bucket_objects(bucket) bucket.delete() bucket.wait_until_not_exists() except ibm.ibm_botocore.exceptions.ClientError as e: @@ -3644,7 +3971,8 @@ def __init__(self, source: str, region: Optional[str] = None, is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: Optional[bool] = True): + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): self.client: Any self.bucket: StorageHandle self.oci_config_file: str @@ -3656,7 +3984,8 @@ def __init__(self, region = oci.get_oci_config()['region'] super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) + # TODO(zpoint): add _bucket_sub_path to the sync/mount/delete commands def _validate(self): if self.source is not None and isinstance(self.source, str): diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 0b2a5b08e1b..96651eddc39 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -268,12 +268,16 @@ # Used for translate local file mounts to cloud storage. Please refer to # sky/execution.py::_maybe_translate_local_file_mounts_and_sync_up for # more details. -WORKDIR_BUCKET_NAME = 'skypilot-workdir-{username}-{id}' -FILE_MOUNTS_BUCKET_NAME = 'skypilot-filemounts-folder-{username}-{id}' -FILE_MOUNTS_FILE_ONLY_BUCKET_NAME = 'skypilot-filemounts-files-{username}-{id}' +FILE_MOUNTS_BUCKET_NAME = 'skypilot-filemounts-{username}-{id}' FILE_MOUNTS_LOCAL_TMP_DIR = 'skypilot-filemounts-files-{id}' FILE_MOUNTS_REMOTE_TMP_DIR = '/tmp/sky-{}-filemounts-files' +# Used when an managed jobs are created and +# files are synced up to the cloud. +FILE_MOUNTS_WORKDIR_SUBPATH = 'job-{run_id}/workdir' +FILE_MOUNTS_SUBPATH = 'job-{run_id}/local-file-mounts/{i}' +FILE_MOUNTS_TMP_SUBPATH = 'job-{run_id}/tmp-files' + # The default idle timeout for SkyPilot controllers. This include spot # controller and sky serve controller. # TODO(tian): Refactor to controller_utils. Current blocker: circular import. diff --git a/sky/task.py b/sky/task.py index edd2fd211a3..bbf6d59b2ae 100644 --- a/sky/task.py +++ b/sky/task.py @@ -948,12 +948,22 @@ def _get_preferred_store( store_type = storage_lib.StoreType.from_cloud(storage_cloud_str) return store_type, storage_region - def sync_storage_mounts(self) -> None: + def sync_storage_mounts(self, force_sync: bool = False) -> None: """(INTERNAL) Eagerly syncs storage mounts to cloud storage. After syncing up, COPY-mode storage mounts are translated into regular file_mounts of the form ``{ /remote/path: {s3,gs,..}:// }``. + + Args: + force_sync: If True, forces the synchronization of storage mounts. + If the store object is added via storage.add_store(), + the sync will happen automatically via add_store. + However, if it is passed via the construction function + of storage, it is usually because the user passed an + intermediate bucket name in the config and we need to + construct from the user config. In this case, set + force_sync to True. """ for storage in self.storage_mounts.values(): if not storage.stores: @@ -961,6 +971,8 @@ def sync_storage_mounts(self) -> None: self.storage_plans[storage] = store_type storage.add_store(store_type, store_region) else: + if force_sync: + storage.sync_all_stores() # We will download the first store that is added to remote. self.storage_plans[storage] = list(storage.stores.keys())[0] @@ -977,6 +989,7 @@ def sync_storage_mounts(self) -> None: else: assert storage.name is not None, storage blob_path = 's3://' + storage.name + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -987,6 +1000,7 @@ def sync_storage_mounts(self) -> None: else: assert storage.name is not None, storage blob_path = 'gs://' + storage.name + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -1005,6 +1019,7 @@ def sync_storage_mounts(self) -> None: blob_path = data_utils.AZURE_CONTAINER_URL.format( storage_account_name=storage_account_name, container_name=storage.name) + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -1015,6 +1030,7 @@ def sync_storage_mounts(self) -> None: blob_path = storage.source else: blob_path = 'r2://' + storage.name + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -1030,6 +1046,7 @@ def sync_storage_mounts(self) -> None: cos_region = data_utils.Rclone.get_region_from_rclone( storage.name, data_utils.Rclone.RcloneClouds.IBM) blob_path = f'cos://{cos_region}/{storage.name}' + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({mnt_path: blob_path}) elif store_type is storage_lib.StoreType.OCI: if storage.source is not None and not isinstance( diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 0166a16ff16..39623085bbb 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -649,10 +649,27 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', still sync up any storage mounts with local source paths (which do not undergo translation). """ + # ================================================================ # Translate the workdir and local file mounts to cloud file mounts. # ================================================================ + def _sub_path_join(sub_path: Optional[str], path: str) -> str: + if sub_path is None: + return path + return os.path.join(sub_path, path).strip('/') + + def assert_no_bucket_creation(store: storage_lib.AbstractStore) -> None: + if store.is_sky_managed: + # Bucket was created, this should not happen since use configured + # the bucket and we assumed it already exists. + store.delete() + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketCreateError( + f'Jobs bucket {store.name!r} does not exist. ' + 'Please check jobs.bucket configuration in ' + 'your SkyPilot config.') + run_id = common_utils.get_usage_run_id()[:8] original_file_mounts = task.file_mounts if task.file_mounts else {} original_storage_mounts = task.storage_mounts if task.storage_mounts else {} @@ -679,11 +696,27 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', ux_utils.spinner_message( f'Translating {msg} to SkyPilot Storage...')) + # Get the bucket name for the workdir and file mounts, + # we store all these files in same bucket from config. + bucket_wth_prefix = skypilot_config.get_nested(('jobs', 'bucket'), None) + store_kwargs: Dict[str, Any] = {} + if bucket_wth_prefix is None: + store_type = store_cls = sub_path = None + storage_account_name = region = None + bucket_name = constants.FILE_MOUNTS_BUCKET_NAME.format( + username=common_utils.get_cleaned_username(), id=run_id) + else: + store_type, store_cls, bucket_name, sub_path, storage_account_name, \ + region = storage_lib.StoreType.get_fields_from_store_url( + bucket_wth_prefix) + if storage_account_name is not None: + store_kwargs['storage_account_name'] = storage_account_name + if region is not None: + store_kwargs['region'] = region + # Step 1: Translate the workdir to SkyPilot storage. new_storage_mounts = {} if task.workdir is not None: - bucket_name = constants.WORKDIR_BUCKET_NAME.format( - username=common_utils.get_cleaned_username(), id=run_id) workdir = task.workdir task.workdir = None if (constants.SKY_REMOTE_WORKDIR in original_file_mounts or @@ -691,14 +724,28 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', raise ValueError( f'Cannot mount {constants.SKY_REMOTE_WORKDIR} as both the ' 'workdir and file_mounts contains it as the target.') - new_storage_mounts[ - constants. - SKY_REMOTE_WORKDIR] = storage_lib.Storage.from_yaml_config({ - 'name': bucket_name, - 'source': workdir, - 'persistent': False, - 'mode': 'COPY', - }) + bucket_sub_path = _sub_path_join( + sub_path, + constants.FILE_MOUNTS_WORKDIR_SUBPATH.format(run_id=run_id)) + stores = None + if store_type is not None: + assert store_cls is not None + with sky_logging.silent(): + stores = { + store_type: store_cls(name=bucket_name, + source=workdir, + _bucket_sub_path=bucket_sub_path, + **store_kwargs) + } + assert_no_bucket_creation(stores[store_type]) + + storage_obj = storage_lib.Storage(name=bucket_name, + source=workdir, + persistent=False, + mode=storage_lib.StorageMode.COPY, + stores=stores, + _bucket_sub_path=bucket_sub_path) + new_storage_mounts[constants.SKY_REMOTE_WORKDIR] = storage_obj # Check of the existence of the workdir in file_mounts is done in # the task construction. logger.info(f' {colorama.Style.DIM}Workdir: {workdir!r} ' @@ -716,27 +763,37 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', if os.path.isfile(os.path.abspath(os.path.expanduser(src))): copy_mounts_with_file_in_src[dst] = src continue - bucket_name = constants.FILE_MOUNTS_BUCKET_NAME.format( - username=common_utils.get_cleaned_username(), - id=f'{run_id}-{i}', - ) - new_storage_mounts[dst] = storage_lib.Storage.from_yaml_config({ - 'name': bucket_name, - 'source': src, - 'persistent': False, - 'mode': 'COPY', - }) + bucket_sub_path = _sub_path_join( + sub_path, constants.FILE_MOUNTS_SUBPATH.format(i=i, run_id=run_id)) + stores = None + if store_type is not None: + assert store_cls is not None + with sky_logging.silent(): + store = store_cls(name=bucket_name, + source=src, + _bucket_sub_path=bucket_sub_path, + **store_kwargs) + + stores = {store_type: store} + assert_no_bucket_creation(stores[store_type]) + storage_obj = storage_lib.Storage(name=bucket_name, + source=src, + persistent=False, + mode=storage_lib.StorageMode.COPY, + stores=stores, + _bucket_sub_path=bucket_sub_path) + new_storage_mounts[dst] = storage_obj logger.info(f' {colorama.Style.DIM}Folder : {src!r} ' f'-> storage: {bucket_name!r}.{colorama.Style.RESET_ALL}') # Step 3: Translate local file mounts with file in src to SkyPilot storage. # Hard link the files in src to a temporary directory, and upload folder. + file_mounts_tmp_subpath = _sub_path_join( + sub_path, constants.FILE_MOUNTS_TMP_SUBPATH.format(run_id=run_id)) local_fm_path = os.path.join( tempfile.gettempdir(), constants.FILE_MOUNTS_LOCAL_TMP_DIR.format(id=run_id)) os.makedirs(local_fm_path, exist_ok=True) - file_bucket_name = constants.FILE_MOUNTS_FILE_ONLY_BUCKET_NAME.format( - username=common_utils.get_cleaned_username(), id=run_id) file_mount_remote_tmp_dir = constants.FILE_MOUNTS_REMOTE_TMP_DIR.format( path) if copy_mounts_with_file_in_src: @@ -745,14 +802,27 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', src_to_file_id[src] = i os.link(os.path.abspath(os.path.expanduser(src)), os.path.join(local_fm_path, f'file-{i}')) - - new_storage_mounts[ - file_mount_remote_tmp_dir] = storage_lib.Storage.from_yaml_config({ - 'name': file_bucket_name, - 'source': local_fm_path, - 'persistent': False, - 'mode': 'MOUNT', - }) + stores = None + if store_type is not None: + assert store_cls is not None + with sky_logging.silent(): + stores = { + store_type: store_cls( + name=bucket_name, + source=local_fm_path, + _bucket_sub_path=file_mounts_tmp_subpath, + **store_kwargs) + } + assert_no_bucket_creation(stores[store_type]) + storage_obj = storage_lib.Storage( + name=bucket_name, + source=local_fm_path, + persistent=False, + mode=storage_lib.StorageMode.MOUNT, + stores=stores, + _bucket_sub_path=file_mounts_tmp_subpath) + + new_storage_mounts[file_mount_remote_tmp_dir] = storage_obj if file_mount_remote_tmp_dir in original_storage_mounts: with ux_utils.print_exception_no_traceback(): raise ValueError( @@ -762,8 +832,9 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', sources = list(src_to_file_id.keys()) sources_str = '\n '.join(sources) logger.info(f' {colorama.Style.DIM}Files (listed below) ' - f' -> storage: {file_bucket_name}:' + f' -> storage: {bucket_name}:' f'\n {sources_str}{colorama.Style.RESET_ALL}') + rich_utils.force_update_status( ux_utils.spinner_message('Uploading translated local files/folders')) task.update_storage_mounts(new_storage_mounts) @@ -779,7 +850,7 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', ux_utils.spinner_message('Uploading local sources to storage[/] ' '[dim]View storages: sky storage ls')) try: - task.sync_storage_mounts() + task.sync_storage_mounts(force_sync=bucket_wth_prefix is not None) except (ValueError, exceptions.NoCloudAccessError) as e: if 'No enabled cloud for storage' in str(e) or isinstance( e, exceptions.NoCloudAccessError): @@ -809,10 +880,11 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', # file_mount_remote_tmp_dir will only exist when there are files in # the src for copy mounts. storage_obj = task.storage_mounts[file_mount_remote_tmp_dir] - store_type = list(storage_obj.stores.keys())[0] - store_object = storage_obj.stores[store_type] + curr_store_type = list(storage_obj.stores.keys())[0] + store_object = storage_obj.stores[curr_store_type] bucket_url = storage_lib.StoreType.get_endpoint_url( - store_object, file_bucket_name) + store_object, bucket_name) + bucket_url += f'/{file_mounts_tmp_subpath}' for dst, src in copy_mounts_with_file_in_src.items(): file_id = src_to_file_id[src] new_file_mounts[dst] = bucket_url + f'/file-{file_id}' @@ -829,8 +901,8 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', store_types = list(storage_obj.stores.keys()) assert len(store_types) == 1, ( 'We only support one store type for now.', storage_obj.stores) - store_type = store_types[0] - store_object = storage_obj.stores[store_type] + curr_store_type = store_types[0] + store_object = storage_obj.stores[curr_store_type] storage_obj.source = storage_lib.StoreType.get_endpoint_url( store_object, storage_obj.name) storage_obj.force_delete = True @@ -847,8 +919,8 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', store_types = list(storage_obj.stores.keys()) assert len(store_types) == 1, ( 'We only support one store type for now.', storage_obj.stores) - store_type = store_types[0] - store_object = storage_obj.stores[store_type] + curr_store_type = store_types[0] + store_object = storage_obj.stores[curr_store_type] source = storage_lib.StoreType.get_endpoint_url( store_object, storage_obj.name) new_storage = storage_lib.Storage.from_yaml_config({ diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 851e77a57fc..a424ae074b9 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -299,6 +299,12 @@ def get_storage_schema(): mode.value for mode in storage.StorageMode ] }, + '_is_sky_managed': { + 'type': 'boolean', + }, + '_bucket_sub_path': { + 'type': 'string', + }, '_force_delete': { 'type': 'boolean', } @@ -721,6 +727,11 @@ def get_config_schema(): 'resources': resources_schema, } }, + 'bucket': { + 'type': 'string', + 'pattern': '^(https|s3|gs|r2|cos)://.+', + 'required': [], + } } } cloud_configs = { diff --git a/tests/smoke_tests/test_managed_job.py b/tests/smoke_tests/test_managed_job.py index c8ef5c1a502..f39dba6f47e 100644 --- a/tests/smoke_tests/test_managed_job.py +++ b/tests/smoke_tests/test_managed_job.py @@ -23,6 +23,7 @@ # > pytest tests/smoke_tests/test_managed_job.py --generic-cloud aws import pathlib +import re import tempfile import time @@ -742,14 +743,70 @@ def test_managed_jobs_storage(generic_cloud: str): # Check if file was written to the mounted output bucket output_check_cmd ], - (f'sky jobs cancel -y -n {name}', - f'; sky storage delete {output_storage_name} || true'), + (f'sky jobs cancel -y -n {name}' + f'; sky storage delete {output_storage_name} -y || true'), # Increase timeout since sky jobs queue -r can be blocked by other spot tests. timeout=20 * 60, ) smoke_tests_utils.run_one_test(test) +@pytest.mark.aws +def test_managed_jobs_intermediate_storage(generic_cloud: str): + """Test storage with managed job""" + name = smoke_tests_utils.get_cluster_name() + yaml_str = pathlib.Path( + 'examples/managed_job_with_storage.yaml').read_text() + timestamp = int(time.time()) + storage_name = f'sky-test-{timestamp}' + output_storage_name = f'sky-test-output-{timestamp}' + + yaml_str_user_config = pathlib.Path( + 'tests/test_yamls/use_intermediate_bucket_config.yaml').read_text() + intermediate_storage_name = f'intermediate-smoke-test-{timestamp}' + + yaml_str = yaml_str.replace('sky-workdir-zhwu', storage_name) + yaml_str = yaml_str.replace('sky-output-bucket', output_storage_name) + yaml_str_user_config = re.sub(r'bucket-jobs-[\w\d]+', + intermediate_storage_name, + yaml_str_user_config) + + with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f_user_config: + f_user_config.write(yaml_str_user_config) + f_user_config.flush() + user_config_path = f_user_config.name + with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f_task: + f_task.write(yaml_str) + f_task.flush() + file_path = f_task.name + + test = smoke_tests_utils.Test( + 'managed_jobs_intermediate_storage', + [ + *smoke_tests_utils.STORAGE_SETUP_COMMANDS, + # Verify command fails with correct error - run only once + f'err=$(sky jobs launch -n {name} --cloud {generic_cloud} {file_path} -y 2>&1); ret=$?; echo "$err" ; [ $ret -eq 0 ] || ! echo "$err" | grep "StorageBucketCreateError: Jobs bucket \'{intermediate_storage_name}\' does not exist. Please check jobs.bucket configuration in your SkyPilot config." > /dev/null && exit 1 || exit 0', + f'aws s3api create-bucket --bucket {intermediate_storage_name}', + f'sky jobs launch -n {name} --cloud {generic_cloud} {file_path} -y', + # fail because the bucket does not exist + smoke_tests_utils. + get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.SUCCEEDED], + timeout=60 + smoke_tests_utils.BUMP_UP_SECONDS), + # check intermediate bucket exists, it won't be deletd if its user specific + f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{intermediate_storage_name}\')].Name" --output text | wc -l) -eq 1 ]', + ], + (f'sky jobs cancel -y -n {name}' + f'; aws s3 rb s3://{intermediate_storage_name} --force' + f'; sky storage delete {output_storage_name} -y || true'), + env={'SKYPILOT_CONFIG': user_config_path}, + # Increase timeout since sky jobs queue -r can be blocked by other spot tests. + timeout=20 * 60, + ) + smoke_tests_utils.run_one_test(test) + + # ---------- Testing spot TPU ---------- @pytest.mark.gcp @pytest.mark.managed_jobs diff --git a/tests/smoke_tests/test_mount_and_storage.py b/tests/smoke_tests/test_mount_and_storage.py index aa61282aa11..89a849ad090 100644 --- a/tests/smoke_tests/test_mount_and_storage.py +++ b/tests/smoke_tests/test_mount_and_storage.py @@ -19,6 +19,7 @@ # Change cloud for generic tests to aws # > pytest tests/smoke_tests/test_mount_and_storage.py --generic-cloud aws +import json import os import pathlib import shlex @@ -37,6 +38,7 @@ import sky from sky import global_user_state from sky import skypilot_config +from sky.adaptors import azure from sky.adaptors import cloudflare from sky.adaptors import ibm from sky.data import data_utils @@ -629,21 +631,69 @@ def cli_delete_cmd(store_type, bucket_name, Rclone.RcloneClouds.IBM) return f'rclone purge {bucket_rclone_profile}:{bucket_name} && rclone config delete {bucket_rclone_profile}' + @classmethod + def list_all_files(cls, store_type, bucket_name): + cmd = cls.cli_ls_cmd(store_type, bucket_name, recursive=True) + if store_type == storage_lib.StoreType.GCS: + try: + out = subprocess.check_output(cmd, + shell=True, + stderr=subprocess.PIPE) + files = [line[5:] for line in out.decode('utf-8').splitlines()] + except subprocess.CalledProcessError as e: + error_output = e.stderr.decode('utf-8') + if "One or more URLs matched no objects" in error_output: + files = [] + else: + raise + elif store_type == storage_lib.StoreType.AZURE: + out = subprocess.check_output(cmd, shell=True) + try: + blobs = json.loads(out.decode('utf-8')) + files = [blob['name'] for blob in blobs] + except json.JSONDecodeError: + files = [] + elif store_type == storage_lib.StoreType.IBM: + # rclone ls format: " 1234 path/to/file" + out = subprocess.check_output(cmd, shell=True) + files = [] + for line in out.decode('utf-8').splitlines(): + # Skip empty lines + if not line.strip(): + continue + # Split by whitespace and get the file path (last column) + parts = line.strip().split( + None, 1) # Split into max 2 parts (size and path) + if len(parts) == 2: + files.append(parts[1]) + else: + out = subprocess.check_output(cmd, shell=True) + files = [ + line.split()[-1] for line in out.decode('utf-8').splitlines() + ] + return files + @staticmethod - def cli_ls_cmd(store_type, bucket_name, suffix=''): + def cli_ls_cmd(store_type, bucket_name, suffix='', recursive=False): if store_type == storage_lib.StoreType.S3: if suffix: url = f's3://{bucket_name}/{suffix}' else: url = f's3://{bucket_name}' - return f'aws s3 ls {url}' + cmd = f'aws s3 ls {url}' + if recursive: + cmd += ' --recursive' + return cmd if store_type == storage_lib.StoreType.GCS: if suffix: url = f'gs://{bucket_name}/{suffix}' else: url = f'gs://{bucket_name}' + if recursive: + url = f'"{url}/**"' return f'gsutil ls {url}' if store_type == storage_lib.StoreType.AZURE: + # azure isrecursive by default default_region = 'eastus' config_storage_account = skypilot_config.get_nested( ('azure', 'storage_account'), None) @@ -665,8 +715,10 @@ def cli_ls_cmd(store_type, bucket_name, suffix=''): url = f's3://{bucket_name}/{suffix}' else: url = f's3://{bucket_name}' - return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls {url} --endpoint {endpoint_url} --profile=r2' + recursive_flag = '--recursive' if recursive else '' + return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls {url} --endpoint {endpoint_url} --profile=r2 {recursive_flag}' if store_type == storage_lib.StoreType.IBM: + # rclone ls is recursive by default bucket_rclone_profile = Rclone.generate_rclone_bucket_profile_name( bucket_name, Rclone.RcloneClouds.IBM) return f'rclone ls {bucket_rclone_profile}:{bucket_name}/{suffix}' @@ -764,6 +816,12 @@ def tmp_source(self, tmp_path): circle_link.symlink_to(tmp_dir, target_is_directory=True) yield str(tmp_dir) + @pytest.fixture + def tmp_sub_path(self): + tmp_dir1 = uuid.uuid4().hex[:8] + tmp_dir2 = uuid.uuid4().hex[:8] + yield "/".join([tmp_dir1, tmp_dir2]) + @staticmethod def generate_bucket_name(): # Creates a temporary bucket name @@ -783,13 +841,15 @@ def yield_storage_object( stores: Optional[Dict[storage_lib.StoreType, storage_lib.AbstractStore]] = None, persistent: Optional[bool] = True, - mode: storage_lib.StorageMode = storage_lib.StorageMode.MOUNT): + mode: storage_lib.StorageMode = storage_lib.StorageMode.MOUNT, + _bucket_sub_path: Optional[str] = None): # Creates a temporary storage object. Stores must be added in the test. storage_obj = storage_lib.Storage(name=name, source=source, stores=stores, persistent=persistent, - mode=mode) + mode=mode, + _bucket_sub_path=_bucket_sub_path) yield storage_obj handle = global_user_state.get_handle_from_storage_name( storage_obj.name) @@ -856,6 +916,15 @@ def tmp_local_storage_obj(self, tmp_bucket_name, tmp_source): yield from self.yield_storage_object(name=tmp_bucket_name, source=tmp_source) + @pytest.fixture + def tmp_local_storage_obj_with_sub_path(self, tmp_bucket_name, tmp_source, + tmp_sub_path): + # Creates a temporary storage object with sub. Stores must be added in the test. + list_source = [tmp_source, tmp_source + '/tmp-file'] + yield from self.yield_storage_object(name=tmp_bucket_name, + source=list_source, + _bucket_sub_path=tmp_sub_path) + @pytest.fixture def tmp_local_list_storage_obj(self, tmp_bucket_name, tmp_source): # Creates a temp storage object which uses a list of paths as source. @@ -1014,6 +1083,59 @@ def test_new_bucket_creation_and_deletion(self, tmp_local_storage_obj, out = subprocess.check_output(['sky', 'storage', 'ls']) assert tmp_local_storage_obj.name not in out.decode('utf-8') + @pytest.mark.no_fluidstack + @pytest.mark.parametrize('store_type', [ + pytest.param(storage_lib.StoreType.S3, marks=pytest.mark.aws), + pytest.param(storage_lib.StoreType.GCS, marks=pytest.mark.gcp), + pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure), + pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm), + pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare) + ]) + def test_bucket_sub_path(self, tmp_local_storage_obj_with_sub_path, + store_type): + # Creates a new bucket with a local source, uploads files to it + # and deletes it. + tmp_local_storage_obj_with_sub_path.add_store(store_type) + + # Check files under bucket and filter by prefix + files = self.list_all_files(store_type, + tmp_local_storage_obj_with_sub_path.name) + assert len(files) > 0 + if store_type == storage_lib.StoreType.GCS: + assert all([ + file.startswith( + tmp_local_storage_obj_with_sub_path.name + '/' + + tmp_local_storage_obj_with_sub_path._bucket_sub_path) + for file in files + ]) + else: + assert all([ + file.startswith( + tmp_local_storage_obj_with_sub_path._bucket_sub_path) + for file in files + ]) + + # Check bucket is empty, all files under sub directory should be deleted + store = tmp_local_storage_obj_with_sub_path.stores[store_type] + store.is_sky_managed = False + if store_type == storage_lib.StoreType.AZURE: + azure.assign_storage_account_iam_role( + storage_account_name=store.storage_account_name, + resource_group_name=store.resource_group_name) + store.delete() + files = self.list_all_files(store_type, + tmp_local_storage_obj_with_sub_path.name) + assert len(files) == 0 + + # Now, delete the entire bucket + store.is_sky_managed = True + tmp_local_storage_obj_with_sub_path.delete() + + # Run sky storage ls to check if storage object is deleted + out = subprocess.check_output(['sky', 'storage', 'ls']) + assert tmp_local_storage_obj_with_sub_path.name not in out.decode( + 'utf-8') + @pytest.mark.no_fluidstack @pytest.mark.xdist_group('multiple_bucket_deletion') @pytest.mark.parametrize('store_type', [ diff --git a/tests/test_yamls/intermediate_bucket.yaml b/tests/test_yamls/intermediate_bucket.yaml new file mode 100644 index 00000000000..fe9aafd0675 --- /dev/null +++ b/tests/test_yamls/intermediate_bucket.yaml @@ -0,0 +1,21 @@ +name: intermediate-bucket + +file_mounts: + /setup.py: ./setup.py + /sky: . + /train-00001-of-01024: gs://cloud-tpu-test-datasets/fake_imagenet/train-00001-of-01024 + +workdir: . + + +setup: | + echo "running setup" + +run: | + echo "listing workdir" + ls . + echo "listing file_mounts" + ls /setup.py + ls /sky + ls /train-00001-of-01024 + echo "task run finish" diff --git a/tests/test_yamls/use_intermediate_bucket_config.yaml b/tests/test_yamls/use_intermediate_bucket_config.yaml new file mode 100644 index 00000000000..cdfb5fbabc1 --- /dev/null +++ b/tests/test_yamls/use_intermediate_bucket_config.yaml @@ -0,0 +1,2 @@ +jobs: + bucket: "s3://bucket-jobs-s3" From 13501e2e30987010e04c728b713b820a8acd3471 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Mon, 30 Dec 2024 13:04:13 -0800 Subject: [PATCH 3/5] [Core] Deprecate LocalDockerBackend (#4516) Deprecate local docker backend --- examples/local_docker/docker_in_docker.yaml | 19 ------------------ examples/local_docker/ping.py | 22 --------------------- examples/local_docker/ping.yaml | 19 ------------------ sky/cli.py | 11 +++++++++-- 4 files changed, 9 insertions(+), 62 deletions(-) delete mode 100644 examples/local_docker/docker_in_docker.yaml delete mode 100644 examples/local_docker/ping.py delete mode 100644 examples/local_docker/ping.yaml diff --git a/examples/local_docker/docker_in_docker.yaml b/examples/local_docker/docker_in_docker.yaml deleted file mode 100644 index bdb6ed70ecf..00000000000 --- a/examples/local_docker/docker_in_docker.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# Runs a docker container as a SkyPilot task. -# -# This demo can be run using the --docker flag, demonstrating the -# docker-in-docker (dind) capabilities of SkyPilot docker mode. -# -# Usage: -# sky launch --docker -c dind docker_in_docker.yaml -# sky down dind - -name: dind - -resources: - cloud: aws - -setup: | - echo "No setup required!" - -run: | - docker run --rm hello-world diff --git a/examples/local_docker/ping.py b/examples/local_docker/ping.py deleted file mode 100644 index c3a90c62243..00000000000 --- a/examples/local_docker/ping.py +++ /dev/null @@ -1,22 +0,0 @@ -"""An example app which pings localhost. - -This script is designed to demonstrate the use of different backends with -SkyPilot. It is useful to support a LocalDockerBackend that users can use to -debug their programs even before they run them on the Sky. -""" - -import sky - -# Set backend here. It can be either LocalDockerBackend or CloudVmRayBackend. -backend = sky.backends.LocalDockerBackend( -) # or sky.backends.CloudVmRayBackend() - -with sky.Dag() as dag: - resources = sky.Resources(accelerators={'K80': 1}) - setup_commands = 'apt-get update && apt-get install -y iputils-ping' - task = sky.Task(run='ping 127.0.0.1 -c 100', - docker_image='ubuntu', - setup=setup_commands, - name='ping').set_resources(resources) - -sky.launch(dag, backend=backend) diff --git a/examples/local_docker/ping.yaml b/examples/local_docker/ping.yaml deleted file mode 100644 index 0d0efd12419..00000000000 --- a/examples/local_docker/ping.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# A minimal ping example. -# -# Runs a task that pings localhost 100 times. -# -# Usage: -# sky launch --docker -c ping ping.yaml -# sky down ping - -name: ping - -resources: - cloud: aws - -setup: | - sudo apt-get update --allow-insecure-repositories - sudo apt-get install -y iputils-ping - -run: | - ping 127.0.0.1 -c 100 diff --git a/sky/cli.py b/sky/cli.py index 5d4f07d535f..d00aae9b646 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -998,8 +998,10 @@ def cli(): @click.option('--docker', 'backend_name', flag_value=backends.LocalDockerBackend.NAME, - default=False, - help='If used, runs locally inside a docker container.') + hidden=True, + help=('(Deprecated) Local docker support is deprecated. ' + 'To run locally, create a local Kubernetes cluster with ' + '``sky local up``.')) @_add_click_options(_TASK_OPTIONS_WITH_NAME + _EXTRA_RESOURCES_OPTIONS) @click.option( '--idle-minutes-to-autostop', @@ -1142,6 +1144,11 @@ def launch( backend: backends.Backend if backend_name == backends.LocalDockerBackend.NAME: backend = backends.LocalDockerBackend() + click.secho( + 'WARNING: LocalDockerBackend is deprecated and will be ' + 'removed in a future release. To run locally, create a local ' + 'Kubernetes cluster with `sky local up`.', + fg='yellow') elif backend_name == backends.CloudVmRayBackend.NAME: backend = backends.CloudVmRayBackend() else: From 3715be2865cd145f387996b7fc8b32fada61c431 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Mon, 30 Dec 2024 14:45:40 -0800 Subject: [PATCH 4/5] [docs] Add newer examples for AI tutorial and distributed training (#4509) * Update tutorial and distributed training examples. * Add examples link * add rdvz --- docs/source/getting-started/tutorial.rst | 48 +++++++++---------- docs/source/running-jobs/distributed-jobs.rst | 48 ++++++++++--------- 2 files changed, 49 insertions(+), 47 deletions(-) diff --git a/docs/source/getting-started/tutorial.rst b/docs/source/getting-started/tutorial.rst index 175f1391a6d..9b067be2876 100644 --- a/docs/source/getting-started/tutorial.rst +++ b/docs/source/getting-started/tutorial.rst @@ -2,19 +2,20 @@ Tutorial: AI Training ====================== -This example uses SkyPilot to train a Transformer-based language model from HuggingFace. +This example uses SkyPilot to train a GPT-like model (inspired by Karpathy's `minGPT `_) with Distributed Data Parallel (DDP) in PyTorch. -First, define a :ref:`task YAML ` with the resource requirements, the setup commands, +We define a :ref:`task YAML ` with the resource requirements, the setup commands, and the commands to run: .. code-block:: yaml - # dnn.yaml + # train.yaml - name: huggingface + name: minGPT-ddp resources: - accelerators: V100:4 + cpus: 4+ + accelerators: L4:4 # Or A100:8, H100:8 # Optional: upload a working directory to remote ~/sky_workdir. # Commands in "setup" and "run" will be executed under it. @@ -30,26 +31,21 @@ and the commands to run: # ~/.netrc: ~/.netrc setup: | - set -e # Exit if any command failed. - git clone https://github.com/huggingface/transformers/ || true - cd transformers - pip install . - cd examples/pytorch/text-classification - pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 + git clone --depth 1 https://github.com/pytorch/examples || true + cd examples + git filter-branch --prune-empty --subdirectory-filter distributed/minGPT-ddp + # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5). + uv pip install -r requirements.txt "numpy<2" "torch==1.12.1+cu113" --extra-index-url https://download.pytorch.org/whl/cu113 run: | - set -e # Exit if any command failed. - cd transformers/examples/pytorch/text-classification - python run_glue.py \ - --model_name_or_path bert-base-cased \ - --dataset_name imdb \ - --do_train \ - --max_seq_length 128 \ - --per_device_train_batch_size 32 \ - --learning_rate 2e-5 \ - --max_steps 50 \ - --output_dir /tmp/imdb/ --overwrite_output_dir \ - --fp16 + cd examples/mingpt + export LOGLEVEL=INFO + + echo "Starting minGPT-ddp training" + + torchrun \ + --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \ + main.py .. tip:: @@ -57,11 +53,15 @@ and the commands to run: learn about how to use them to mount local dirs/files or object store buckets (S3, GCS, R2) into your cluster, see :ref:`sync-code-artifacts`. +.. tip:: + + The ``SKYPILOT_NUM_GPUS_PER_NODE`` environment variable is automatically set by SkyPilot to the number of GPUs per node. See :ref:`env-vars` for more. + Then, launch training: .. code-block:: console - $ sky launch -c lm-cluster dnn.yaml + $ sky launch -c mingpt train.yaml This will provision the cheapest cluster with the required resources, execute the setup commands, then execute the run commands. diff --git a/docs/source/running-jobs/distributed-jobs.rst b/docs/source/running-jobs/distributed-jobs.rst index f6c8cba9c9d..7c3421aa276 100644 --- a/docs/source/running-jobs/distributed-jobs.rst +++ b/docs/source/running-jobs/distributed-jobs.rst @@ -6,39 +6,40 @@ Distributed Multi-Node Jobs SkyPilot supports multi-node cluster provisioning and distributed execution on many nodes. -For example, here is a simple PyTorch Distributed training example: +For example, here is a simple example to train a GPT-like model (inspired by Karpathy's `minGPT `_) across 2 nodes with Distributed Data Parallel (DDP) in PyTorch. .. code-block:: yaml - :emphasize-lines: 6-6,21-21,23-26 + :emphasize-lines: 6,19,23-24,26 - name: resnet-distributed-app + name: minGPT-ddp - resources: - accelerators: A100:8 + resources: + accelerators: A100:8 - num_nodes: 2 + num_nodes: 2 - setup: | - pip3 install --upgrade pip - git clone https://github.com/michaelzhiluo/pytorch-distributed-resnet - cd pytorch-distributed-resnet - # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5). - pip3 install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 - mkdir -p data && mkdir -p saved_models && cd data && \ - wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz - tar -xvzf cifar-10-python.tar.gz + setup: | + git clone --depth 1 https://github.com/pytorch/examples || true + cd examples + git filter-branch --prune-empty --subdirectory-filter distributed/minGPT-ddp + # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5). + uv pip install -r requirements.txt "numpy<2" "torch==1.12.1+cu113" --extra-index-url https://download.pytorch.org/whl/cu113 - run: | - cd pytorch-distributed-resnet + run: | + cd examples/mingpt + export LOGLEVEL=INFO + + MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1) + echo "Starting distributed training, head node: $MASTER_ADDR" - MASTER_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1` - torchrun \ + torchrun \ --nnodes=$SKYPILOT_NUM_NODES \ - --master_addr=$MASTER_ADDR \ --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \ - --node_rank=$SKYPILOT_NODE_RANK \ - --master_port=12375 \ - resnet_ddp.py --num_epochs 20 + --master_addr=$MASTER_ADDR \ + --node_rank=${SKYPILOT_NODE_RANK} \ + --master_port=8008 \ + main.py + In the above, @@ -55,6 +56,7 @@ In the above, ulimit -n 65535 +You can find more `distributed training examples `_ (including `using rdvz backend for pytorch `_) in our `GitHub repository `_. Environment variables ----------------------------------------- From 7c3340393f9450478c18f1e3898fca33bd5df1ac Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Mon, 30 Dec 2024 15:03:21 -0800 Subject: [PATCH 5/5] [k8s] Fix L40 detection for nvidia GFD labels (#4511) Fix L40 detection --- sky/provision/kubernetes/utils.py | 5 +++-- .../kubernetes/test_gpu_label_formatters.py | 22 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 tests/unit_tests/kubernetes/test_gpu_label_formatters.py diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index 87ccd6b105d..487868d1d9e 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -340,14 +340,15 @@ def get_accelerator_from_label_value(cls, value: str) -> str: """ canonical_gpu_names = [ 'A100-80GB', 'A100', 'A10G', 'H100', 'K80', 'M60', 'T4g', 'T4', - 'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L4' + 'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L40', 'L4' ] for canonical_name in canonical_gpu_names: # A100-80G accelerator is A100-SXM-80GB or A100-PCIE-80GB if canonical_name == 'A100-80GB' and re.search( r'A100.*-80GB', value): return canonical_name - elif canonical_name in value: + # Use word boundary matching to prevent substring matches + elif re.search(rf'\b{re.escape(canonical_name)}\b', value): return canonical_name # If we didn't find a canonical name: diff --git a/tests/unit_tests/kubernetes/test_gpu_label_formatters.py b/tests/unit_tests/kubernetes/test_gpu_label_formatters.py new file mode 100644 index 00000000000..cd7337dc7a1 --- /dev/null +++ b/tests/unit_tests/kubernetes/test_gpu_label_formatters.py @@ -0,0 +1,22 @@ +"""Tests for GPU label formatting in Kubernetes integration. + +Tests verify correct GPU detection from Kubernetes labels. +""" +import pytest + +from sky.provision.kubernetes.utils import GFDLabelFormatter + + +def test_gfd_label_formatter(): + """Test word boundary regex matching in GFDLabelFormatter.""" + # Test various GPU name patterns + test_cases = [ + ('NVIDIA-L4-24GB', 'L4'), + ('NVIDIA-L40-48GB', 'L40'), + ('NVIDIA-L400', 'L400'), # Should not match L4 or L40 + ('NVIDIA-L4', 'L4'), + ('L40-GPU', 'L40'), + ] + for input_value, expected in test_cases: + result = GFDLabelFormatter.get_accelerator_from_label_value(input_value) + assert result == expected, f'Failed for {input_value}'