Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in the latest libtpu nightly release #24829

Open
knyazer opened this issue Nov 11, 2024 · 1 comment
Open

Bug in the latest libtpu nightly release #24829

knyazer opened this issue Nov 11, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@knyazer
Copy link

knyazer commented Nov 11, 2024

Description

I've experienced the issue in #24821, but I decided to create a new issue, since I've figured out the problem. Here is the code example that fails on the latest jax[tpu] with libtpu_nightly-0.1.dev20241010+nightly.cleanup

import jax
import jax.numpy as jnp
import jax.random as jr
from jax.sharding import NamedSharding, PartitionSpec as P

def f(data,key):
    indices = jr.choice(key, 100, shape=(100,))
    indexer = jnp.ones((100,2), dtype=jnp.int32)[indices]
    return data[indexer[:, 0], indexer[:, 1]]

sharding = NamedSharding(jax.make_mesh((len(jax.devices()),), P("x")), P("x"))

sharded_data = jnp.zeros((256,10,10,3))
sharded_data = jax.device_put(sharded_data, sharding)
keys = jr.split(jr.key(0), 256)

jax.vmap(f)(sharded_data, keys)

Fails with a bunch of various issues, it is either that TPU cannot initialize itself:

/home/user/.local/lib/python3.11/site-packages/jax/__init__.py:31: UserWarning: cloud_tpu_init failed: AttributeError("module 'libtpu' has no attribute 'get_library_path'")
 This a JAX bug; please report an issue at https://github.com/jax-ml/jax/issues
  _warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report "

Or, sometimes, with

RuntimeError: Unable to initialize backend 'cpu': ALREADY_EXISTS: Config key cpu:local_topology/cpu/3 already exists.
Additional GRPC error information from remote target coordination_service while calling /tensorflow.CoordinationService/InsertKeyValue:
:{"created":"@1731217318.560649460","description":"Error received from peer ipv4:some ip:8476","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Config key cpu:local_topology/cpu/3 already exists.","grpc_status":6} (set JAX_PLATFORMS='' to automatically choose an available backend)

Or it could even fail with some very internal errors like

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /home/arst/.local/lib/python3.11/site-packages/equinox/_jit.py:55:14: error: All components of the offset index in a gather op must either be a offset dimension or explicitly collapsed or explicitly batched; got len(slice_sizes)=4, output_slice_sizes=2, collapsed_slice_dims=1,2, operand_batching_dims=.: 
...
   @     0x7fb4ce901fe4  (unknown)
    @     0x7fb67c825f8d  xla::InitializeArgsAndCompile()
    @     0x7fb67c8266f6  xla::PjRtCApiClient::Compile()
    @     0x7fb68255566c  xla::ifrt::PjRtLoadedExecutable::Create()
    @     0x7fb682550a51  xla::ifrt::PjRtCompiler::Compile()
    @     0x7fb681ce452e  xla::PyClient::CompileIfrtProgram()
    @     0x7fb681ce532e  xla::PyClient::Compile()

This solution is to just explicitly downgrade libtpu to libtpu_nightly==0.1.dev20241009; some old versions also don't work, but this one seems to be OK.

System info (python version, jaxlib version, accelerator, etc.)

# python3.11 -m pip freeze
jax==0.4.35
jaxlib==0.4.35
libtpu==0.0.2
libtpu_nightly==0.1.dev20241010+nightly.cleanup
... (a bunch of other unrelated packages)

# python3.11 --version
3.11.10

# neofetch
OS: Ubuntu 20.04.4 LTS x86_64
Host: Google Compute Engine
Kernel: 5.13.0-1023-gcp
CPU: AMD EPYC 7B12 (240) @ 2.249GHz

The machine is from TRC, v4-32

@knyazer knyazer added the bug Something isn't working label Nov 11, 2024
@knyazer knyazer changed the title Bug in the latest TPU nightly release Bug in the latest libtpu nightly release Nov 11, 2024
@emilyfertig
Copy link
Collaborator

Hi @knyazer , since the OS is Ubuntu 20.04, it looks like maybe you created the VM with --version=tpu-ubuntu2004-base instead of tpu-ubuntu2204-base, could that be right? (see docs here). I haven't been able to repro the failure with tpu-ubuntu2204-base. If you used a different version, could you try using that?

If there are instructions somewhere to use tpu-ubuntu2004-base please let us know so we can update that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants