Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion python/ray/_private/runtime_env/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ async def download_and_unpack_package(
gcs_client: Optional[GcsClient] = None,
logger: Optional[logging.Logger] = default_logger,
overwrite: bool = False,
transport_params: Optional[dict] = None,
) -> str:
"""Download the package corresponding to this URI and unpack it if zipped.

Expand All @@ -792,6 +793,11 @@ async def download_and_unpack_package(
gcs_client: Client to use for downloading from the GCS.
logger: The logger to use.
overwrite: If True, overwrite the existing package.
transport_params: Optional transport parameters for smart_open. These parameters
will be passed to smart_open when downloading from remote storage.
For protocols with default configurations (S3, Azure, GCS, ABFSS),
these parameters will be merged with defaults, with custom parameters
taking precedence.

Returns:
Path to the local directory containing the unpacked package files.
Expand Down Expand Up @@ -874,7 +880,12 @@ async def download_and_unpack_package(
else:
return str(pkg_file)
elif protocol in Protocol.remote_protocols():
protocol.download_remote_uri(source_uri=pkg_uri, dest_file=pkg_file)
# Use provided transport_params or fall back to None
protocol.download_remote_uri(
source_uri=pkg_uri,
dest_file=pkg_file,
transport_params=transport_params,
)

if pkg_file.suffix in [".zip", ".jar"]:
unzip_package(
Expand Down
57 changes: 49 additions & 8 deletions python/ray/_private/runtime_env/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,20 @@ def open_file(uri, mode, *, transport_params=None):
return open_file, None

@classmethod
def download_remote_uri(cls, protocol: str, source_uri: str, dest_file: str):
def download_remote_uri(
cls,
protocol: str,
source_uri: str,
dest_file: str,
transport_params: dict = None,
):
"""Download file from remote URI to destination file.

Args:
protocol: The protocol to use for downloading (e.g., 's3', 'https').
source_uri: The source URI to download from.
dest_file: The destination file path to save to.
transport_params: Optional transport parameters for smart_open.

Raises:
ImportError: If required dependencies for the protocol are not installed.
Expand All @@ -254,17 +261,28 @@ def download_remote_uri(cls, protocol: str, source_uri: str, dest_file: str):

def open_file(uri, mode, *, transport_params=None):
return open(uri, mode)

tp = transport_params

elif protocol == "https":
open_file, tp = cls._handle_https_protocol()
open_file, default_tp = cls._handle_https_protocol()
tp = cls._merge_transport_params(default_tp, transport_params)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HTTPS headers lost when custom headers provided

Medium Severity

When custom headers are passed via transport_params for HTTPS, the default User-Agent and Accept headers are completely replaced rather than merged. The _handle_https_protocol() returns None for default transport params, so _merge_transport_params(None, custom_params) just returns the custom params unchanged. Then inside open_file, params.update(transport_params) does a shallow update that replaces the entire headers dict. This breaks the documented GitLab use case where users add authentication headers but lose the default headers that some servers may require.

Additional Locations (1)

Fix in Cursor Fix in Web

elif protocol == "s3":
open_file, tp = cls._handle_s3_protocol()
open_file, default_tp = cls._handle_s3_protocol()
# Merge custom transport_params with defaults
tp = cls._merge_transport_params(default_tp, transport_params)
elif protocol == "gs":
open_file, tp = cls._handle_gs_protocol()
open_file, default_tp = cls._handle_gs_protocol()
# Merge custom transport_params with defaults
tp = cls._merge_transport_params(default_tp, transport_params)
elif protocol == "azure":
open_file, tp = cls._handle_azure_protocol()
open_file, default_tp = cls._handle_azure_protocol()
# Merge custom transport_params with defaults
tp = cls._merge_transport_params(default_tp, transport_params)
elif protocol == "abfss":
open_file, tp = cls._handle_abfss_protocol()
open_file, default_tp = cls._handle_abfss_protocol()
# Merge custom transport_params with defaults
tp = cls._merge_transport_params(default_tp, transport_params)
else:
try:
from smart_open import open as open_file
Expand All @@ -274,11 +292,32 @@ def open_file(uri, mode, *, transport_params=None):
f"to fetch {protocol.upper()} URIs. "
+ cls._MISSING_DEPENDENCIES_WARNING
)
tp = transport_params

with open_file(source_uri, "rb", transport_params=tp) as fin:
with open(dest_file, "wb") as fout:
fout.write(fin.read())

@classmethod
def _merge_transport_params(cls, default_params, custom_params: dict):
"""
Merge custom transport parameters with default parameters.
Custom parameters take precedence over default parameters.
"""
if custom_params is None:
return default_params

if default_params is None:
return custom_params

# Create a copy of default params to avoid modifying the original
merged = default_params.copy()

# Update with custom params, which take precedence
merged.update(custom_params)

return merged


Protocol = enum.Enum(
"Protocol",
Expand All @@ -298,8 +337,10 @@ def _remote_protocols(cls):
Protocol.remote_protocols = _remote_protocols


def _download_remote_uri(self, source_uri, dest_file):
return ProtocolsProvider.download_remote_uri(self.value, source_uri, dest_file)
def _download_remote_uri(self, source_uri, dest_file, transport_params=None):
return ProtocolsProvider.download_remote_uri(
self.value, source_uri, dest_file, transport_params
)


Protocol.download_remote_uri = _download_remote_uri
9 changes: 8 additions & 1 deletion python/ray/_private/runtime_env/py_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,16 @@ async def create(
context: RuntimeEnvContext,
logger: Optional[logging.Logger] = default_logger,
) -> int:
# Extract transport_params from runtime_env config if available
config = runtime_env.get("config")
transport_params = config.get("transport_params") if config else None

module_dir = await download_and_unpack_package(
uri, self._resources_dir, self._gcs_client, logger=logger
uri,
self._resources_dir,
self._gcs_client,
logger=logger,
transport_params=transport_params,
)

if is_whl_uri(uri):
Expand Down
5 changes: 5 additions & 0 deletions python/ray/_private/runtime_env/working_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,17 @@ async def create(
context: RuntimeEnvContext,
logger: logging.Logger = default_logger,
) -> int:
# Extract transport_params from runtime_env config if available
config = runtime_env.get("config")
transport_params = config.get("transport_params") if config else None

local_dir = await download_and_unpack_package(
uri,
self._resources_dir,
self._gcs_client,
logger=logger,
overwrite=True,
transport_params=transport_params,
)
return get_directory_size_bytes(local_dir)

Expand Down
29 changes: 28 additions & 1 deletion python/ray/runtime_env/runtime_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,32 @@ class RuntimeEnvConfig(dict):
eager_install: Indicates whether to install the runtime environment
on the cluster at `ray.init()` time, before the workers are leased.
This flag is set to `True` by default.
transport_params: Optional transport parameters for downloading packages
from remote storage (e.g., S3, GCS, Azure). This is a dictionary
of type Dict[str, Any] that will be passed to smart_open when
downloading remote URIs.
"""

known_fields: Set[str] = {"setup_timeout_seconds", "eager_install", "log_files"}
known_fields: Set[str] = {
"setup_timeout_seconds",
"eager_install",
"log_files",
"transport_params",
}

_default_config: Dict = {
"setup_timeout_seconds": DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS,
"eager_install": True,
"log_files": [],
"transport_params": None,
}

def __init__(
self,
setup_timeout_seconds: int = DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS,
eager_install: bool = True,
log_files: Optional[List[str]] = None,
transport_params: Optional[Dict[str, Any]] = None,
):
super().__init__()
if not isinstance(setup_timeout_seconds, int):
Expand Down Expand Up @@ -91,6 +102,17 @@ def __init__(

self["log_files"] = log_files

if transport_params is not None:
if not isinstance(transport_params, dict):
raise TypeError(
"transport_params must be a dict or None, got "
f"{transport_params} with type {type(transport_params)}."
)
else:
transport_params = self._default_config["transport_params"]

self["transport_params"] = transport_params

@staticmethod
def parse_and_validate_runtime_env_config(
config: Union[Dict, "RuntimeEnvConfig"]
Expand Down Expand Up @@ -137,6 +159,7 @@ def from_proto(cls, runtime_env_config: ProtoRuntimeEnvConfig):
# assign the default value to setup_timeout_seconds.
if setup_timeout_seconds == 0:
setup_timeout_seconds = cls._default_config["setup_timeout_seconds"]

return cls(
setup_timeout_seconds=setup_timeout_seconds,
eager_install=runtime_env_config.eager_install,
Expand Down Expand Up @@ -277,6 +300,10 @@ class MyClass:
config: config for runtime environment. Either
a dict or a RuntimeEnvConfig. Field: (1) setup_timeout_seconds, the
timeout of runtime environment creation, timeout is in seconds.
(2) transport_params, optional transport parameters for downloading
packages from remote storage (e.g., S3, GCS, Azure). This is a
dictionary of type Dict[str, Any] that will be passed to smart_open
when downloading remote URIs.
image_uri: URI to a container image. The Ray worker process runs
in a container with this image. This parameter only works alone,
or with the ``config`` or ``env_vars`` parameters.
Expand Down