-
Notifications
You must be signed in to change notification settings - Fork 563
Description
🐛 Bug
The maybe_get_jax() function in torch_xla/_internal/jax_workarounds.py merged in #9521 currently emits a warning message when JAX is not installed. While informative, this warning results in an excessive number of log lines during training workloads, cluttering the logs and making it difficult to spot genuinely important debug messages.
To Reproduce
Steps to reproduce the behavior:
- Create Python Virtual Environment (python3 -m venv ptxla_28) on Ubuntu 22.04
- pip install torch==2.8.0 torchvision; pip install torch_xla==2.8.0
- Create small python script(let's call it trigger_warning.py)
import sys
sys.path.insert(0, 'ptxla_28/lib/python3.10/site-packages')
from torch_xla._internal.jax_workarounds import maybe_get_jax
maybe_get_jax()
- execute the script
bash -c "source ptxla_28/bin/activate && python trigger_warning.py"
- You should be able to see the warning message like below
WARNING:root:Defaulting to PJRT_DEVICE=CPU
WARNING:root:You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]
Expected behavior
Remove or suppress this warning message, or limit it to display only once per process/session instead of for every invocation.
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
- torch_xla version: 2.8.0
- Relevant Code:
def maybe_get_jax():
Additional context
The current behavior results in thousands of lines of repeated warnings when running workloads that do not require JAX, negatively impacting developer experience. Reducing or removing this warning will significantly clean up logs for users running long or large-scale training jobs, improving usability without sacrificing relevant error reporting.