Skip to content

Commit

Permalink
linting.
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjbrowning committed Jan 16, 2025
1 parent d5d83f5 commit f98015b
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions sphericart-jax/python/sphericart/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,28 @@ def get_minimum_cuda_version_for_jax(jax_version):
has_sphericart_jax_cuda = False
try:
from .lib import sphericart_jax_cuda

has_sphericart_jax_cuda = True
# register the operations to xla
for _name, _value in sphericart_jax_cuda.registrations().items():
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
has_sphericart_jax_cuda=False
has_sphericart_jax_cuda = False
pass

if (has_sphericart_jax_cuda):
from .lib.sphericart_jax_cuda import get_cuda_runtime_version
#check the jaxlib version is suitable for the host cudatoolkit.
cuda_version = get_cuda_runtime_version()
cuda_version = (cuda_version["major"], cuda_version["minor"])
jax_version = jax.__version__
required_version = get_minimum_cuda_version_for_jax(jax_version)
if cuda_version < required_version:
raise RuntimeError(
f"Installed CUDA Toolkit: {cuda_version[0]}.{cuda_version[1]} \
is not compatible with installed JAX version {jax_version}. \
Minimum required CUDA Toolkit for your JAX version \
is {required_version[0]}.{required_version[1]}. \
Please upgrade your CUDA Toolkit to meet the requirements."
)
if has_sphericart_jax_cuda:
from .lib.sphericart_jax_cuda import get_cuda_runtime_version

# check the jaxlib version is suitable for the host cudatoolkit.
cuda_version = get_cuda_runtime_version()
cuda_version = (cuda_version["major"], cuda_version["minor"])
jax_version = jax.__version__
required_version = get_minimum_cuda_version_for_jax(jax_version)
if cuda_version < required_version:
raise RuntimeError(
f"Installed CUDA Toolkit: {cuda_version[0]}.{cuda_version[1]} \
is not compatible with installed JAX version {jax_version}. \
Minimum required CUDA Toolkit for your JAX version \
is {required_version[0]}.{required_version[1]}. \
Please upgrade your CUDA Toolkit to meet the requirements."
)

0 comments on commit f98015b

Please sign in to comment.