You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've experienced the issue in #24821, but I decided to create a new issue, since I've figured out the problem. Here is the code example that fails on the latest jax[tpu] with libtpu_nightly-0.1.dev20241010+nightly.cleanup
Fails with a bunch of various issues, it is either that TPU cannot initialize itself:
/home/user/.local/lib/python3.11/site-packages/jax/__init__.py:31: UserWarning: cloud_tpu_init failed: AttributeError("module 'libtpu' has no attribute 'get_library_path'")
This a JAX bug; please report an issue at https://github.com/jax-ml/jax/issues
_warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report "
Or, sometimes, with
RuntimeError: Unable to initialize backend 'cpu': ALREADY_EXISTS: Config key cpu:local_topology/cpu/3 already exists.
Additional GRPC error information from remote target coordination_service while calling /tensorflow.CoordinationService/InsertKeyValue:
:{"created":"@1731217318.560649460","description":"Error received from peer ipv4:some ip:8476","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Config key cpu:local_topology/cpu/3 already exists.","grpc_status":6} (set JAX_PLATFORMS='' to automatically choose an available backend)
Or it could even fail with some very internal errors like
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /home/arst/.local/lib/python3.11/site-packages/equinox/_jit.py:55:14: error: All components of the offset index in a gather op must either be a offset dimension or explicitly collapsed or explicitly batched; got len(slice_sizes)=4, output_slice_sizes=2, collapsed_slice_dims=1,2, operand_batching_dims=.:
...
@ 0x7fb4ce901fe4 (unknown)
@ 0x7fb67c825f8d xla::InitializeArgsAndCompile()
@ 0x7fb67c8266f6 xla::PjRtCApiClient::Compile()
@ 0x7fb68255566c xla::ifrt::PjRtLoadedExecutable::Create()
@ 0x7fb682550a51 xla::ifrt::PjRtCompiler::Compile()
@ 0x7fb681ce452e xla::PyClient::CompileIfrtProgram()
@ 0x7fb681ce532e xla::PyClient::Compile()
This solution is to just explicitly downgrade libtpu to libtpu_nightly==0.1.dev20241009; some old versions also don't work, but this one seems to be OK.
System info (python version, jaxlib version, accelerator, etc.)
# python3.11 -m pip freeze
jax==0.4.35
jaxlib==0.4.35
libtpu==0.0.2
libtpu_nightly==0.1.dev20241010+nightly.cleanup
... (a bunch of other unrelated packages)
# python3.11 --version
3.11.10
# neofetch
OS: Ubuntu 20.04.4 LTS x86_64
Host: Google Compute Engine
Kernel: 5.13.0-1023-gcp
CPU: AMD EPYC 7B12 (240) @ 2.249GHz
The machine is from TRC, v4-32
The text was updated successfully, but these errors were encountered:
Hi @knyazer , since the OS is Ubuntu 20.04, it looks like maybe you created the VM with --version=tpu-ubuntu2004-base instead of tpu-ubuntu2204-base, could that be right? (see docs here). I haven't been able to repro the failure with tpu-ubuntu2204-base. If you used a different version, could you try using that?
If there are instructions somewhere to use tpu-ubuntu2004-base please let us know so we can update that.
Description
I've experienced the issue in #24821, but I decided to create a new issue, since I've figured out the problem. Here is the code example that fails on the latest jax[tpu] with
libtpu_nightly-0.1.dev20241010+nightly.cleanup
Fails with a bunch of various issues, it is either that TPU cannot initialize itself:
Or, sometimes, with
Or it could even fail with some very internal errors like
This solution is to just explicitly downgrade libtpu to
libtpu_nightly==0.1.dev20241009
; some old versions also don't work, but this one seems to be OK.System info (python version, jaxlib version, accelerator, etc.)
The machine is from TRC, v4-32
The text was updated successfully, but these errors were encountered: