diff --git a/python/ray/_private/runtime_env/packaging.py b/python/ray/_private/runtime_env/packaging.py index 5691c6796f30..258d3b382adb 100644 --- a/python/ray/_private/runtime_env/packaging.py +++ b/python/ray/_private/runtime_env/packaging.py @@ -779,6 +779,7 @@ async def download_and_unpack_package( gcs_client: Optional[GcsClient] = None, logger: Optional[logging.Logger] = default_logger, overwrite: bool = False, + transport_params: Optional[dict] = None, ) -> str: """Download the package corresponding to this URI and unpack it if zipped. @@ -792,6 +793,11 @@ async def download_and_unpack_package( gcs_client: Client to use for downloading from the GCS. logger: The logger to use. overwrite: If True, overwrite the existing package. + transport_params: Optional transport parameters for smart_open. These parameters + will be passed to smart_open when downloading from remote storage. + For protocols with default configurations (S3, Azure, GCS, ABFSS), + these parameters will be merged with defaults, with custom parameters + taking precedence. Returns: Path to the local directory containing the unpacked package files. @@ -874,7 +880,12 @@ async def download_and_unpack_package( else: return str(pkg_file) elif protocol in Protocol.remote_protocols(): - protocol.download_remote_uri(source_uri=pkg_uri, dest_file=pkg_file) + # Use provided transport_params or fall back to None + protocol.download_remote_uri( + source_uri=pkg_uri, + dest_file=pkg_file, + transport_params=transport_params, + ) if pkg_file.suffix in [".zip", ".jar"]: unzip_package( diff --git a/python/ray/_private/runtime_env/protocol.py b/python/ray/_private/runtime_env/protocol.py index c293cc47034f..dac098a84930 100644 --- a/python/ray/_private/runtime_env/protocol.py +++ b/python/ray/_private/runtime_env/protocol.py @@ -233,13 +233,20 @@ def open_file(uri, mode, *, transport_params=None): return open_file, None @classmethod - def download_remote_uri(cls, protocol: str, source_uri: str, dest_file: str): + def download_remote_uri( + cls, + protocol: str, + source_uri: str, + dest_file: str, + transport_params: dict = None, + ): """Download file from remote URI to destination file. Args: protocol: The protocol to use for downloading (e.g., 's3', 'https'). source_uri: The source URI to download from. dest_file: The destination file path to save to. + transport_params: Optional transport parameters for smart_open. Raises: ImportError: If required dependencies for the protocol are not installed. @@ -275,10 +282,40 @@ def open_file(uri, mode, *, transport_params=None): + cls._MISSING_DEPENDENCIES_WARNING ) + if transport_params: + tp = cls._merge_transport_params(tp, transport_params) + with open_file(source_uri, "rb", transport_params=tp) as fin: with open(dest_file, "wb") as fout: fout.write(fin.read()) + @classmethod + def _merge_transport_params(cls, default_params, custom_params: dict): + """ + Merge custom transport parameters with default parameters. + Custom parameters take precedence over default parameters. + Performs a deep merge for nested dictionaries. + """ + if custom_params is None: + return default_params + + if default_params is None: + return custom_params + + merged = default_params.copy() + + for key, value in custom_params.items(): + if ( + key in merged + and isinstance(merged[key], dict) + and isinstance(value, dict) + ): + merged[key] = cls._merge_transport_params(merged[key], value) + else: + merged[key] = value + + return merged + Protocol = enum.Enum( "Protocol", @@ -298,8 +335,10 @@ def _remote_protocols(cls): Protocol.remote_protocols = _remote_protocols -def _download_remote_uri(self, source_uri, dest_file): - return ProtocolsProvider.download_remote_uri(self.value, source_uri, dest_file) +def _download_remote_uri(self, source_uri, dest_file, transport_params=None): + return ProtocolsProvider.download_remote_uri( + self.value, source_uri, dest_file, transport_params + ) Protocol.download_remote_uri = _download_remote_uri diff --git a/python/ray/_private/runtime_env/py_modules.py b/python/ray/_private/runtime_env/py_modules.py index 2a5dbe7423c4..bf40f3497072 100644 --- a/python/ray/_private/runtime_env/py_modules.py +++ b/python/ray/_private/runtime_env/py_modules.py @@ -205,9 +205,16 @@ async def create( context: RuntimeEnvContext, logger: Optional[logging.Logger] = default_logger, ) -> int: + # Extract transport_params from runtime_env config if available + config = runtime_env.get("config") + transport_params = config.get("transport_params") if config else None module_dir = await download_and_unpack_package( - uri, self._resources_dir, self._gcs_client, logger=logger + uri, + self._resources_dir, + self._gcs_client, + logger=logger, + transport_params=transport_params, ) if is_whl_uri(uri): diff --git a/python/ray/_private/runtime_env/working_dir.py b/python/ray/_private/runtime_env/working_dir.py index 27579e550e3a..38a6a611100a 100644 --- a/python/ray/_private/runtime_env/working_dir.py +++ b/python/ray/_private/runtime_env/working_dir.py @@ -192,12 +192,17 @@ async def create( context: RuntimeEnvContext, logger: logging.Logger = default_logger, ) -> int: + # Extract transport_params from runtime_env config if available + config = runtime_env.get("config") + transport_params = config.get("transport_params") if config else None + local_dir = await download_and_unpack_package( uri, self._resources_dir, self._gcs_client, logger=logger, overwrite=True, + transport_params=transport_params, ) return get_directory_size_bytes(local_dir) diff --git a/python/ray/runtime_env/runtime_env.py b/python/ray/runtime_env/runtime_env.py index 2b7d5d9be5f2..2b7ef8c7f415 100644 --- a/python/ray/runtime_env/runtime_env.py +++ b/python/ray/runtime_env/runtime_env.py @@ -42,14 +42,24 @@ class RuntimeEnvConfig(dict): eager_install: Indicates whether to install the runtime environment on the cluster at `ray.init()` time, before the workers are leased. This flag is set to `True` by default. + transport_params: Optional transport parameters for downloading packages + from remote storage (e.g., S3, GCS, Azure). This is a dictionary + of type Dict[str, Any] that will be passed to smart_open when + downloading remote URIs. """ - known_fields: Set[str] = {"setup_timeout_seconds", "eager_install", "log_files"} + known_fields: Set[str] = { + "setup_timeout_seconds", + "eager_install", + "log_files", + "transport_params", + } _default_config: Dict = { "setup_timeout_seconds": DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS, "eager_install": True, "log_files": [], + "transport_params": None, } def __init__( @@ -57,6 +67,7 @@ def __init__( setup_timeout_seconds: int = DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS, eager_install: bool = True, log_files: Optional[List[str]] = None, + transport_params: Optional[Dict[str, Any]] = None, ): super().__init__() if not isinstance(setup_timeout_seconds, int): @@ -91,6 +102,17 @@ def __init__( self["log_files"] = log_files + if transport_params is not None: + if not isinstance(transport_params, dict): + raise TypeError( + "transport_params must be a dict or None, got " + f"{transport_params} with type {type(transport_params)}." + ) + else: + transport_params = self._default_config["transport_params"] + + self["transport_params"] = transport_params + @staticmethod def parse_and_validate_runtime_env_config( config: Union[Dict, "RuntimeEnvConfig"] @@ -137,6 +159,7 @@ def from_proto(cls, runtime_env_config: ProtoRuntimeEnvConfig): # assign the default value to setup_timeout_seconds. if setup_timeout_seconds == 0: setup_timeout_seconds = cls._default_config["setup_timeout_seconds"] + return cls( setup_timeout_seconds=setup_timeout_seconds, eager_install=runtime_env_config.eager_install, @@ -277,6 +300,10 @@ class MyClass: config: config for runtime environment. Either a dict or a RuntimeEnvConfig. Field: (1) setup_timeout_seconds, the timeout of runtime environment creation, timeout is in seconds. + (2) transport_params, optional transport parameters for downloading + packages from remote storage (e.g., S3, GCS, Azure). This is a + dictionary of type Dict[str, Any] that will be passed to smart_open + when downloading remote URIs. image_uri: URI to a container image. The Ray worker process runs in a container with this image. This parameter only works alone, or with the ``config`` or ``env_vars`` parameters.