Skip to content

Commit

Permalink
formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjbrowning committed Jan 16, 2025
1 parent debb3fc commit fec2534
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
34 changes: 19 additions & 15 deletions sphericart-jax/python/sphericart/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ def get_minimum_cuda_version_for_jax(jax_version):
"""
# Define ranges of JAX versions and their corresponding minimum CUDA versions
version_ranges = [
(version.parse("0.4.26"), version.parse("999.999.999"),
(12, 1)), # JAX 0.4.26 and later: CUDA 12.1+
(version.parse("0.4.11"), version.parse("0.4.25"),
(11, 8)), # JAX 0.4.11 - 0.4.25: CUDA 11.8+
(
version.parse("0.4.26"),
version.parse("999.999.999"),
(12, 1),
), # JAX 0.4.26 and later: CUDA 12.1+
(
version.parse("0.4.11"),
version.parse("0.4.25"),
(11, 8),
), # JAX 0.4.11 - 0.4.25: CUDA 11.8+
]

# Parse the current JAX version
Expand All @@ -37,30 +43,28 @@ def get_minimum_cuda_version_for_jax(jax_version):

# register the operations to xla
for _name, _value in sphericart_jax_cpu.registrations().items():
jax.lib.xla_client.register_custom_call_target(
_name, _value, platform="cpu")
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="cpu")

try:
from .lib import sphericart_jax_cuda
from .lib.sphericart_jax_cuda import get_cuda_runtime_version

cuda_version = get_cuda_runtime_version()
cuda_version = (cuda_version['major'], cuda_version['minor'])
cuda_version = (cuda_version["major"], cuda_version["minor"])
jax_version = jax.__version__
required_version = get_minimum_cuda_version_for_jax(jax_version)
if (cuda_version < required_version):
if cuda_version < required_version:
raise RuntimeError(
f"Incompatible setup detected:\n"
f"- Installed CUDA version: {cuda_version[0]}.{cuda_version[1]}\n"
f"- Installed JAX version: {jax_version}\n"
f"- Minimum required CUDA version for JAX {jax_version}: {required_version[0]}.{required_version[1]}\n"
f"Please upgrade your CUDA Toolkit to meet the requirements."
f"Installed CUDA Toolkit: {cuda_version[0]}.{cuda_version[1]} \
is not compatible with installed JAX version {jax_version}. \
Minimum required CUDA Toolkit for your JAX version \
is {required_version[0]}.{required_version[1]}. \
Please upgrade your CUDA Toolkit to meet the requirements."
)

# register the operations to xla
for _name, _value in sphericart_jax_cuda.registrations().items():
jax.lib.xla_client.register_custom_call_target(
_name, _value, platform="gpu")
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu")

except ImportError:
pass
2 changes: 1 addition & 1 deletion sphericart-jax/python/sphericart/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ def build_sph_descriptor(a, b):
"Trying to use sphericart-jax on CUDA, "
"but sphericart-jax was installed without CUDA support. "
"Please re-install sphericart-jax with CUDA support"
)
)

0 comments on commit fec2534

Please sign in to comment.