Skip to content

Conversation

rajkthakur
Copy link

@rajkthakur rajkthakur commented Sep 3, 2025

Fixes #9569

Changes:

  • Add maybe_get_jax() function with env var logic in jax_workarounds.py
    • TORCH_XLA_ENABLE_JAX=1: enables JAX, warns if missing jax
    • Unset/empty: warns with guidance to set explicit env value for TORCH_XLA_ENABLE_JAX
    • TORCH_XLA_ENABLE_JAX=0/other: silent operation
  • Add test covering environment variable scenarios
  • Preserve existing JAX functionality when explicitly enabled
  • Improve user experience with clear guidance messages

@jeffhataws jeffhataws requested review from qihqi and bhavya01 September 3, 2025 17:56
@qihqi
Copy link
Collaborator

qihqi commented Sep 4, 2025

Rebasing to HEAD might fix the build issue.

jax.config.update('jax_use_shardy_partitioner', False)
return jax
except (ModuleNotFoundError, ImportError):
logging.warn('You are trying to use a feature that requires jax/pallas.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do something simpler.

Instead of the env variable, just remove this logging.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably fix the test if that's a concern? Or do you see some other issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just the tests. Although I don't need the logging in this function, so it can be simplified if so.

i.e. it's reasonable to push the logging to the callsites of maybe_get_jax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove excessive warn message in maybe_get_jax as it creates too many log lines during training
2 participants