Skip to content

Commit e94a3bb

Browse files
CUDA Toolkit version + Jax incompatibility check (#166)
1 parent f7b5b71 commit e94a3bb

File tree

4 files changed

+83
-2
lines changed

4 files changed

+83
-2
lines changed

sphericart-jax/pyproject.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
name = "sphericart-jax"
33
dynamic = ["version"]
44
requires-python = ">=3.9"
5-
dependencies = ["jax >= 0.4.18"]
5+
dependencies = [
6+
"jax >= 0.4.18",
7+
"packaging",
8+
]
69

710
readme = "README.md"
811
license = {text = "Apache-2.0"}
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,77 @@
11
import jax
2+
from packaging import version
3+
import warnings
4+
25
from .lib import sphericart_jax_cpu
36
from .spherical_harmonics import spherical_harmonics, solid_harmonics # noqa: F401
47

58

9+
def get_minimum_cuda_version_for_jax(jax_version):
10+
"""
11+
Get the minimum required CUDA version for a specific JAX version.
12+
13+
Args:
14+
jax_version (str): Installed JAX version, e.g., '0.4.11'.
15+
16+
Returns:
17+
tuple: Minimum required CUDA version as (major, minor), e.g., (11, 8).
18+
"""
19+
# Define ranges of JAX versions and their corresponding minimum CUDA versions
20+
version_ranges = [
21+
(
22+
version.parse("0.4.26"),
23+
version.parse("999.999.999"),
24+
(12, 1),
25+
), # JAX 0.4.26 and later: CUDA 12.1+
26+
(
27+
version.parse("0.4.11"),
28+
version.parse("0.4.25"),
29+
(11, 8),
30+
), # JAX 0.4.11 - 0.4.25: CUDA 11.8+
31+
]
32+
33+
jax_ver = version.parse(jax_version)
34+
35+
# Find the appropriate CUDA version range
36+
for start, end, cuda_version in version_ranges:
37+
if start <= jax_ver <= end:
38+
return cuda_version
39+
40+
raise ValueError(f"Unsupported JAX version: {jax_version}")
41+
42+
643
# register the operations to xla
744
for _name, _value in sphericart_jax_cpu.registrations().items():
845
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="cpu")
946

47+
has_sphericart_jax_cuda = False
1048
try:
1149
from .lib import sphericart_jax_cuda
1250

51+
has_sphericart_jax_cuda = True
1352
# register the operations to xla
1453
for _name, _value in sphericart_jax_cuda.registrations().items():
1554
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu")
16-
1755
except ImportError:
56+
has_sphericart_jax_cuda = False
1857
pass
58+
59+
if has_sphericart_jax_cuda:
60+
from .lib.sphericart_jax_cuda import get_cuda_runtime_version
61+
62+
# check the jaxlib version is suitable for the host cudatoolkit.
63+
cuda_version = get_cuda_runtime_version()
64+
cuda_version = (cuda_version["major"], cuda_version["minor"])
65+
jax_version = jax.__version__
66+
required_version = get_minimum_cuda_version_for_jax(jax_version)
67+
if cuda_version < required_version:
68+
warnings.warn(
69+
"The installed CUDA Toolkit version is "
70+
f"{cuda_version[0]}.{cuda_version[1]}, which "
71+
f"is not compatible with the installed JAX version {jax_version}. "
72+
"The minimum required CUDA Toolkit for your JAX version "
73+
f"is {required_version[0]}.{required_version[1]}. "
74+
"Please upgrade your CUDA Toolkit to meet the requirements, or ",
75+
"downgrade JAX to a compatible version.",
76+
stacklevel=2,
77+
)

sphericart-jax/src/sphericart_jax_cuda.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
#include <mutex>
99
#include <tuple>
1010

11+
#include "dynamic_cuda.hpp"
1112
#include "sphericart_cuda.hpp"
1213
#include "sphericart/pybind11_kernel_helpers.hpp"
1314

15+
using namespace pybind11::literals;
16+
1417
struct SphDescriptor {
1518
std::int64_t n_samples;
1619
std::int64_t lmax;
@@ -115,11 +118,23 @@ pybind11::dict Registrations() {
115118
return dict;
116119
}
117120

121+
std::pair<int, int> getCUDARuntimeVersion() {
122+
int version;
123+
CUDART_SAFE_CALL(CUDART_INSTANCE.cudaRuntimeGetVersion(&version));
124+
int major = version / 1000;
125+
int minor = (version % 1000) / 10;
126+
return {major, minor};
127+
}
128+
118129
PYBIND11_MODULE(sphericart_jax_cuda, m) {
119130
m.def("registrations", &Registrations);
120131
m.def("build_sph_descriptor", [](std::int64_t n_samples, std::int64_t lmax) {
121132
return PackDescriptor(SphDescriptor{n_samples, lmax});
122133
});
134+
m.def("get_cuda_runtime_version", []() {
135+
auto [major, minor] = getCUDARuntimeVersion();
136+
return pybind11::dict("major"_a = major, "minor"_a = minor);
137+
});
123138
}
124139

125140
} // namespace cuda

sphericart/include/dynamic_cuda.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class CUDART {
9292
using cudaDeviceSynchronize_t = cudaError_t (*)(void);
9393
using cudaPointerGetAttributes_t = cudaError_t (*)(cudaPointerAttributes*, const void*);
9494
using cudaFree_t = cudaError_t (*)(void*);
95+
using cudaRuntimeGetVersion_t = cudaError_t (*)(int*);
9596

9697
cudaGetDeviceCount_t cudaGetDeviceCount;
9798
cudaGetDevice_t cudaGetDevice;
@@ -103,6 +104,7 @@ class CUDART {
103104
cudaDeviceSynchronize_t cudaDeviceSynchronize;
104105
cudaPointerGetAttributes_t cudaPointerGetAttributes;
105106
cudaFree_t cudaFree;
107+
cudaRuntimeGetVersion_t cudaRuntimeGetVersion;
106108

107109
CUDART() {
108110
#ifdef __linux__
@@ -124,6 +126,8 @@ class CUDART {
124126
cudaPointerGetAttributes =
125127
load<cudaPointerGetAttributes_t>(cudartHandle, "cudaPointerGetAttributes");
126128
cudaFree = load<cudaFree_t>(cudartHandle, "cudaFree");
129+
cudaRuntimeGetVersion =
130+
load<cudaRuntimeGetVersion_t>(cudartHandle, "cudaRuntimeGetVersion");
127131
}
128132
}
129133

0 commit comments

Comments
 (0)