Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion python/ray/_private/runtime_env/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ async def download_and_unpack_package(
gcs_client: Optional[GcsClient] = None,
logger: Optional[logging.Logger] = default_logger,
overwrite: bool = False,
transport_params: Optional[dict] = None,
) -> str:
"""Download the package corresponding to this URI and unpack it if zipped.

Expand All @@ -792,6 +793,11 @@ async def download_and_unpack_package(
gcs_client: Client to use for downloading from the GCS.
logger: The logger to use.
overwrite: If True, overwrite the existing package.
transport_params: Optional transport parameters for smart_open. These parameters
will be passed to smart_open when downloading from remote storage.
For protocols with default configurations (S3, Azure, GCS, ABFSS),
these parameters will be merged with defaults, with custom parameters
taking precedence.

Returns:
Path to the local directory containing the unpacked package files.
Expand Down Expand Up @@ -874,7 +880,12 @@ async def download_and_unpack_package(
else:
return str(pkg_file)
elif protocol in Protocol.remote_protocols():
protocol.download_remote_uri(source_uri=pkg_uri, dest_file=pkg_file)
# Use provided transport_params or fall back to None
protocol.download_remote_uri(
source_uri=pkg_uri,
dest_file=pkg_file,
transport_params=transport_params,
)

if pkg_file.suffix in [".zip", ".jar"]:
unzip_package(
Expand Down
45 changes: 42 additions & 3 deletions python/ray/_private/runtime_env/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,20 @@ def open_file(uri, mode, *, transport_params=None):
return open_file, None

@classmethod
def download_remote_uri(cls, protocol: str, source_uri: str, dest_file: str):
def download_remote_uri(
cls,
protocol: str,
source_uri: str,
dest_file: str,
transport_params: dict = None,
):
"""Download file from remote URI to destination file.

Args:
protocol: The protocol to use for downloading (e.g., 's3', 'https').
source_uri: The source URI to download from.
dest_file: The destination file path to save to.
transport_params: Optional transport parameters for smart_open.

Raises:
ImportError: If required dependencies for the protocol are not installed.
Expand Down Expand Up @@ -275,10 +282,40 @@ def open_file(uri, mode, *, transport_params=None):
+ cls._MISSING_DEPENDENCIES_WARNING
)

if transport_params:
tp = cls._merge_transport_params(tp, transport_params)

with open_file(source_uri, "rb", transport_params=tp) as fin:
with open(dest_file, "wb") as fout:
fout.write(fin.read())

@classmethod
def _merge_transport_params(cls, default_params, custom_params: dict):
"""
Merge custom transport parameters with default parameters.
Custom parameters take precedence over default parameters.
Performs a deep merge for nested dictionaries.
"""
if custom_params is None:
return default_params

if default_params is None:
return custom_params

merged = default_params.copy()

for key, value in custom_params.items():
if (
key in merged
and isinstance(merged[key], dict)
and isinstance(value, dict)
):
merged[key] = cls._merge_transport_params(merged[key], value)
else:
merged[key] = value

return merged


Protocol = enum.Enum(
"Protocol",
Expand All @@ -298,8 +335,10 @@ def _remote_protocols(cls):
Protocol.remote_protocols = _remote_protocols


def _download_remote_uri(self, source_uri, dest_file):
return ProtocolsProvider.download_remote_uri(self.value, source_uri, dest_file)
def _download_remote_uri(self, source_uri, dest_file, transport_params=None):
return ProtocolsProvider.download_remote_uri(
self.value, source_uri, dest_file, transport_params
)


Protocol.download_remote_uri = _download_remote_uri
9 changes: 8 additions & 1 deletion python/ray/_private/runtime_env/py_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,16 @@ async def create(
context: RuntimeEnvContext,
logger: Optional[logging.Logger] = default_logger,
) -> int:
# Extract transport_params from runtime_env config if available
config = runtime_env.get("config")
transport_params = config.get("transport_params") if config else None

module_dir = await download_and_unpack_package(
uri, self._resources_dir, self._gcs_client, logger=logger
uri,
self._resources_dir,
self._gcs_client,
logger=logger,
transport_params=transport_params,
)

if is_whl_uri(uri):
Expand Down
5 changes: 5 additions & 0 deletions python/ray/_private/runtime_env/working_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,17 @@ async def create(
context: RuntimeEnvContext,
logger: logging.Logger = default_logger,
) -> int:
# Extract transport_params from runtime_env config if available
config = runtime_env.get("config")
transport_params = config.get("transport_params") if config else None

local_dir = await download_and_unpack_package(
uri,
self._resources_dir,
self._gcs_client,
logger=logger,
overwrite=True,
transport_params=transport_params,
)
return get_directory_size_bytes(local_dir)

Expand Down
29 changes: 28 additions & 1 deletion python/ray/runtime_env/runtime_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,32 @@ class RuntimeEnvConfig(dict):
eager_install: Indicates whether to install the runtime environment
on the cluster at `ray.init()` time, before the workers are leased.
This flag is set to `True` by default.
transport_params: Optional transport parameters for downloading packages
from remote storage (e.g., S3, GCS, Azure). This is a dictionary
of type Dict[str, Any] that will be passed to smart_open when
downloading remote URIs.
"""

known_fields: Set[str] = {"setup_timeout_seconds", "eager_install", "log_files"}
known_fields: Set[str] = {
"setup_timeout_seconds",
"eager_install",
"log_files",
"transport_params",
}

_default_config: Dict = {
"setup_timeout_seconds": DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS,
"eager_install": True,
"log_files": [],
"transport_params": None,
}

def __init__(
self,
setup_timeout_seconds: int = DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS,
eager_install: bool = True,
log_files: Optional[List[str]] = None,
transport_params: Optional[Dict[str, Any]] = None,
):
super().__init__()
if not isinstance(setup_timeout_seconds, int):
Expand Down Expand Up @@ -91,6 +102,17 @@ def __init__(

self["log_files"] = log_files

if transport_params is not None:
if not isinstance(transport_params, dict):
raise TypeError(
"transport_params must be a dict or None, got "
f"{transport_params} with type {type(transport_params)}."
)
else:
transport_params = self._default_config["transport_params"]

self["transport_params"] = transport_params

@staticmethod
def parse_and_validate_runtime_env_config(
config: Union[Dict, "RuntimeEnvConfig"]
Expand Down Expand Up @@ -137,6 +159,7 @@ def from_proto(cls, runtime_env_config: ProtoRuntimeEnvConfig):
# assign the default value to setup_timeout_seconds.
if setup_timeout_seconds == 0:
setup_timeout_seconds = cls._default_config["setup_timeout_seconds"]

return cls(
setup_timeout_seconds=setup_timeout_seconds,
eager_install=runtime_env_config.eager_install,
Expand Down Expand Up @@ -277,6 +300,10 @@ class MyClass:
config: config for runtime environment. Either
a dict or a RuntimeEnvConfig. Field: (1) setup_timeout_seconds, the
timeout of runtime environment creation, timeout is in seconds.
(2) transport_params, optional transport parameters for downloading
packages from remote storage (e.g., S3, GCS, Azure). This is a
dictionary of type Dict[str, Any] that will be passed to smart_open
when downloading remote URIs.
image_uri: URI to a container image. The Ray worker process runs
in a container with this image. This parameter only works alone,
or with the ``config`` or ``env_vars`` parameters.
Expand Down