From 609dfac29452e4842c62168c9c9036f38976a57d Mon Sep 17 00:00:00 2001 From: Yunlong Liu Date: Mon, 11 Nov 2024 22:59:24 -0800 Subject: [PATCH] Adds a flag to control proxy env checking. name typo fix. Fixes comments. --- jax/_src/clusters/cluster.py | 6 ------ jax/_src/distributed.py | 38 ++++++++++++++++++++++++------------ 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 2fb13fde72cf..69ef77a6421d 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -49,12 +49,6 @@ def auto_detect_unset_distributed_params(cls, initialization_timeout: int | None, ) -> tuple[str | None, int | None, int | None, Sequence[int] | None]: - - if all(p is not None for p in (coordinator_address, num_processes, - process_id, local_device_ids)): - return (coordinator_address, num_processes, process_id, - local_device_ids) - # First, we check the spec detection method because it will ignore submitted values # If if succeeds. if cluster_detection_method is not None: diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 5b9130fc0455..f80f90bde186 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -27,6 +27,13 @@ logger = logging.getLogger(__name__) +_CHECK_PROXY_ENVS = config.bool_flag( + name="jax_check_proxy_envs", + default=True, + help="Checks proxy vars in user envs and emit warnings.", +) + + class State: process_id: int = 0 num_processes: int = 1 @@ -55,16 +62,17 @@ def initialize(self, if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')): local_device_ids = list(map(int, env_ids.split(","))) - (coordinator_address, num_processes, process_id, local_device_ids) = ( - clusters.ClusterEnv.auto_detect_unset_distributed_params( - coordinator_address, - num_processes, - process_id, - local_device_ids, - cluster_detection_method, - initialization_timeout, - ) - ) + if None in (coordinator_address, num_processes, process_id, local_device_ids): + (coordinator_address, num_processes, process_id, local_device_ids) = ( + clusters.ClusterEnv.auto_detect_unset_distributed_params( + coordinator_address, + num_processes, + process_id, + local_device_ids, + cluster_detection_method, + initialization_timeout, + ) + ) if coordinator_address is None: raise ValueError('coordinator_address should be defined.') @@ -92,8 +100,10 @@ def initialize(self, self.process_id = process_id - # Emit a warning about PROXY variables if they are in the user's env: - proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()] + proxy_vars = [] + if _CHECK_PROXY_ENVS.value: + proxy_vars = [key for key in os.environ.keys() + if '_proxy' in key.lower()] if len(proxy_vars) > 0: vars = " ".join(proxy_vars) + ". " @@ -179,7 +189,9 @@ def initialize(coordinator_address: str | None = None, ``cluster_detection_method="mpi4py"`` to bootstrap the required arguments. Otherwise, you must provide the ``coordinator_address``, - ``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`. + ``num_processes``, ``process_id``, and ``local_device_ids`` arguments + to :func:`~jax.distributed.initialize`. When all four arguments are provided, cluster + environment auto detection will be skipped. Please note: on some systems, particularly HPC clusters that only access external networks through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to