Skip to content

Commit ec9bf36

Browse files
Merge branch 'main' into release-sphericart-1.0.3
2 parents 8ed4dd4 + 3704357 commit ec9bf36

File tree

15 files changed

+88
-81
lines changed

15 files changed

+88
-81
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
name: check Python build
6767
strategy:
6868
matrix:
69-
python-version: ['3.8', '3.12']
69+
python-version: ['3.9', '3.12']
7070
steps:
7171
- uses: actions/checkout@v3
7272

ci/containers/Containerfile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
FROM docker://nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
2+
3+
RUN apt update && apt install -y python3 python3-pip python3-venv git cmake && apt clean
4+
5+
ENV PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cu129
6+
7+
RUN python3 -m pip install --break-system-packages tox torch
8+
9+
ENV CUDA_HOME="/usr/local/cuda"
10+
ENV Torch_DIR=/usr/local/lib/python3.12/dist-packages/torch/share/cmake/Torch/

ci/pipeline.yml

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,54 @@ include:
22
- remote: 'https://gitlab.com/cscs-ci/recipes/-/raw/master/templates/v2/.ci-ext.yml'
33

44
stages:
5+
- build
56
- test
67

7-
test_job:
8+
build:
9+
stage: build
10+
extends: .container-builder-cscs-gh200
11+
before_script:
12+
- TAG_DOCKERFILE=`sha256sum $DOCKERFILE | head -c 8`
13+
- TAG=${TAG_DOCKERFILE}
14+
- export PERSIST_IMAGE_NAME=$CSCS_REGISTRY_PATH/base/mtt-base:$TAG
15+
- echo "BASE_IMAGE=$PERSIST_IMAGE_NAME" > build.env
16+
- 'echo "INFO: Building image $PERSIST_IMAGE_NAME"'
17+
artifacts:
18+
reports:
19+
dotenv: build.env
20+
variables:
21+
DOCKERFILE: ci/containers/Containerfile
22+
23+
test tox:
824
stage: test
925
extends: .container-runner-daint-gh200
10-
image: nvcr.io/nvidia/pytorch:24.12-py3
11-
timeout: 2h
26+
image: $BASE_IMAGE
27+
timeout: 1h
1228
script:
13-
- export CUDA_HOME="/usr/local/cuda"
14-
- python3 -m pip install --upgrade pip
15-
- python3 -m pip install tox
16-
- tox
17-
- export Torch_DIR=/usr/local/lib/python3.12/dist-packages/torch/share/cmake/Torch/
18-
- mkdir buildcpp
19-
- cd buildcpp
20-
- cmake .. -DSPHERICART_BUILD_TESTS=ON -DSPHERICART_OPENMP=ON -DSPHERICART_BUILD_EXAMPLES=ON -DSPHERICART_ENABLE_CUDA=ON -DSPHERICART_BUILD_TORCH=ON
21-
- cmake --build . --parallel
22-
- ctest
29+
- tox -vv -e tests
30+
- tox -vv -e torch-tests
31+
- LD_LIBRARY_PATH=$CUDA_HOME/lib64/:$LD_LIBRARY_PATH tox -vv -e jax-tests
32+
- tox -vv -e examples
33+
variables:
34+
SLURM_JOB_NUM_NODES: 1
35+
SLURM_PARTITION: normal
36+
SLURM_NTASKS: 1
37+
SLURM_TIMELIMIT: '00:30:00'
38+
GIT_STRATEGY: fetch
2339

40+
test cpp:
41+
stage: test
42+
extends: .container-runner-daint-gh200
43+
image: $BASE_IMAGE
44+
timeout: 1h
45+
script:
46+
- mkdir buildcpp
47+
- cmake -B buildcpp -S . -DSPHERICART_BUILD_TESTS=ON -DSPHERICART_OPENMP=ON -DSPHERICART_BUILD_EXAMPLES=ON -DSPHERICART_ENABLE_CUDA=ON -DSPHERICART_BUILD_TORCH=ON
48+
- cmake --build buildcpp --parallel
49+
- ctest --test-dir buildcpp --output-on-failure
2450
variables:
2551
SLURM_JOB_NUM_NODES: 1
2652
SLURM_PARTITION: normal
2753
SLURM_NTASKS: 1
28-
SLURM_TIMELIMIT: '02:30:00'
54+
SLURM_TIMELIMIT: '00:10:00'
2955
GIT_STRATEGY: fetch

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
[project]
22
name = "sphericart"
33
dynamic = ["version", "optional-dependencies"]
4-
requires-python = ">=3.8"
4+
requires-python = ">=3.9"
55
dependencies = ["numpy"]
66

77
readme = "README.md"
8-
license = {text = "Apache-2.0 or MIT"}
8+
license = "Apache-2.0 or MIT"
99
description = "Fast calculation of spherical harmonics"
1010
authors = [
1111
{name = "Filippo Bigi"},
@@ -18,7 +18,6 @@ keywords = ["spherical harmonics"]
1818
classifiers = [
1919
"Development Status :: 4 - Beta",
2020
"Intended Audience :: Science/Research",
21-
"License :: OSI Approved :: Apache Software License",
2221
"Operating System :: POSIX",
2322
"Operating System :: MacOS :: MacOS X",
2423
"Programming Language :: Python",

python/tests/test_vs_scipy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def scipy_real_sph(xyz, l, m): # noqa E741
1616
r = np.sqrt(x**2 + y**2 + z**2)
1717
theta = np.arccos(z / r)
1818
phi = np.arctan2(y, x)
19-
complex_sh_scipy_l_m = scipy.special.sph_harm(m, l, phi, theta)
20-
complex_sh_scipy_l_negm = scipy.special.sph_harm(-m, l, phi, theta)
19+
complex_sh_scipy_l_m = scipy.special.sph_harm_y(l, m, theta, phi)
20+
complex_sh_scipy_l_negm = scipy.special.sph_harm_y(l, -m, theta, phi)
2121

2222
if m > 0:
2323
sh_scipy_l_m = (

sphericart-jax/pyproject.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ name = "sphericart-jax"
33
dynamic = ["version"]
44
requires-python = ">=3.9"
55
dependencies = [
6-
"jax >=0.4.18,<0.6",
6+
"jax >=0.4.32,<0.8",
77
"packaging",
88
]
99

1010
readme = "README.md"
11-
license = {text = "Apache-2.0"}
11+
license = "Apache-2.0 or MIT"
1212
description = "JAX bindings to sphericart"
1313
authors = [
1414
{name = "Filippo Bigi"},
@@ -22,7 +22,6 @@ keywords = ["spherical harmonics", "jax"]
2222
classifiers = [
2323
"Development Status :: 4 - Beta",
2424
"Intended Audience :: Science/Research",
25-
"License :: OSI Approved :: BSD License",
2625
"Operating System :: POSIX",
2726
"Operating System :: MacOS :: MacOS X",
2827
"Environment :: GPU :: NVIDIA CUDA",
@@ -43,7 +42,7 @@ repository = "https://github.com/lab-cosmo/sphericart"
4342

4443
[build-system]
4544
requires = [
46-
"setuptools >=44",
45+
"setuptools >=77",
4746
"wheel >=0.36",
4847
"cmake >=3.30",
4948
"pybind11>=2.8.0",

sphericart-jax/python/sphericart/jax/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_minimum_cuda_version_for_jax(jax_version):
4444

4545
# register the operations to xla
4646
for _name, _value in sphericart_jax_cpu.registrations().items():
47-
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="cpu")
47+
jax.ffi.register_ffi_target(_name, _value, platform="cpu", api_version=0)
4848

4949
has_sphericart_jax_cuda = False
5050
try:
@@ -53,7 +53,7 @@ def get_minimum_cuda_version_for_jax(jax_version):
5353
has_sphericart_jax_cuda = True
5454
# register the operations to xla
5555
for _name, _value in sphericart_jax_cuda.registrations().items():
56-
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu")
56+
jax.ffi.register_ffi_target(_name, _value, platform="gpu", api_version=0)
5757
except ImportError:
5858
has_sphericart_jax_cuda = False
5959
pass

sphericart-jax/python/sphericart/jax/ddsph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import partial
33

44
import jax
5-
from jax import core
5+
from jax import extend
66
from jax.core import ShapedArray
77
from jax.interpreters import mlir, xla
88
from jax.interpreters.mlir import custom_call, ir
@@ -14,7 +14,7 @@
1414
# as well as some transformation rules. For more information and comments,
1515
# see sph.py
1616

17-
_ddsph_p = core.Primitive("ddsph")
17+
_ddsph_p = extend.core.Primitive("ddsph")
1818
_ddsph_p.multiple_results = True
1919
_ddsph_p.def_impl(partial(xla.apply_primitive, _ddsph_p))
2020

sphericart-jax/python/sphericart/jax/dsph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import jax
55
import jax.numpy as jnp
6-
from jax import core
6+
from jax import extend
77
from jax.core import ShapedArray
88
from jax.interpreters import ad, mlir, xla
99
from jax.interpreters.mlir import custom_call, ir
@@ -16,7 +16,7 @@
1616
# as well as some transformation rules. For more information and comments,
1717
# see sph.py
1818

19-
_dsph_p = core.Primitive("dsph")
19+
_dsph_p = extend.core.Primitive("dsph")
2020
_dsph_p.multiple_results = True
2121
_dsph_p.def_impl(partial(xla.apply_primitive, _dsph_p))
2222

sphericart-jax/python/sphericart/jax/sph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import jax
55
import jax.numpy as jnp
6-
from jax import core
6+
from jax import extend
77
from jax.core import ShapedArray
88
from jax.interpreters import ad, mlir, xla
99
from jax.interpreters.mlir import custom_call, ir
@@ -13,7 +13,7 @@
1313

1414

1515
# register the sph primitive
16-
_sph_p = core.Primitive("sph_fwd")
16+
_sph_p = extend.core.Primitive("sph_fwd")
1717
_sph_p.def_impl(partial(xla.apply_primitive, _sph_p))
1818

1919

0 commit comments

Comments
 (0)