Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX/distributed] Small clean ups on warning emissions in jax.distributed. #24853

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yliu120
Copy link
Contributor

@yliu120 yliu120 commented Nov 12, 2024

Different users have different cluster setups. Some warnings might be meaningful for some users but not all. emitting those warnings regardless of the user's specific cluster env could lead to 1) noisy logging and 2) confusions.

This PR does two tiny things inside jax._src.distributed:

  1. Adds a flag to disable proxy env checks and also the warning emissions.
  2. Refactor the auto_detect_distributed_params method a bit to make it more readable and update the comments.
  • In the _src.distributed code, we should explicitly tell readers when all necesary args are provided, the auto-detect thing is skipped. Otherwise, users need to go in and look around to know this. Also update the comment of the method.

Copy link
Collaborator

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! A couple of small nits.

)
)
_any_is_none = lambda *args: any(arg is None for arg in args)
if _any_is_none(coordinator_address, num_processes, process_id, local_device_ids):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps:

if None in (coordinator_address, num_processes, process_id, local_device_ids):
  ...

might be shorter? No auxiliary function needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. This is much better. Done.

``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
``num_processes``, ``process_id``, and ``local_device_ids`` arguments
to :func:`~jax.distributed.initialize`. When all four arguments are provided, cluster
envs auto detection will be skipped.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: expand the abbreviation.
environment auto detection will be skipped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

name typo fix.

Fixes comments.
@yliu120
Copy link
Contributor Author

yliu120 commented Nov 16, 2024

Commits squashed and rebase to upstream.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants