From 337ec47d138749a3cf68e47a50b5a37fc6fc5e9b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 16 May 2022 00:24:04 +0000 Subject: [PATCH] Fix jax 0.3.11 GPU breakge when used with jaxlib 0.3.10. --- jax/_src/lib/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index ea9b2d9a844f..f45bf729551d 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -183,7 +183,7 @@ def _parse_version(v: str) -> Tuple[int, ...]: hip_linalg = None try: - import jaxlib.cuda_linalg as gpu_linalg # pytype: disable=import-error + import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error except ImportError: gpu_linalg = None