Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNIST tutorial broken for Colab TPU #4122

Open
rcrowe-google opened this issue Aug 7, 2024 · 0 comments
Open

MNIST tutorial broken for Colab TPU #4122

rcrowe-google opened this issue Aug 7, 2024 · 0 comments

Comments

@rcrowe-google
Copy link

https://colab.sandbox.google.com/github/google/flax/blob/main/docs/nnx/mnist_tutorial.ipynb#scrollTo=6

The MNIST tutorial on the NNX website throws the following error when trying to instantiate the model:

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
    [... skipping hidden 1 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py](https://localhost:8080/#) in _init_backend(platform)
    972   logger.debug("Initializing backend '%s'", platform)
--> 973   backend = registration.factory()
    974   # TODO(skye): consider raising more descriptive errors directly from backend

20 frames
XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.54).

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
    [... skipping hidden 16 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py](https://localhost:8080/#) in backends()
    901           else:
    902             err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
--> 903           raise RuntimeError(err_msg)
    904 
    905     assert _default_backend is not None

RuntimeError: Unable to initialize backend 'tpu': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.54). (set JAX_PLATFORMS='' to automatically choose an available backend)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant