|
1 | 1 | import jax
|
| 2 | +from packaging import version |
| 3 | +import warnings |
| 4 | + |
2 | 5 | from .lib import sphericart_jax_cpu
|
3 | 6 | from .spherical_harmonics import spherical_harmonics, solid_harmonics # noqa: F401
|
4 | 7 |
|
5 | 8 |
|
| 9 | +def get_minimum_cuda_version_for_jax(jax_version): |
| 10 | + """ |
| 11 | + Get the minimum required CUDA version for a specific JAX version. |
| 12 | +
|
| 13 | + Args: |
| 14 | + jax_version (str): Installed JAX version, e.g., '0.4.11'. |
| 15 | +
|
| 16 | + Returns: |
| 17 | + tuple: Minimum required CUDA version as (major, minor), e.g., (11, 8). |
| 18 | + """ |
| 19 | + # Define ranges of JAX versions and their corresponding minimum CUDA versions |
| 20 | + version_ranges = [ |
| 21 | + ( |
| 22 | + version.parse("0.4.26"), |
| 23 | + version.parse("999.999.999"), |
| 24 | + (12, 1), |
| 25 | + ), # JAX 0.4.26 and later: CUDA 12.1+ |
| 26 | + ( |
| 27 | + version.parse("0.4.11"), |
| 28 | + version.parse("0.4.25"), |
| 29 | + (11, 8), |
| 30 | + ), # JAX 0.4.11 - 0.4.25: CUDA 11.8+ |
| 31 | + ] |
| 32 | + |
| 33 | + jax_ver = version.parse(jax_version) |
| 34 | + |
| 35 | + # Find the appropriate CUDA version range |
| 36 | + for start, end, cuda_version in version_ranges: |
| 37 | + if start <= jax_ver <= end: |
| 38 | + return cuda_version |
| 39 | + |
| 40 | + raise ValueError(f"Unsupported JAX version: {jax_version}") |
| 41 | + |
| 42 | + |
6 | 43 | # register the operations to xla
|
7 | 44 | for _name, _value in sphericart_jax_cpu.registrations().items():
|
8 | 45 | jax.lib.xla_client.register_custom_call_target(_name, _value, platform="cpu")
|
9 | 46 |
|
| 47 | +has_sphericart_jax_cuda = False |
10 | 48 | try:
|
11 | 49 | from .lib import sphericart_jax_cuda
|
12 | 50 |
|
| 51 | + has_sphericart_jax_cuda = True |
13 | 52 | # register the operations to xla
|
14 | 53 | for _name, _value in sphericart_jax_cuda.registrations().items():
|
15 | 54 | jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu")
|
16 |
| - |
17 | 55 | except ImportError:
|
| 56 | + has_sphericart_jax_cuda = False |
18 | 57 | pass
|
| 58 | + |
| 59 | +if has_sphericart_jax_cuda: |
| 60 | + from .lib.sphericart_jax_cuda import get_cuda_runtime_version |
| 61 | + |
| 62 | + # check the jaxlib version is suitable for the host cudatoolkit. |
| 63 | + cuda_version = get_cuda_runtime_version() |
| 64 | + cuda_version = (cuda_version["major"], cuda_version["minor"]) |
| 65 | + jax_version = jax.__version__ |
| 66 | + required_version = get_minimum_cuda_version_for_jax(jax_version) |
| 67 | + if cuda_version < required_version: |
| 68 | + warnings.warn( |
| 69 | + "The installed CUDA Toolkit version is " |
| 70 | + f"{cuda_version[0]}.{cuda_version[1]}, which " |
| 71 | + f"is not compatible with the installed JAX version {jax_version}. " |
| 72 | + "The minimum required CUDA Toolkit for your JAX version " |
| 73 | + f"is {required_version[0]}.{required_version[1]}. " |
| 74 | + "Please upgrade your CUDA Toolkit to meet the requirements, or ", |
| 75 | + "downgrade JAX to a compatible version.", |
| 76 | + stacklevel=2, |
| 77 | + ) |
0 commit comments