From f1e1e99dff672ba90973e3c3771e11566fdaca48 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Tue, 29 Oct 2024 10:24:25 +0100 Subject: [PATCH 001/106] Add .dacecache to gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index b6e4761..8bac872 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +# dace +.dacecache/ From b16a0af049a27de99e89aa914425883260205fac Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Tue, 29 Oct 2024 12:25:55 +0100 Subject: [PATCH 002/106] Add Jax framework --- framework_info/jax.json | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 framework_info/jax.json diff --git a/framework_info/jax.json b/framework_info/jax.json new file mode 100644 index 0000000..d1a2318 --- /dev/null +++ b/framework_info/jax.json @@ -0,0 +1,10 @@ +{ + "framework": { + "simple_name": "jax", + "full_name": "Jax", + "prefix": "jax", + "postfix": "jax", + "class": "Framework", + "arch": "cpu" + } +} \ No newline at end of file From a331c456d5778e13dc1c80e2f8d3724e03bd5cee Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Tue, 29 Oct 2024 12:26:56 +0100 Subject: [PATCH 003/106] Add Jax "compute" benchmark --- npbench/benchmarks/compute/compute_jax.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 npbench/benchmarks/compute/compute_jax.py diff --git a/npbench/benchmarks/compute/compute_jax.py b/npbench/benchmarks/compute/compute_jax.py new file mode 100644 index 0000000..93d23f0 --- /dev/null +++ b/npbench/benchmarks/compute/compute_jax.py @@ -0,0 +1,8 @@ +# https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html + +import jax.numpy as jnp +import jax + +@jax.jit +def compute(array_1, array_2, a, b, c): + return jnp.clip(array_1, 2, 10) * a + array_2 * b + c From 3de6ec74df477106015d68417af6436cf995e324 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 2 Nov 2024 19:44:52 +0100 Subject: [PATCH 004/106] Add jax_framework Define the JaxFramework class for implementing block_until_ready() calls after kernel computation for correctness of profiling, convert np.ndarray to jax Array in the copy_func() --- npbench/infrastructure/jax_framework.py | 46 +++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 npbench/infrastructure/jax_framework.py diff --git a/npbench/infrastructure/jax_framework.py b/npbench/infrastructure/jax_framework.py new file mode 100644 index 0000000..b5f3f52 --- /dev/null +++ b/npbench/infrastructure/jax_framework.py @@ -0,0 +1,46 @@ +# Copyright 2021 ETH Zurich and the NPBench authors. All rights reserved. +import jax.numpy as jnp +import jax +jax.config.update("jax_enable_x64", True) + +from npbench.infrastructure import Benchmark, Framework +from typing import Any, Callable, Dict + + +class JaxFramework(Framework): + """ A class for reading and processing framework information. """ + + def __init__(self, fname: str): + """ Reads framework information. + :param fname: The framework name. + """ + + super().__init__(fname) + + def imports(self) -> Dict[str, Any]: + return {'jax': jax} + + def copy_func(self) -> Callable: + """ Returns the copy-method that should be used + for copying the benchmark arguments. """ + return jnp.array + + def exec_str(self, bench: Benchmark, impl: Callable = None): + """ Generates the execution-string that should be used to call + the benchmark implementation. + :param bench: A benchmark. + :param impl: A benchmark implementation. + """ + + arg_str = self.arg_str(bench, impl) + # param_str = self.param_str(bench, impl) + main_exec_str = "__npb_result = __npb_impl({a})".format(a=arg_str) + sync_str = """ +if isinstance(__npb_result, jax.Array): + __npb_result.block_until_ready() +elif isinstance(__npb_result, tuple): + for item in __npb_result: + if isinstance(item, jax.Array): + item.block_until_ready() +""" + return main_exec_str + sync_str From 76cd00a5142b37bdf6a2128b03f1c5a6ea582025 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 2 Nov 2024 19:46:17 +0100 Subject: [PATCH 005/106] Add jax_framework Define the JaxFramework class for implementing block_until_ready() calls after kernel computation for correctness of profiling, convert np.ndarray to jax Array in the copy_func() --- framework_info/jax.json | 2 +- npbench/infrastructure/__init__.py | 1 + npbench/infrastructure/jax_framework.py | 46 +++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 npbench/infrastructure/jax_framework.py diff --git a/framework_info/jax.json b/framework_info/jax.json index d1a2318..c3506b8 100644 --- a/framework_info/jax.json +++ b/framework_info/jax.json @@ -4,7 +4,7 @@ "full_name": "Jax", "prefix": "jax", "postfix": "jax", - "class": "Framework", + "class": "JaxFramework", "arch": "cpu" } } \ No newline at end of file diff --git a/npbench/infrastructure/__init__.py b/npbench/infrastructure/__init__.py index ac403b7..8b13b1b 100644 --- a/npbench/infrastructure/__init__.py +++ b/npbench/infrastructure/__init__.py @@ -10,3 +10,4 @@ from .legate_framework import * from .numba_framework import * from .pythran_framework import * +from .jax_framework import * diff --git a/npbench/infrastructure/jax_framework.py b/npbench/infrastructure/jax_framework.py new file mode 100644 index 0000000..b5f3f52 --- /dev/null +++ b/npbench/infrastructure/jax_framework.py @@ -0,0 +1,46 @@ +# Copyright 2021 ETH Zurich and the NPBench authors. All rights reserved. +import jax.numpy as jnp +import jax +jax.config.update("jax_enable_x64", True) + +from npbench.infrastructure import Benchmark, Framework +from typing import Any, Callable, Dict + + +class JaxFramework(Framework): + """ A class for reading and processing framework information. """ + + def __init__(self, fname: str): + """ Reads framework information. + :param fname: The framework name. + """ + + super().__init__(fname) + + def imports(self) -> Dict[str, Any]: + return {'jax': jax} + + def copy_func(self) -> Callable: + """ Returns the copy-method that should be used + for copying the benchmark arguments. """ + return jnp.array + + def exec_str(self, bench: Benchmark, impl: Callable = None): + """ Generates the execution-string that should be used to call + the benchmark implementation. + :param bench: A benchmark. + :param impl: A benchmark implementation. + """ + + arg_str = self.arg_str(bench, impl) + # param_str = self.param_str(bench, impl) + main_exec_str = "__npb_result = __npb_impl({a})".format(a=arg_str) + sync_str = """ +if isinstance(__npb_result, jax.Array): + __npb_result.block_until_ready() +elif isinstance(__npb_result, tuple): + for item in __npb_result: + if isinstance(item, jax.Array): + item.block_until_ready() +""" + return main_exec_str + sync_str From cacf7f50f2b0912b420c2ea5df69e6da41bb22aa Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 2 Nov 2024 19:50:10 +0100 Subject: [PATCH 006/106] Add azimint_hist jax implementation --- .../azimint_hist/azimint_hist_jax.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 npbench/benchmarks/azimint_hist/azimint_hist_jax.py diff --git a/npbench/benchmarks/azimint_hist/azimint_hist_jax.py b/npbench/benchmarks/azimint_hist/azimint_hist_jax.py new file mode 100644 index 0000000..2103621 --- /dev/null +++ b/npbench/benchmarks/azimint_hist/azimint_hist_jax.py @@ -0,0 +1,19 @@ +# Copyright 2014 Jérôme Kieffer et al. +# This is an open-access article distributed under the terms of the +# Creative Commons Attribution License, which permits unrestricted use, +# distribution, and reproduction in any medium, provided the original author +# and source are credited. +# http://creativecommons.org/licenses/by/3.0/ +# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for +# high performance azimuthal integration on gpu, 2014. In Proceedings of the +# 7th European Conference on Python in Science (EuroSciPy 2014). + +import jax +import jax.numpy as jnp +from functools import partial + +@partial(jax.jit, static_argnums=(2,)) +def azimint_hist(data: jax.Array, radius: jax.Array, npt): + histu = jnp.histogram(radius, npt)[0] + histw = jnp.histogram(radius, npt, weights=data)[0] + return histw / histu From e9ae93442a60d1851454ba4e953f9e771beedd02 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 2 Nov 2024 19:50:17 +0100 Subject: [PATCH 007/106] Add go_fast jax implementation --- npbench/benchmarks/go_fast/go_fast_jax.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 npbench/benchmarks/go_fast/go_fast_jax.py diff --git a/npbench/benchmarks/go_fast/go_fast_jax.py b/npbench/benchmarks/go_fast/go_fast_jax.py new file mode 100644 index 0000000..93e6582 --- /dev/null +++ b/npbench/benchmarks/go_fast/go_fast_jax.py @@ -0,0 +1,11 @@ +# https://numba.readthedocs.io/en/stable/user/5minguide.html + +import jax +import jax.numpy as jnp + +@jax.jit +def go_fast(a: jax.Array): + trace = jnp.float64(0) + for i in range(a.shape[0]): + trace += jnp.tanh(a[i, i]) + return a + trace From 40dd6416c526f2cc8bd5f6bae51802ad8c76cd58 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 2 Nov 2024 19:50:32 +0100 Subject: [PATCH 008/106] Add bicg jax implementation --- npbench/benchmarks/polybench/bicg/bicg_jax.py | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 npbench/benchmarks/polybench/bicg/bicg_jax.py diff --git a/npbench/benchmarks/polybench/bicg/bicg_jax.py b/npbench/benchmarks/polybench/bicg/bicg_jax.py new file mode 100644 index 0000000..9114b40 --- /dev/null +++ b/npbench/benchmarks/polybench/bicg/bicg_jax.py @@ -0,0 +1,7 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(A: jax.Array, p: jax.Array, r: jax.Array): + + return r @ A, A @ p From 053d87b4d6d2dea3323f64e7bb7e8b351b4a9301 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 2 Nov 2024 19:50:47 +0100 Subject: [PATCH 009/106] Add cholesky2 jax implementation --- npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py diff --git a/npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py b/npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py new file mode 100644 index 0000000..009d8eb --- /dev/null +++ b/npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py @@ -0,0 +1,8 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(A: jax.Array): + A = A.at[:].set(jnp.linalg.cholesky(A) + jnp.triu(A, k=1)) + + return A From 7ddec27b7b091eb19dee0fd3bef846b75fa05dac Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 2 Nov 2024 19:50:57 +0100 Subject: [PATCH 010/106] Add gesummv jax implementation --- npbench/benchmarks/polybench/gesummv/gesummv_jax.py | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 npbench/benchmarks/polybench/gesummv/gesummv_jax.py diff --git a/npbench/benchmarks/polybench/gesummv/gesummv_jax.py b/npbench/benchmarks/polybench/gesummv/gesummv_jax.py new file mode 100644 index 0000000..776f590 --- /dev/null +++ b/npbench/benchmarks/polybench/gesummv/gesummv_jax.py @@ -0,0 +1,7 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(alpha, beta, A: jax.Array, B: jax.Array, x: jax.Array): + + return alpha * A @ x + beta * B @ x From a69848c9c8fb250483997d7c0f6ddc99502823cb Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 2 Nov 2024 19:51:13 +0100 Subject: [PATCH 011/106] Add jacobi_1d jax implementation --- .../benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py diff --git a/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py b/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py new file mode 100644 index 0000000..2cb043b --- /dev/null +++ b/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py @@ -0,0 +1,12 @@ +import jax +import jax.numpy as jnp +from functools import partial + +@partial(jax.jit, static_argnums=(0,)) +def kernel(TSTEPS: int, A: jax.Array, B: jax.Array): + + for t in range(1, TSTEPS): + B = B.at[1:-1].set(0.33333 * (A[:-2] + A[1:-1] + A[2:])) + A = A.at[1:-1].set(0.33333 * (B[:-2] + B[1:-1] + B[2:])) + + return A, B \ No newline at end of file From aea5a703456da5e921dcb54005f16c23cf4eb2df Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 2 Nov 2024 19:51:25 +0100 Subject: [PATCH 012/106] Add jacobi_2d jax implementation --- .../polybench/jacobi_2d/jacobi_2d_jax.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 npbench/benchmarks/polybench/jacobi_2d/jacobi_2d_jax.py diff --git a/npbench/benchmarks/polybench/jacobi_2d/jacobi_2d_jax.py b/npbench/benchmarks/polybench/jacobi_2d/jacobi_2d_jax.py new file mode 100644 index 0000000..33cdeff --- /dev/null +++ b/npbench/benchmarks/polybench/jacobi_2d/jacobi_2d_jax.py @@ -0,0 +1,14 @@ +import jax +import jax.numpy as jnp +from functools import partial + +@partial(jax.jit, static_argnums=(0,)) +def kernel(TSTEPS: int, A: jax.Array, B: jax.Array): + + for t in range(1, TSTEPS): + B = B.at[1:-1, 1:-1].set(0.2 * (A[1:-1, 1:-1] + A[1:-1, :-2] + A[1:-1, 2:] + + A[2:, 1:-1] + A[:-2, 1:-1])) + A = A.at[1:-1, 1:-1].set(0.2 * (B[1:-1, 1:-1] + B[1:-1, :-2] + B[1:-1, 2:] + + B[2:, 1:-1] + B[:-2, 1:-1])) + + return A, B From 24b0622359ac79ab9d205997ea2e192cd228fdf2 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 2 Nov 2024 20:31:21 +0100 Subject: [PATCH 013/106] Add cholesky jax implementation --- .../polybench/cholesky/cholesky_jax.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 npbench/benchmarks/polybench/cholesky/cholesky_jax.py diff --git a/npbench/benchmarks/polybench/cholesky/cholesky_jax.py b/npbench/benchmarks/polybench/cholesky/cholesky_jax.py new file mode 100644 index 0000000..71d4651 --- /dev/null +++ b/npbench/benchmarks/polybench/cholesky/cholesky_jax.py @@ -0,0 +1,46 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def kernel(A: jax.Array): + + # Set A[0, 0] to its square root + A = A.at[0, 0].set(jnp.sqrt(A[0, 0])) + + def row_update(i, A): + + def col_update(j, A): + # Create a mask for elements up to `j` + mask = jnp.arange(A.shape[1]) < j + + # Equivalent of A[i, :j] and A[j, :j] using masking + A_i_slice = jnp.where(mask, A[i, :], 0) + A_j_slice = jnp.where(mask, A[j, :], 0) + + # A[i, j] -= dot(A[i, :j], A[j, :j]) + dot_product = jnp.dot(A_i_slice, A_j_slice) + A = A.at[i, j].set(A[i, j] - dot_product) + + # A[i, j] /= A[j, j] + A = A.at[i, j].divide(A[j, j]) + + return A + + # Column update for all `j` in range(0, i) + A = lax.fori_loop(0, i, col_update, A) + + # Equivalent of A[i, i] -= dot(A[i, :i], A[i, :i]) + A_i_slice = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0) + dot_product = jnp.dot(A_i_slice, A_i_slice) + A = A.at[i, i].set(A[i, i] - dot_product) + + # Set A[i, i] to its square root + A = A.at[i, i].set(jnp.sqrt(A[i, i])) + + return A + + # Apply row update for all `i` in range(1, A.shape[0]) + A = lax.fori_loop(1, A.shape[0], row_update, A) + + return A From 3d43c3bd89626fc7688b333bf2c97385d8ca50b5 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sun, 3 Nov 2024 12:06:35 +0100 Subject: [PATCH 014/106] Update go_fast jax implementation --- npbench/benchmarks/go_fast/go_fast_jax.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/npbench/benchmarks/go_fast/go_fast_jax.py b/npbench/benchmarks/go_fast/go_fast_jax.py index 93e6582..4db63ab 100644 --- a/npbench/benchmarks/go_fast/go_fast_jax.py +++ b/npbench/benchmarks/go_fast/go_fast_jax.py @@ -5,7 +5,8 @@ @jax.jit def go_fast(a: jax.Array): - trace = jnp.float64(0) - for i in range(a.shape[0]): - trace += jnp.tanh(a[i, i]) + # Calculate the trace of the tanh of the diagonal elements + trace = jnp.sum(jnp.tanh(jnp.diag(a))) + + # Add the result to the original matrix return a + trace From 5d23224b72593f5f3c7e6b6eb7471dfd2166ccc7 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sun, 3 Nov 2024 12:10:45 +0100 Subject: [PATCH 015/106] Update cholesky2 jax implementation --- npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py b/npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py index 009d8eb..fa30afb 100644 --- a/npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py +++ b/npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py @@ -3,6 +3,8 @@ @jax.jit def kernel(A: jax.Array): - A = A.at[:].set(jnp.linalg.cholesky(A) + jnp.triu(A, k=1)) - return A + L = jnp.linalg.cholesky(A) + upper_A = jnp.triu(A, k=1) + + return L + upper_A From 069029d17940c4e86e298e58756f3158985f794a Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sun, 3 Nov 2024 12:13:19 +0100 Subject: [PATCH 016/106] Update gesummv jax implementation --- npbench/benchmarks/polybench/gesummv/gesummv_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/npbench/benchmarks/polybench/gesummv/gesummv_jax.py b/npbench/benchmarks/polybench/gesummv/gesummv_jax.py index 776f590..f4f3aa2 100644 --- a/npbench/benchmarks/polybench/gesummv/gesummv_jax.py +++ b/npbench/benchmarks/polybench/gesummv/gesummv_jax.py @@ -4,4 +4,4 @@ @jax.jit def kernel(alpha, beta, A: jax.Array, B: jax.Array, x: jax.Array): - return alpha * A @ x + beta * B @ x + return (alpha * A + beta * B) @ x From cc88ecd745c11c71a1f2602d7755727a2f361247 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sun, 3 Nov 2024 12:16:23 +0100 Subject: [PATCH 017/106] Update jacobi_1d jax implementation --- npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py b/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py index 2cb043b..781b4c6 100644 --- a/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py +++ b/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py @@ -1,12 +1,19 @@ import jax import jax.numpy as jnp +from jax import lax from functools import partial @partial(jax.jit, static_argnums=(0,)) def kernel(TSTEPS: int, A: jax.Array, B: jax.Array): - for t in range(1, TSTEPS): + def body_fn(t, arrays): + A, B = arrays + # Update B based on A, and then update A based on the new B B = B.at[1:-1].set(0.33333 * (A[:-2] + A[1:-1] + A[2:])) A = A.at[1:-1].set(0.33333 * (B[:-2] + B[1:-1] + B[2:])) + return A, B + + # Run the loop for TSTEPS iterations + A, B = lax.fori_loop(1, TSTEPS, body_fn, (A, B)) return A, B \ No newline at end of file From 7d9e7f26b0cc9659f326b60b4275983e3c4d1a8b Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sun, 3 Nov 2024 12:18:48 +0100 Subject: [PATCH 018/106] Update jacobi_2d jax implementation --- .../polybench/jacobi_2d/jacobi_2d_jax.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/npbench/benchmarks/polybench/jacobi_2d/jacobi_2d_jax.py b/npbench/benchmarks/polybench/jacobi_2d/jacobi_2d_jax.py index 33cdeff..054e414 100644 --- a/npbench/benchmarks/polybench/jacobi_2d/jacobi_2d_jax.py +++ b/npbench/benchmarks/polybench/jacobi_2d/jacobi_2d_jax.py @@ -1,14 +1,22 @@ import jax import jax.numpy as jnp +from jax import lax from functools import partial @partial(jax.jit, static_argnums=(0,)) def kernel(TSTEPS: int, A: jax.Array, B: jax.Array): - - for t in range(1, TSTEPS): + + def body_fn(t, arrays): + A, B = arrays + # Update B based on A B = B.at[1:-1, 1:-1].set(0.2 * (A[1:-1, 1:-1] + A[1:-1, :-2] + A[1:-1, 2:] + - A[2:, 1:-1] + A[:-2, 1:-1])) + A[2:, 1:-1] + A[:-2, 1:-1])) + # Update A based on the new B A = A.at[1:-1, 1:-1].set(0.2 * (B[1:-1, 1:-1] + B[1:-1, :-2] + B[1:-1, 2:] + - B[2:, 1:-1] + B[:-2, 1:-1])) - + B[2:, 1:-1] + B[:-2, 1:-1])) + return A, B + + # Execute the loop for TSTEPS iterations + A, B = lax.fori_loop(1, TSTEPS, body_fn, (A, B)) + return A, B From dc24a9ff9020d92f879806004423550a5dc49c37 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Mon, 4 Nov 2024 08:22:38 +0100 Subject: [PATCH 019/106] add covariance jax --- .../benchmarks/polybench/covariance/covariance_jax.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 npbench/benchmarks/polybench/covariance/covariance_jax.py diff --git a/npbench/benchmarks/polybench/covariance/covariance_jax.py b/npbench/benchmarks/polybench/covariance/covariance_jax.py new file mode 100644 index 0000000..e788823 --- /dev/null +++ b/npbench/benchmarks/polybench/covariance/covariance_jax.py @@ -0,0 +1,10 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(M, float_n, data): + + cov = jnp.cov(data, rowvar=False) + + return cov \ No newline at end of file From 6ffc18dfa3e5cbebbc9a27e540f1bd7588ab1eed Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 4 Nov 2024 13:55:52 +0100 Subject: [PATCH 020/106] Add atax Jax implementation --- npbench/benchmarks/polybench/atax/atax_jax.py | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 npbench/benchmarks/polybench/atax/atax_jax.py diff --git a/npbench/benchmarks/polybench/atax/atax_jax.py b/npbench/benchmarks/polybench/atax/atax_jax.py new file mode 100644 index 0000000..26ed193 --- /dev/null +++ b/npbench/benchmarks/polybench/atax/atax_jax.py @@ -0,0 +1,6 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(A: jax.Array, x: jax.Array): + return (A @ x) @ A From de3d8f635da871400742f3df739149f87644c316 Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 4 Nov 2024 14:39:25 +0100 Subject: [PATCH 021/106] Add doitgen Jax implementation --- .../benchmarks/polybench/doitgen/doitgen_jax.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 npbench/benchmarks/polybench/doitgen/doitgen_jax.py diff --git a/npbench/benchmarks/polybench/doitgen/doitgen_jax.py b/npbench/benchmarks/polybench/doitgen/doitgen_jax.py new file mode 100644 index 0000000..8e6caf9 --- /dev/null +++ b/npbench/benchmarks/polybench/doitgen/doitgen_jax.py @@ -0,0 +1,14 @@ +import jax +import jax.numpy as jnp + +from functools import partial + +@partial(jax.jit, static_argnums=(0, 1, 2)) +def kernel(NR: int, NQ: int, NP: int, A:jax.Array, C4:jax.Array): + + # for r in range(NR): + # for q in range(NQ): + # sum[:] = A[r, q, :] @ C4 + # A[r, q, :] = sum + A = A.at[:].set(jnp.reshape(jnp.reshape(A, (NR, NQ, 1, NP)) @ C4, (NR, NQ, NP))) + return A From 3967a008fffcdd754c507c620b4fb8dc0195b06f Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 4 Nov 2024 14:48:06 +0100 Subject: [PATCH 022/106] Add gemm Jax implementation --- npbench/benchmarks/polybench/gemm/gemm_jax.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 npbench/benchmarks/polybench/gemm/gemm_jax.py diff --git a/npbench/benchmarks/polybench/gemm/gemm_jax.py b/npbench/benchmarks/polybench/gemm/gemm_jax.py new file mode 100644 index 0000000..6dfe60a --- /dev/null +++ b/npbench/benchmarks/polybench/gemm/gemm_jax.py @@ -0,0 +1,8 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(alpha: jnp.float64, beta: jnp.float64, C:jax.Array, A:jax.Array, B:jax.Array): + + C = C.at[:].set(alpha * A @ B + beta * C) + return C From eca1dc8f107cae8411ed95fcf23fe2330a364737 Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 4 Nov 2024 14:51:15 +0100 Subject: [PATCH 023/106] Add k2mm Jax implementation --- npbench/benchmarks/polybench/k2mm/k2mm_jax.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 npbench/benchmarks/polybench/k2mm/k2mm_jax.py diff --git a/npbench/benchmarks/polybench/k2mm/k2mm_jax.py b/npbench/benchmarks/polybench/k2mm/k2mm_jax.py new file mode 100644 index 0000000..db2a15f --- /dev/null +++ b/npbench/benchmarks/polybench/k2mm/k2mm_jax.py @@ -0,0 +1,9 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(alpha: jnp.float64, beta: jnp.float64, A: jax.Array, B: jax.Array, C: jax.Array, D: jax.Array): + + D = D.at[:].set(alpha * A @ B @ C + beta * D) + return D From da9e2b5ae6fd0ee7de0bd9f0309a5b0d52c02c25 Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 4 Nov 2024 14:53:54 +0100 Subject: [PATCH 024/106] Add k3mm Jax implementation --- npbench/benchmarks/polybench/k3mm/k3mm_jax.py | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 npbench/benchmarks/polybench/k3mm/k3mm_jax.py diff --git a/npbench/benchmarks/polybench/k3mm/k3mm_jax.py b/npbench/benchmarks/polybench/k3mm/k3mm_jax.py new file mode 100644 index 0000000..354d9d0 --- /dev/null +++ b/npbench/benchmarks/polybench/k3mm/k3mm_jax.py @@ -0,0 +1,7 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(A, B, C, D): + + return A @ B @ C @ D From 7264536cd5467534a338395aed4b71a36bc00bfb Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 4 Nov 2024 15:05:07 +0100 Subject: [PATCH 025/106] Add mvt Jax implementation --- npbench/benchmarks/polybench/mvt/mvt_jax.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 npbench/benchmarks/polybench/mvt/mvt_jax.py diff --git a/npbench/benchmarks/polybench/mvt/mvt_jax.py b/npbench/benchmarks/polybench/mvt/mvt_jax.py new file mode 100644 index 0000000..732051f --- /dev/null +++ b/npbench/benchmarks/polybench/mvt/mvt_jax.py @@ -0,0 +1,10 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(x1, x2, y_1, y_2, A): + + x1 += A @ y_1 + x2 += y_2 @ A + + return (x1, x2) From 92cdc38790ee8f433058385fb6f4952e418ff358 Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 4 Nov 2024 15:08:22 +0100 Subject: [PATCH 026/106] Add softmax Jax implementation --- .../benchmarks/deep_learning/softmax/softmax_jax.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 npbench/benchmarks/deep_learning/softmax/softmax_jax.py diff --git a/npbench/benchmarks/deep_learning/softmax/softmax_jax.py b/npbench/benchmarks/deep_learning/softmax/softmax_jax.py new file mode 100644 index 0000000..c9a4f76 --- /dev/null +++ b/npbench/benchmarks/deep_learning/softmax/softmax_jax.py @@ -0,0 +1,11 @@ +import jax +import jax.numpy as jnp + + +# Numerically-stable version of softmax +@jax.jit +def softmax(x): + tmp_max = jnp.max(x, axis=-1, keepdims=True) + tmp_out = jnp.exp(x - tmp_max) + tmp_sum = jnp.sum(tmp_out, axis=-1, keepdims=True) + return tmp_out / tmp_sum From 058faf91c54bf65614bfcec0281844667b758c9c Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 4 Nov 2024 16:32:10 +0100 Subject: [PATCH 027/106] Add trmm Jax implementation --- npbench/benchmarks/polybench/trmm/trmm_jax.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 npbench/benchmarks/polybench/trmm/trmm_jax.py diff --git a/npbench/benchmarks/polybench/trmm/trmm_jax.py b/npbench/benchmarks/polybench/trmm/trmm_jax.py new file mode 100644 index 0000000..cb27935 --- /dev/null +++ b/npbench/benchmarks/polybench/trmm/trmm_jax.py @@ -0,0 +1,15 @@ +import jax +import jax.numpy as jnp + +# import numpy as np + + + +@jax.jit +def kernel(alpha, A, B): + + L = jnp.triu(A, 1) # 1 excludes the main diagonal + B += L @ B + B *= alpha + + return B From f877c2b5ac008a7dcce7064365e1db81c15845dd Mon Sep 17 00:00:00 2001 From: Sushant S Date: Tue, 5 Nov 2024 06:38:49 +0100 Subject: [PATCH 028/106] add fdtd_2d --- .../polybench/fdtd_2d/fdtd_2d_jax.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 npbench/benchmarks/polybench/fdtd_2d/fdtd_2d_jax.py diff --git a/npbench/benchmarks/polybench/fdtd_2d/fdtd_2d_jax.py b/npbench/benchmarks/polybench/fdtd_2d/fdtd_2d_jax.py new file mode 100644 index 0000000..98cf9bb --- /dev/null +++ b/npbench/benchmarks/polybench/fdtd_2d/fdtd_2d_jax.py @@ -0,0 +1,19 @@ +import jax +import jax.numpy as jnp +from jax import lax + + +@jax.jit +def kernel(TMAX, ex, ey, hz, _fict_): + + def loop_body(t, loop_vars): + ex, ey, hz = loop_vars + ey = ey.at[0, :].set(_fict_[t]) + ey = ey.at[1:, :].set(ey[1:, :] - 0.5 * (hz[1:, :] - hz[:-1, :])) + ex = ex.at[:, 1:].set(ex[:, 1:] - 0.5 * (hz[:, 1:] - hz[:, :-1])) + hz = hz.at[:-1, :-1].set(hz[:-1, :-1] - 0.7 * (ex[:-1, 1:] - ex[:-1, :-1] + + ey[1:, :-1] - ey[:-1, :-1])) + return ex, ey, hz + + ex, ey, hz = lax.fori_loop(0, TMAX, loop_body, (ex, ey, hz)) + return ex, ey, hz, _fict_ From b39be288835600af02154023d04cfdb147ac39c7 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Tue, 5 Nov 2024 07:11:52 +0100 Subject: [PATCH 029/106] add floyd_warshall --- .../polybench/floyd_warshall/floyd_warshall_jax.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 npbench/benchmarks/polybench/floyd_warshall/floyd_warshall_jax.py diff --git a/npbench/benchmarks/polybench/floyd_warshall/floyd_warshall_jax.py b/npbench/benchmarks/polybench/floyd_warshall/floyd_warshall_jax.py new file mode 100644 index 0000000..a930f3a --- /dev/null +++ b/npbench/benchmarks/polybench/floyd_warshall/floyd_warshall_jax.py @@ -0,0 +1,11 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(path): + + for k in range(path.shape[0]): + path = path.at[:].set(jnp.minimum(path[:], jnp.add.outer(path[:, k], path[k, :]))) + + return path \ No newline at end of file From 84559e758b47f33e19ce4276fd8fc33528c84e61 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Tue, 5 Nov 2024 07:30:13 +0100 Subject: [PATCH 030/106] add lu --- npbench/benchmarks/polybench/lu/lu_jax.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 npbench/benchmarks/polybench/lu/lu_jax.py diff --git a/npbench/benchmarks/polybench/lu/lu_jax.py b/npbench/benchmarks/polybench/lu/lu_jax.py new file mode 100644 index 0000000..399ee01 --- /dev/null +++ b/npbench/benchmarks/polybench/lu/lu_jax.py @@ -0,0 +1,15 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(A): + + for i in range(A.shape[0]): + for j in range(i): + A = A.at[i, j].set(A[i, j] - A[i, :j] @ A[:j, j]) + A = A.at[i, j].set(A[i, j] / A[j, j]) + for j in range(i, A.shape[0]): + A = A.at[i, j].set(A[i, j] - A[i, :i] @ A[:i, j]) + + return A From 30635b6468016f1bd21372150a37be42e629ef48 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Tue, 5 Nov 2024 07:58:29 +0100 Subject: [PATCH 031/106] add seidel_2d --- .../polybench/seidel_2d/seidel_2d_jax.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py diff --git a/npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py b/npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py new file mode 100644 index 0000000..a5b0ccb --- /dev/null +++ b/npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py @@ -0,0 +1,32 @@ +import jax +import jax.numpy as jnp +from jax import lax + + +@jax.jit +def kernel(TSTEPS, N, A): + + def loop1(t, A): + + def loop2(i, A): + + def loop3(j, A): + A = A.at[i, j].set(A[i, j] + A[i, j - 1]) + A = A.at[i, j].set(A[i, j] / 9.0) + return A + + A = A.at[i, 1:-1].set( + A[i, 1:-1] + (A[i - 1, :-2] + A[i - 1, 1:-1] + A[i - 1, 2:] + + A[i, 2:] + A[i + 1, :-2] + A[i + 1, 1:-1] + + A[i + 1, 2:]) + ) + + A = lax.fori_loop(1, N - 1, loop3, A) + return A + + A = lax.fori_loop(1, N - 1, loop2, A) + return A + + A = lax.fori_loop(0, TSTEPS - 1, loop1, A) + + return A From 0d4e6256762b307c7f488e5451d4a6719be3f7ef Mon Sep 17 00:00:00 2001 From: Sushant S Date: Tue, 5 Nov 2024 10:19:01 +0100 Subject: [PATCH 032/106] add syrk --- npbench/benchmarks/polybench/syrk/syrk_jax.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 npbench/benchmarks/polybench/syrk/syrk_jax.py diff --git a/npbench/benchmarks/polybench/syrk/syrk_jax.py b/npbench/benchmarks/polybench/syrk/syrk_jax.py new file mode 100644 index 0000000..310c72d --- /dev/null +++ b/npbench/benchmarks/polybench/syrk/syrk_jax.py @@ -0,0 +1,13 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(alpha, beta, C, A): + + for i in range(A.shape[0]): + C = C.at[i, :i + 1].set(C[i, :i + 1] * beta) + for k in range(A.shape[1]): + C = C.at[i, :i + 1].set(C[i, :i + 1] + alpha * A[i, k] * A[:i + 1, k]) + + return C From d298aa9991d90d068a169f45e66c1cd98aa85f8f Mon Sep 17 00:00:00 2001 From: Sushant S Date: Tue, 5 Nov 2024 10:27:00 +0100 Subject: [PATCH 033/106] add syr2k --- npbench/benchmarks/polybench/syr2k/syr2k_jax.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 npbench/benchmarks/polybench/syr2k/syr2k_jax.py diff --git a/npbench/benchmarks/polybench/syr2k/syr2k_jax.py b/npbench/benchmarks/polybench/syr2k/syr2k_jax.py new file mode 100644 index 0000000..263dcff --- /dev/null +++ b/npbench/benchmarks/polybench/syr2k/syr2k_jax.py @@ -0,0 +1,13 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(alpha, beta, C, A, B): + for i in range(A.shape[0]): + C = C.at[i, :i + 1].set(C[i, :i + 1] * beta) + for k in range(A.shape[1]): + C = C.at[i, :i + 1].set(C[i, :i + 1] + A[:i + 1, k] * alpha * B[i, k] + + B[:i + 1, k] * alpha * A[i, k]) + + return C From f0fd399da7aadeddb87324b02b804b205f2a2d72 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Tue, 5 Nov 2024 10:30:45 +0100 Subject: [PATCH 034/106] add trisolv --- npbench/benchmarks/polybench/trisolv/trisolv_jax.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 npbench/benchmarks/polybench/trisolv/trisolv_jax.py diff --git a/npbench/benchmarks/polybench/trisolv/trisolv_jax.py b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py new file mode 100644 index 0000000..9122d9c --- /dev/null +++ b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py @@ -0,0 +1,11 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(L, x, b): + + for i in range(x.shape[0]): + x = x.at[i].set((b[i] - L[i, :i] @ x[:i]) / L[i, i]) + + return x From 4b871d855981f3a3865dfc7c911749954207a7bd Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Wed, 6 Nov 2024 22:23:28 +0100 Subject: [PATCH 035/106] Update block_until_ready() call in exec_str so returning modified arrays in jax benchmarks is optional --- npbench/infrastructure/jax_framework.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/npbench/infrastructure/jax_framework.py b/npbench/infrastructure/jax_framework.py index b5f3f52..dac358c 100644 --- a/npbench/infrastructure/jax_framework.py +++ b/npbench/infrastructure/jax_framework.py @@ -33,14 +33,6 @@ def exec_str(self, bench: Benchmark, impl: Callable = None): """ arg_str = self.arg_str(bench, impl) - # param_str = self.param_str(bench, impl) - main_exec_str = "__npb_result = __npb_impl({a})".format(a=arg_str) - sync_str = """ -if isinstance(__npb_result, jax.Array): - __npb_result.block_until_ready() -elif isinstance(__npb_result, tuple): - for item in __npb_result: - if isinstance(item, jax.Array): - item.block_until_ready() -""" - return main_exec_str + sync_str + main_exec_str = "__npb_result = jax.block_until_ready(__npb_impl({a}))".format(a=arg_str) + + return main_exec_str From 27337cfcf06ef3344474e1e901853e176cd3245c Mon Sep 17 00:00:00 2001 From: Sushant S Date: Sun, 10 Nov 2024 10:58:03 +0100 Subject: [PATCH 036/106] update floyd_warshall --- .../polybench/floyd_warshall/floyd_warshall_jax.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/npbench/benchmarks/polybench/floyd_warshall/floyd_warshall_jax.py b/npbench/benchmarks/polybench/floyd_warshall/floyd_warshall_jax.py index a930f3a..d5cdc25 100644 --- a/npbench/benchmarks/polybench/floyd_warshall/floyd_warshall_jax.py +++ b/npbench/benchmarks/polybench/floyd_warshall/floyd_warshall_jax.py @@ -1,11 +1,15 @@ import jax import jax.numpy as jnp +from jax import lax @jax.jit def kernel(path): - for k in range(path.shape[0]): + def loop_func(k, path): path = path.at[:].set(jnp.minimum(path[:], jnp.add.outer(path[:, k], path[k, :]))) + return path + + path = lax.fori_loop(0, path.shape[0], loop_func, path) - return path \ No newline at end of file + return path From fea3c9cf50bbca3ce69d678928d938c69d0d3fbd Mon Sep 17 00:00:00 2001 From: Sushant S Date: Tue, 12 Nov 2024 18:24:41 +0100 Subject: [PATCH 037/106] add azimint_naive --- .../azimint_naive/azimint_naive_jax.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 npbench/benchmarks/azimint_naive/azimint_naive_jax.py diff --git a/npbench/benchmarks/azimint_naive/azimint_naive_jax.py b/npbench/benchmarks/azimint_naive/azimint_naive_jax.py new file mode 100644 index 0000000..2f4f242 --- /dev/null +++ b/npbench/benchmarks/azimint_naive/azimint_naive_jax.py @@ -0,0 +1,28 @@ +# Copyright 2014 Jérôme Kieffer et al. +# This is an open-access article distributed under the terms of the +# Creative Commons Attribution License, which permits unrestricted use, +# distribution, and reproduction in any medium, provided the original author +# and source are credited. +# http://creativecommons.org/licenses/by/3.0/ +# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for +# high performance azimuthal integration on gpu, 2014. In Proceedings of the +# 7th European Conference on Python in Science (EuroSciPy 2014). + +import jax +import jax.numpy as jnp +from functools import partial + + +@partial(jax.jit, static_argnums=(2,)) +def azimint_naive(data, radius, npt): + rmax = radius.max() + res = jnp.zeros(npt, dtype=jnp.float64) + + for i in range(npt): + r1 = rmax * i / npt + r2 = rmax * (i + 1) / npt + mask_r12 = jnp.logical_and((r1 <= radius), (radius < r2)) + mean = jnp.where(mask_r12, data, 0).mean(where=mask_r12) + res = res.at[i].set(mean) + + return res From b7a140d63c6d754884748f4ab9d97f7d47c5caf0 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Tue, 12 Nov 2024 18:53:13 +0100 Subject: [PATCH 038/106] add correlation --- .../polybench/correlation/correlation_jax.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 npbench/benchmarks/polybench/correlation/correlation_jax.py diff --git a/npbench/benchmarks/polybench/correlation/correlation_jax.py b/npbench/benchmarks/polybench/correlation/correlation_jax.py new file mode 100644 index 0000000..af4f56c --- /dev/null +++ b/npbench/benchmarks/polybench/correlation/correlation_jax.py @@ -0,0 +1,20 @@ +import jax +import jax.numpy as jnp +from jax import lax +from functools import partial + +@partial(jax.jit, static_argnums=(0,)) +def kernel(M, float_n, data): + + mean = jnp.mean(data, axis=0) + stddev = jnp.std(data, axis=0) + stddev = jnp.where(stddev <= 0.1, 1.0, stddev) + data = data - mean + data = data / (jnp.sqrt(float_n) * stddev) + corr = jnp.eye(M, dtype=data.dtype) + + for i in range(M - 1): + corr = corr.at[i + 1:M, i].set(data[:, i] @ data[:, i + 1:M]) + corr = corr.at[i, i + 1:M].set(data[:, i] @ data[:, i + 1:M]) + + return corr From c23ec0496a5d9e0c3423ab6f796b9ccbb1d7f460 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 16 Nov 2024 12:56:59 +0100 Subject: [PATCH 039/106] Add symm jax implementation --- npbench/benchmarks/polybench/symm/symm_jax.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 npbench/benchmarks/polybench/symm/symm_jax.py diff --git a/npbench/benchmarks/polybench/symm/symm_jax.py b/npbench/benchmarks/polybench/symm/symm_jax.py new file mode 100644 index 0000000..4827b62 --- /dev/null +++ b/npbench/benchmarks/polybench/symm/symm_jax.py @@ -0,0 +1,42 @@ +import jax +import jax.numpy as jnp +from jax import lax + +def kernel(alpha, beta, C: jax.Array, A: jax.Array, B: jax.Array): + + # Scale C by beta + C = C * beta + + def row_update(i, C_and_temp2): + C, _ = C_and_temp2 + temp2 = jnp.zeros((C.shape[1],), dtype=C.dtype) + + def col_update(j, val): + C, temp2 = val + + # Create a mask for elements up to `i` + mask = jnp.arange(C.shape[0]) < i + + # Masked elements to simulate A[i, :i] and B[:i, j] + A_slice = jnp.where(mask, A[i, :], 0) # A[i, :i] with mask + B_slice = jnp.where(mask, B[:, j], 0) # B[:i, j] with mask + + # Equivalent to C[:i, j] += alpha * B[i, j] * A[i, :i] + C = C.at[:, j].add(alpha * B[i, j] * A_slice) + + # Set temp2[j] with the masked dot product result + temp2 = temp2.at[j].set(jnp.dot(B_slice, A_slice)) + return C, temp2 + + # Update columns in row `i` + C, temp2 = lax.fori_loop(0, C.shape[1], col_update, (C, temp2)) + + # Update row `i` after column updates + C = C.at[i, :].add(alpha * B[i, :] * A[i, i] + alpha * temp2) + + return C, temp2 + + # Apply the row update across all rows + C, _ = lax.fori_loop(0, C.shape[0], row_update, (C, jnp.zeros(C.shape[1], dtype=C.dtype))) + return C + From 6e53257ac6618659e0d3fd8cb867eb1e41a1e994 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 16 Nov 2024 14:58:33 +0100 Subject: [PATCH 040/106] Add conv2d jax implementation --- .../deep_learning/conv2d_bias/conv2d_jax.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 npbench/benchmarks/deep_learning/conv2d_bias/conv2d_jax.py diff --git a/npbench/benchmarks/deep_learning/conv2d_bias/conv2d_jax.py b/npbench/benchmarks/deep_learning/conv2d_bias/conv2d_jax.py new file mode 100644 index 0000000..36c07a5 --- /dev/null +++ b/npbench/benchmarks/deep_learning/conv2d_bias/conv2d_jax.py @@ -0,0 +1,48 @@ +import jax.numpy as jnp +import jax +from jax import lax + + +# Deep learning convolutional operator (stride = 1) +@jax.jit +def conv2d(input, weights): + K = weights.shape[0] # Assuming square kernel + N = input.shape[0] + H_out = input.shape[1] - K + 1 + W_out = input.shape[2] - K + 1 + C_out = weights.shape[3] + output = jnp.empty((N, H_out, W_out, C_out), dtype=jnp.float32) + + def col_update(j, arrays): + input, weights, output, i = arrays + + input_slice = lax.dynamic_slice( + input, + (0, i, j, 0), + (N, K, K, input.shape[-1]) + ) + conv_result = jnp.sum( + input_slice[:, :, :, :, jnp.newaxis] * weights[jnp.newaxis, :, :, :], + axis=(1, 2, 3) + ) + output = lax.dynamic_update_slice( + output, + conv_result[:, jnp.newaxis, jnp.newaxis, :], + (0, i, j, 0) + ) + return input, weights, output, i + + def row_update(i, arrays): + input, weights, output = arrays + arrays = (input, weights, output, i) + _, _, output, _ = lax.fori_loop(0, W_out, col_update, arrays) + return input, weights, output + + _, _, output = lax.fori_loop(0, H_out, row_update, (input, weights, output)) + + return output + + +@jax.jit +def conv2d_bias(input, weights, bias): + return conv2d(input, weights) + bias From e46e344377e35e006db5e39001839f71090f4b7c Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 16 Nov 2024 23:20:53 +0100 Subject: [PATCH 041/106] Add heat_3d jax implementation --- .../polybench/heat_3d/heat_3d_jax.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 npbench/benchmarks/polybench/heat_3d/heat_3d_jax.py diff --git a/npbench/benchmarks/polybench/heat_3d/heat_3d_jax.py b/npbench/benchmarks/polybench/heat_3d/heat_3d_jax.py new file mode 100644 index 0000000..858b8fd --- /dev/null +++ b/npbench/benchmarks/polybench/heat_3d/heat_3d_jax.py @@ -0,0 +1,26 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def kernel(TSTEPS: int, A: jnp.ndarray, B: jnp.ndarray): + def time_step(t, arrays): + A, B = arrays + + B = B.at[1:-1, 1:-1, 1:-1].set( + 0.125 * (A[2:, 1:-1, 1:-1] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[:-2, 1:-1, 1:-1]) + + 0.125 * (A[1:-1, 2:, 1:-1] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[1:-1, :-2, 1:-1]) + + 0.125 * (A[1:-1, 1:-1, 2:] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[1:-1, 1:-1, :-2]) + + A[1:-1, 1:-1, 1:-1] + ) + A = A.at[1:-1, 1:-1, 1:-1].set( + 0.125 * (B[2:, 1:-1, 1:-1] - 2.0 * B[1:-1, 1:-1, 1:-1] + B[:-2, 1:-1, 1:-1]) + + 0.125 * (B[1:-1, 2:, 1:-1] - 2.0 * B[1:-1, 1:-1, 1:-1] + B[1:-1, :-2, 1:-1]) + + 0.125 * (B[1:-1, 1:-1, 2:] - 2.0 * B[1:-1, 1:-1, 1:-1] + B[1:-1, 1:-1, :-2]) + + B[1:-1, 1:-1, 1:-1] + ) + + return A, B + + A, B = lax.fori_loop(1, TSTEPS, time_step, (A, B)) + return A, B From 57d7aa4804a4a74d29feb631a41a0afe3e926206 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 16 Nov 2024 23:57:11 +0100 Subject: [PATCH 042/106] Add mlp jax implementation --- .../benchmarks/deep_learning/mlp/mlp_jax.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 npbench/benchmarks/deep_learning/mlp/mlp_jax.py diff --git a/npbench/benchmarks/deep_learning/mlp/mlp_jax.py b/npbench/benchmarks/deep_learning/mlp/mlp_jax.py new file mode 100644 index 0000000..57b8c53 --- /dev/null +++ b/npbench/benchmarks/deep_learning/mlp/mlp_jax.py @@ -0,0 +1,24 @@ +import jax.numpy as jnp +import jax + +@jax.jit +def relu(x): + return jnp.maximum(x, 0) + + +# Numerically-stable version of softmax +@jax.jit +def softmax(x): + tmp_max = jnp.max(x, axis=-1, keepdims=True) + tmp_out = jnp.exp(x - tmp_max) + tmp_sum = jnp.sum(tmp_out, axis=-1, keepdims=True) + return tmp_out / tmp_sum + + +# 3-layer MLP +@jax.jit +def mlp(input, w1, b1, w2, b2, w3, b3): + x = relu(input @ w1 + b1) + x = relu(x @ w2 + b2) + x = softmax(x @ w3 + b3) # Softmax call can be omitted if necessary + return x From 09d404abfa17bb15b02c3b9c411f5eadf13cd9e6 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sun, 17 Nov 2024 00:21:56 +0100 Subject: [PATCH 043/106] Fix symm jax implementation --- npbench/benchmarks/polybench/symm/symm_jax.py | 38 ++++++++----------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/npbench/benchmarks/polybench/symm/symm_jax.py b/npbench/benchmarks/polybench/symm/symm_jax.py index 4827b62..abeafce 100644 --- a/npbench/benchmarks/polybench/symm/symm_jax.py +++ b/npbench/benchmarks/polybench/symm/symm_jax.py @@ -2,41 +2,33 @@ import jax.numpy as jnp from jax import lax +@jax.jit def kernel(alpha, beta, C: jax.Array, A: jax.Array, B: jax.Array): - # Scale C by beta - C = C * beta + temp2 = jnp.empty((C.shape[1], ), dtype=C.dtype) + C *= beta - def row_update(i, C_and_temp2): - C, _ = C_and_temp2 - temp2 = jnp.zeros((C.shape[1],), dtype=C.dtype) + def row_update(i, arrays): + C, temp2 = arrays def col_update(j, val): C, temp2 = val - # Create a mask for elements up to `i` - mask = jnp.arange(C.shape[0]) < i - - # Masked elements to simulate A[i, :i] and B[:i, j] - A_slice = jnp.where(mask, A[i, :], 0) # A[i, :i] with mask - B_slice = jnp.where(mask, B[:, j], 0) # B[:i, j] with mask - - # Equivalent to C[:i, j] += alpha * B[i, j] * A[i, :i] - C = C.at[:, j].add(alpha * B[i, j] * A_slice) - - # Set temp2[j] with the masked dot product result - temp2 = temp2.at[j].set(jnp.dot(B_slice, A_slice)) + A_slice = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0.0) + B_slice = jnp.where(jnp.arange(B.shape[0]) < i, B[:, j], 0.0) + + C = lax.dynamic_update_slice( + C, + (C[:,j] + (alpha * B[i, j] * A_slice))[:, None], + (0, j) + ) + temp2 = temp2.at[j].set(B_slice @ A_slice) return C, temp2 - # Update columns in row `i` C, temp2 = lax.fori_loop(0, C.shape[1], col_update, (C, temp2)) - - # Update row `i` after column updates C = C.at[i, :].add(alpha * B[i, :] * A[i, i] + alpha * temp2) - return C, temp2 - # Apply the row update across all rows - C, _ = lax.fori_loop(0, C.shape[0], row_update, (C, jnp.zeros(C.shape[1], dtype=C.dtype))) + C, _ = lax.fori_loop(0, C.shape[0], row_update, (C, temp2)) return C From a01496eb4a6a4cb9274354beed4dc42c20f546b4 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sun, 17 Nov 2024 00:24:31 +0100 Subject: [PATCH 044/106] Add contour_integral jax implementation --- .../contour_integral/contour_integral_jax.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 npbench/benchmarks/contour_integral/contour_integral_jax.py diff --git a/npbench/benchmarks/contour_integral/contour_integral_jax.py b/npbench/benchmarks/contour_integral/contour_integral_jax.py new file mode 100644 index 0000000..b319bda --- /dev/null +++ b/npbench/benchmarks/contour_integral/contour_integral_jax.py @@ -0,0 +1,35 @@ +import jax +import jax.numpy as jnp +from functools import partial + +@partial(jax.jit, static_argnums=(0, 1, 2)) +def contour_integral(NR, NM, slab_per_bc, Ham, int_pts, Y): + P0 = jnp.zeros((NR, NM), dtype=jnp.complex128) + P1 = jnp.zeros((NR, NM), dtype=jnp.complex128) + + def body_fun(i, accum): + P0, P1 = accum + z = int_pts[i] + Tz = jnp.zeros((NR, NR), dtype=jnp.complex128) + + def compute_Tz(n, Tz): + zz = jnp.power(z, slab_per_bc / 2 - n) + Tz += zz * Ham[n] + return Tz + + Tz = jax.lax.fori_loop(0, slab_per_bc + 1, compute_Tz, Tz) + + if NR == NM: + X = jnp.linalg.inv(Tz) + else: + X = jnp.linalg.solve(Tz, Y) + X = jax.lax.cond(abs(z) < 1.0, lambda x: -x, lambda x: x, X) + + P0 += X + P1 += z * X + + return P0, P1 + + P0, P1 = jax.lax.fori_loop(0, int_pts.shape[0], body_fun, (P0, P1)) + + return P0, P1 From 87df0165b58de2ec4a53d939ffb7ef691d5263ed Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sun, 17 Nov 2024 00:26:52 +0100 Subject: [PATCH 045/106] Fix jacobi_1d jax implementation Remove comments for code optimization, remove TSTEPS as statically compiled parameter (not required with fori_loop) --- npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py b/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py index 781b4c6..600ea06 100644 --- a/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py +++ b/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py @@ -1,19 +1,14 @@ import jax -import jax.numpy as jnp from jax import lax -from functools import partial -@partial(jax.jit, static_argnums=(0,)) +@jax.jit def kernel(TSTEPS: int, A: jax.Array, B: jax.Array): def body_fn(t, arrays): A, B = arrays - # Update B based on A, and then update A based on the new B B = B.at[1:-1].set(0.33333 * (A[:-2] + A[1:-1] + A[2:])) A = A.at[1:-1].set(0.33333 * (B[:-2] + B[1:-1] + B[2:])) return A, B - # Run the loop for TSTEPS iterations A, B = lax.fori_loop(1, TSTEPS, body_fn, (A, B)) - return A, B \ No newline at end of file From 17e00aba2f927df7c996560bf1fa88c32b83b938 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sun, 17 Nov 2024 00:29:21 +0100 Subject: [PATCH 046/106] Add gramschmidt jax implementation --- .../polybench/gramschmidt/gramschmidt_jax.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 npbench/benchmarks/polybench/gramschmidt/gramschmidt_jax.py diff --git a/npbench/benchmarks/polybench/gramschmidt/gramschmidt_jax.py b/npbench/benchmarks/polybench/gramschmidt/gramschmidt_jax.py new file mode 100644 index 0000000..55d7dbf --- /dev/null +++ b/npbench/benchmarks/polybench/gramschmidt/gramschmidt_jax.py @@ -0,0 +1,28 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(A): + + Q = jnp.zeros_like(A) + R = jnp.zeros((A.shape[1], A.shape[1]), dtype=A.dtype) + + def body_fun(k, arrays): + Q, R, A = arrays + + nrm = jnp.dot(A[:, k], A[:, k]) + R = R.at[k, k].set(jnp.sqrt(nrm)) + Q = Q.at[:, k].set(A[:, k] / R[k, k]) + + def inner_body_fun(j, arrays): + Q, R, A = arrays + R = R.at[k, j].set(jnp.dot(Q[:, k], A[:, j])) + A = A.at[:, j].add(-Q[:, k] * R[k, j]) + return Q, R, A + + Q, R, A = jax.lax.fori_loop(k + 1, A.shape[1], inner_body_fun, (Q, R, A)) + return Q, R, A + + Q, R, A = jax.lax.fori_loop(0, A.shape[1], body_fun, (Q, R, A)) + + return Q, R From 34db9e131bc72623d9be570138645f9ae27280ef Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sun, 17 Nov 2024 11:27:36 +0100 Subject: [PATCH 047/106] Fix cholesky jax implementation --- npbench/benchmarks/polybench/cholesky/cholesky_jax.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/npbench/benchmarks/polybench/cholesky/cholesky_jax.py b/npbench/benchmarks/polybench/cholesky/cholesky_jax.py index 71d4651..0633886 100644 --- a/npbench/benchmarks/polybench/cholesky/cholesky_jax.py +++ b/npbench/benchmarks/polybench/cholesky/cholesky_jax.py @@ -5,42 +5,31 @@ @jax.jit def kernel(A: jax.Array): - # Set A[0, 0] to its square root A = A.at[0, 0].set(jnp.sqrt(A[0, 0])) def row_update(i, A): def col_update(j, A): - # Create a mask for elements up to `j` mask = jnp.arange(A.shape[1]) < j - # Equivalent of A[i, :j] and A[j, :j] using masking A_i_slice = jnp.where(mask, A[i, :], 0) A_j_slice = jnp.where(mask, A[j, :], 0) - # A[i, j] -= dot(A[i, :j], A[j, :j]) dot_product = jnp.dot(A_i_slice, A_j_slice) A = A.at[i, j].set(A[i, j] - dot_product) - - # A[i, j] /= A[j, j] A = A.at[i, j].divide(A[j, j]) return A - # Column update for all `j` in range(0, i) A = lax.fori_loop(0, i, col_update, A) - # Equivalent of A[i, i] -= dot(A[i, :i], A[i, :i]) A_i_slice = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0) dot_product = jnp.dot(A_i_slice, A_i_slice) A = A.at[i, i].set(A[i, i] - dot_product) - - # Set A[i, i] to its square root A = A.at[i, i].set(jnp.sqrt(A[i, i])) return A - # Apply row update for all `i` in range(1, A.shape[0]) A = lax.fori_loop(1, A.shape[0], row_update, A) return A From 591f79f7057645600e98ca71b6359ea422d9c040 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sun, 17 Nov 2024 12:07:01 +0100 Subject: [PATCH 048/106] Add spmv jax implementation --- npbench/benchmarks/spmv/spmv_jax.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 npbench/benchmarks/spmv/spmv_jax.py diff --git a/npbench/benchmarks/spmv/spmv_jax.py b/npbench/benchmarks/spmv/spmv_jax.py new file mode 100644 index 0000000..0a8e109 --- /dev/null +++ b/npbench/benchmarks/spmv/spmv_jax.py @@ -0,0 +1,22 @@ +# Sparse Matrix-Vector Multiplication (SpMV) +import jax.numpy as jnp +import jax +from jax import lax + +# Matrix-Vector Multiplication with the matrix given in Compressed Sparse Row +# (CSR) format +@jax.jit +def spmv(A_row, A_col, A_val, x): + y = jnp.empty(A_row.size - 1, dtype=A_val.dtype) + + def row_update(i, y): + mask = (jnp.arange(A_col.size) >= A_row[i]) & (jnp.arange(A_col.size) < A_row[i + 1]) + + cols = jnp.where(mask, A_col, 0) + vals = jnp.where(mask, A_val, 0) + y = y.at[i].set(vals @ x[cols]) + + return y + + y = lax.fori_loop(0, A_row.size - 1, row_update, y) + return y From 330511d848c6cab475252da3ed8ffa87810d63b4 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Mon, 18 Nov 2024 06:25:45 +0100 Subject: [PATCH 049/106] add durbin --- .../benchmarks/polybench/durbin/durbin_jax.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 npbench/benchmarks/polybench/durbin/durbin_jax.py diff --git a/npbench/benchmarks/polybench/durbin/durbin_jax.py b/npbench/benchmarks/polybench/durbin/durbin_jax.py new file mode 100644 index 0000000..06bae7e --- /dev/null +++ b/npbench/benchmarks/polybench/durbin/durbin_jax.py @@ -0,0 +1,19 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(r): + + y = jnp.empty_like(r) + alpha = -r[0] + beta = 1.0 + y = y.at[0].set(-r[0]) + + for k in range(1, r.shape[0]): + beta *= 1.0 - alpha * alpha + alpha = -(r[k] + jnp.dot(jnp.flip(r[:k]), y[:k])) / beta + y = y.at[:k].add(alpha * jnp.flip(y[:k])) + y = y.at[k].set(alpha) + + return y From 6986abe30a99329127bcbaf6b3c7db3a68954b86 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Mon, 18 Nov 2024 07:11:46 +0100 Subject: [PATCH 050/106] add lud_cmp --- .../benchmarks/polybench/ludcmp/ludcmp_jax.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py diff --git a/npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py b/npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py new file mode 100644 index 0000000..528fe59 --- /dev/null +++ b/npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py @@ -0,0 +1,50 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def kernel(A, b): + + x = jnp.zeros_like(b) + y = jnp.zeros_like(b) + + def loop_body_1(i, A): + def inner_loop_1(j, A): + A_slice_1 = jnp.where(jnp.arange(A.shape[1]) < j, A[i, :], 0.0) + A_slice_2 = jnp.where(jnp.arange(A.shape[0]) < j, A[:, j], 0.0) + + A = A.at[i, j].set(A[i, j] - A_slice_1 @ A_slice_2) + A = A.at[i, j].set(A[i, j] / A[j, j]) + return A + + def inner_loop_2(j, A): + A_slice_1 = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0.0) + A_slice_2 = jnp.where(jnp.arange(A.shape[0]) < i, A[:, j], 0.0) + A = A.at[i, j].set(A[i, j] - A_slice_1 @ A_slice_2) + return A + + A = lax.fori_loop(0, i, inner_loop_1, A) + A = lax.fori_loop(i, A.shape[0], inner_loop_2, A) + + return A + + def loop_body_2(i, loop_vars): + A, y, b = loop_vars + A_slice = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0.0) + y_slice = jnp.where(jnp.arange(y.shape[0]) < i, y, 0.0) + y = y.at[i].set(b[i] - A_slice @ y_slice) + return A, y, b + + def loop_body_3(t, loop_vars): + A, x, y = loop_vars + i = A.shape[0] - 1 - t + A_slice = jnp.where(jnp.arange(A.shape[1]) > i, A[i, :], 0.0) + x_slice = jnp.where(jnp.arange(x.shape[0]) > i, x, 0.0) + x = x.at[i].set((y[i] - A_slice @ x_slice) / A[i, i]) + return A, x, y + + A = lax.fori_loop(0, A.shape[0], loop_body_1, A) + A, y, b = lax.fori_loop(0, A.shape[0], loop_body_2, (A, y, b)) + A, x, y = lax.fori_loop(0, A.shape[0], loop_body_3, (A, x, y)) + + return x, y From e6b37d17b8fa3dad0e3553478a5d6c8dc8245fe0 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Mon, 18 Nov 2024 07:19:04 +0100 Subject: [PATCH 051/106] update lu --- npbench/benchmarks/polybench/lu/lu_jax.py | 28 ++++++++++++++++++----- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/npbench/benchmarks/polybench/lu/lu_jax.py b/npbench/benchmarks/polybench/lu/lu_jax.py index 399ee01..72ac80b 100644 --- a/npbench/benchmarks/polybench/lu/lu_jax.py +++ b/npbench/benchmarks/polybench/lu/lu_jax.py @@ -1,15 +1,31 @@ import jax import jax.numpy as jnp +from jax import lax @jax.jit def kernel(A): - - for i in range(A.shape[0]): - for j in range(i): - A = A.at[i, j].set(A[i, j] - A[i, :j] @ A[:j, j]) + + def loop_body(i, A): + + def inner_loop_1(j, A): + A_slice_1 = jnp.where(jnp.arange(A.shape[1]) < j, A[i, :], 0.0) + A_slice_2 = jnp.where(jnp.arange(A.shape[0]) < j, A[:, j], 0.0) + A = A.at[i, j].set(A[i, j] - A_slice_1 @ A_slice_2) A = A.at[i, j].set(A[i, j] / A[j, j]) - for j in range(i, A.shape[0]): - A = A.at[i, j].set(A[i, j] - A[i, :i] @ A[:i, j]) + return A + + def inner_loop_2(j, A): + A_slice_1 = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0.0) + A_slice_2 = jnp.where(jnp.arange(A.shape[0]) < i, A[:, j], 0.0) + A = A.at[i, j].set(A[i, j] - A_slice_1 @ A_slice_2) + return A + + A = lax.fori_loop(0, i, inner_loop_1, A) + A = lax.fori_loop(i, A.shape[0], inner_loop_2, A) + return A + + A = lax.fori_loop(0, A.shape[0], loop_body, A) + return A From bb8661dcbb2981402c0bbc6da0e1aaf139c4c9bf Mon Sep 17 00:00:00 2001 From: Sushant S Date: Mon, 18 Nov 2024 07:23:19 +0100 Subject: [PATCH 052/106] update trisolv --- npbench/benchmarks/polybench/trisolv/trisolv_jax.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/npbench/benchmarks/polybench/trisolv/trisolv_jax.py b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py index 9122d9c..60b2ed3 100644 --- a/npbench/benchmarks/polybench/trisolv/trisolv_jax.py +++ b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py @@ -5,7 +5,13 @@ @jax.jit def kernel(L, x, b): - for i in range(x.shape[0]): - x = x.at[i].set((b[i] - L[i, :i] @ x[:i]) / L[i, i]) + def loop_body(i, loop_vars): + L, x, b = loop_vars + L_slice = jnp.where(jnp.arange(L.shape[1]) < i, L[i, :], 0.0) + x_slice = jnp.where(jnp.arange(x.shape[0]) < i, x, 0.0) + x = x.at[i].set((b[i] - L_slice @ x_slice) / L[i, i]) + return L, x, b + + _, x, _ = jax.lax.fori_loop(0, x.shape[0], loop_body, (L, x, b)) return x From 936a90ae53e65845f7e7022f00735e4c4ea125ad Mon Sep 17 00:00:00 2001 From: Sushant S Date: Mon, 18 Nov 2024 07:57:29 +0100 Subject: [PATCH 053/106] update syr2k --- .../benchmarks/polybench/syr2k/syr2k_jax.py | 38 ++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/npbench/benchmarks/polybench/syr2k/syr2k_jax.py b/npbench/benchmarks/polybench/syr2k/syr2k_jax.py index 263dcff..2458c69 100644 --- a/npbench/benchmarks/polybench/syr2k/syr2k_jax.py +++ b/npbench/benchmarks/polybench/syr2k/syr2k_jax.py @@ -1,13 +1,41 @@ import jax import jax.numpy as jnp +from jax import lax @jax.jit def kernel(alpha, beta, C, A, B): - for i in range(A.shape[0]): - C = C.at[i, :i + 1].set(C[i, :i + 1] * beta) - for k in range(A.shape[1]): - C = C.at[i, :i + 1].set(C[i, :i + 1] + A[:i + 1, k] * alpha * B[i, k] + - B[:i + 1, k] * alpha * A[i, k]) + + def loop_body(i, loop_vars): + + def inner_loop(k, loop_vars): + alpha, C, A, B = loop_vars + A_update_slice = jnp.where(jnp.arange(A.shape[0]) < i + 1, A[:, k], 0.0) + A_update_slice *= alpha * B[i, k] + + + B_update_slice = jnp.where(jnp.arange(B.shape[0]) < i + 1, B[:, k], 0.0) + B_update_slice *= alpha * A[i, k] + + C_update_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C[i, :], 0.0) + C_update_slice += A_update_slice + B_update_slice + C_update_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C_update_slice, C[i, :]) + + C = lax.dynamic_update_slice(C, C_update_slice[None, :], (i, 0)) + return alpha, C, A, B + + + alpha, beta, C, A, B = loop_vars + C_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C[i, :], 0.0) + C_slice = C_slice * beta + C_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C_slice, C[i, :]) + + C = lax.dynamic_update_slice(C, C_slice[None, :], (i, 0)) + + _, C, _, _ = lax.fori_loop(0, A.shape[1], inner_loop, (alpha, C, A, B)) + + return alpha, beta, C, A, B + + _, _, C, _, _ = lax.fori_loop(0, A.shape[0], loop_body, (alpha, beta, C, A, B)) return C From 4c15b77f3204c9a9b6fdc31f7527af44e9c1b7b7 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Mon, 18 Nov 2024 08:04:43 +0100 Subject: [PATCH 054/106] update syrk --- npbench/benchmarks/polybench/syrk/syrk_jax.py | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/npbench/benchmarks/polybench/syrk/syrk_jax.py b/npbench/benchmarks/polybench/syrk/syrk_jax.py index 310c72d..8dcab55 100644 --- a/npbench/benchmarks/polybench/syrk/syrk_jax.py +++ b/npbench/benchmarks/polybench/syrk/syrk_jax.py @@ -1,13 +1,36 @@ import jax import jax.numpy as jnp +from jax import lax @jax.jit def kernel(alpha, beta, C, A): - for i in range(A.shape[0]): - C = C.at[i, :i + 1].set(C[i, :i + 1] * beta) - for k in range(A.shape[1]): - C = C.at[i, :i + 1].set(C[i, :i + 1] + alpha * A[i, k] * A[:i + 1, k]) + def loop_body(i, loop_vars): + + def inner_loop(k, loop_vars): + alpha, C, A = loop_vars + A_update_slice = jnp.where(jnp.arange(A.shape[0]) < i + 1, A[:, k], 0.0) + A_update_slice *= alpha * A[i, k] + + C_update_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C[i, :], 0.0) + C_update_slice += A_update_slice + C_update_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C_update_slice, C[i, :]) + + C = lax.dynamic_update_slice(C, C_update_slice[None, :], (i, 0)) + return alpha, C, A + + alpha, beta, C, A = loop_vars + + C_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C[i, :], 0.0) + C_slice = C_slice * beta + C_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C_slice, C[i, :]) + C = lax.dynamic_update_slice(C, C_slice[None, :], (i, 0)) + + _, C, _ = lax.fori_loop(0, A.shape[1], inner_loop, (alpha, C, A)) + + return alpha, beta, C, A + + _, _, C, _ = lax.fori_loop(0, A.shape[0], loop_body, (alpha, beta, C, A)) return C From be2c3203a936f2ad438d7620fcd3350b420a9967 Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 18 Nov 2024 14:44:43 +0100 Subject: [PATCH 055/106] Add mandelbrot1 Jax implementation --- .../benchmarks/mandelbrot1/mandelbrot1_jax.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 npbench/benchmarks/mandelbrot1/mandelbrot1_jax.py diff --git a/npbench/benchmarks/mandelbrot1/mandelbrot1_jax.py b/npbench/benchmarks/mandelbrot1/mandelbrot1_jax.py new file mode 100644 index 0000000..4953f02 --- /dev/null +++ b/npbench/benchmarks/mandelbrot1/mandelbrot1_jax.py @@ -0,0 +1,30 @@ +# ----------------------------------------------------------------------------- +# From Numpy to Python +# Copyright (2017) Nicolas P. Rougier - BSD license +# More information at https://github.com/rougier/numpy-book +# ----------------------------------------------------------------------------- + +import jax +import jax.numpy as jnp +from functools import partial + +@partial(jax.jit, static_argnames=["xn", "yn", "maxiter"]) +def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0): + # Adapted from https://www.ibm.com/developerworks/community/blogs/jfp/... + # .../entry/How_To_Compute_Mandelbrodt_Set_Quickly?lang=en + X = jnp.linspace(xmin, xmax, xn, dtype=jnp.float64) + Y = jnp.linspace(ymin, ymax, yn, dtype=jnp.float64) + C = X + Y[:, None] * 1j + N = jnp.zeros(C.shape, dtype=jnp.int64) + Z = jnp.zeros(C.shape, dtype=jnp.complex128) + + def body_fun(n, state): + Z, N = state + I = jnp.less(jnp.abs(Z), horizon) + new_N = jnp.where(I, n, N) + new_Z = jnp.where(I, Z**2 + C, Z) + return new_Z, new_N + + Z, N = jax.lax.fori_loop(0, maxiter, body_fun, (Z, N)) + N = jnp.where(N == maxiter-1, 0, N) + return Z, N From 855778d2b7d0cbfe93dfb14414b8ecb1627f4a8d Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 18 Nov 2024 16:47:58 +0100 Subject: [PATCH 056/106] Add stockham_fft implementation --- .../stockham_fft/stockham_fft_jax.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 npbench/benchmarks/stockham_fft/stockham_fft_jax.py diff --git a/npbench/benchmarks/stockham_fft/stockham_fft_jax.py b/npbench/benchmarks/stockham_fft/stockham_fft_jax.py new file mode 100644 index 0000000..b222339 --- /dev/null +++ b/npbench/benchmarks/stockham_fft/stockham_fft_jax.py @@ -0,0 +1,32 @@ +import jax +import jax.numpy as jnp +from functools import partial + + +@partial(jax.jit, static_argnames=["N", "R", "K"]) +def stockham_fft(N, R, K, x, y): + + # Generate DFT matrix for radix R. + # Define transient variable for matrix. + i_coord, j_coord = jnp.mgrid[0:R, 0:R] + # dft_mat = jnp.empty((R, R), dtype=jnp.complex128) + dft_mat = jnp.exp(-2.0j * jnp.pi * i_coord * j_coord / R) + y = x + + ii_coord, jj_coord = jnp.mgrid[0:R, 0:R**K] + + # Main Stockham loop + for i in range(K): + # Stride permutation + yv = jnp.reshape(y, (R**i, R, R**(K - i - 1))) + tmp_perm = jnp.transpose(yv, axes=(1, 0, 2)) + + # Twiddle Factor multiplication + tmp = jnp.exp(-2.0j * jnp.pi * ii_coord[:, :R**i] * jj_coord[:, :R**i] / R**(i + 1)) + D = jnp.repeat(jnp.reshape(tmp, (R, R**i, 1)), R**(K - i - 1), axis=2) + tmp_twid = jnp.reshape(tmp_perm, (N, )) * jnp.reshape(D, (N, )) + + # Product with Butterfly + y = jnp.reshape(dft_mat @ jnp.reshape(tmp_twid, (R, R**(K - 1))),(N, )) + + return y \ No newline at end of file From 8d6af1ab5c298c36f314fd2b0ce9d7aff76507ff Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 18 Nov 2024 18:45:51 +0100 Subject: [PATCH 057/106] Add scattering_self_energies Jax implementation --- .../scattering_self_energies_jax.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 npbench/benchmarks/scattering_self_energies/scattering_self_energies_jax.py diff --git a/npbench/benchmarks/scattering_self_energies/scattering_self_energies_jax.py b/npbench/benchmarks/scattering_self_energies/scattering_self_energies_jax.py new file mode 100644 index 0000000..a8f34cc --- /dev/null +++ b/npbench/benchmarks/scattering_self_energies/scattering_self_energies_jax.py @@ -0,0 +1,39 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def scattering_self_energies(neigh_idx, dH, G, D, Sigma): + def body_fun(sigma, idx): + k, E, q, w, a, b, i, j = idx + + dHG = G[k, E - w, neigh_idx[a, b]] @ dH[a, b, i] + dHD = dH[a, b, j] * D[q, w, a, b, i, j] + + update = jnp.where(E >= w, dHG @ dHD, 0.0) + + return sigma.at[k, E, a].add(update), None + + # Create all possible index combinations + k_range = jnp.arange(G.shape[0]) + E_range = jnp.arange(G.shape[1]) + q_range = jnp.arange(D.shape[0]) + w_range = jnp.arange(D.shape[1]) + a_range = jnp.arange(neigh_idx.shape[0]) + b_range = jnp.arange(neigh_idx.shape[1]) + i_range = jnp.arange(D.shape[-2]) + j_range = jnp.arange(D.shape[-1]) + + # Create meshgrid of indices + indices = jnp.meshgrid( + k_range, E_range, q_range, w_range, + a_range, b_range, i_range, j_range, + indexing='ij' + ) + + # Reshape indices into a single array of 8-tuples + indices = jnp.stack([idx.ravel() for idx in indices], axis=1) + + # Use scan to iterate over all index combinations + result, _ = jax.lax.scan(body_fun, Sigma, indices) + + return result From 099134005d3d111ca03bc2a70dd3dc0340d9a7aa Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Mon, 18 Nov 2024 19:37:36 +0100 Subject: [PATCH 058/106] Add deriche Jax implementation --- .../polybench/deriche/deriche_jax.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 npbench/benchmarks/polybench/deriche/deriche_jax.py diff --git a/npbench/benchmarks/polybench/deriche/deriche_jax.py b/npbench/benchmarks/polybench/deriche/deriche_jax.py new file mode 100644 index 0000000..a390a33 --- /dev/null +++ b/npbench/benchmarks/polybench/deriche/deriche_jax.py @@ -0,0 +1,71 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def kernel(alpha, imgIn): + + k = (1.0 - jnp.exp(-alpha)) * (1.0 - jnp.exp(-alpha)) / ( + 1.0 + alpha * jnp.exp(-alpha) - jnp.exp(2.0 * alpha)) + a1 = a5 = k + a2 = a6 = k * jnp.exp(-alpha) * (alpha - 1.0) + a3 = a7 = k * jnp.exp(-alpha) * (alpha + 1.0) + a4 = a8 = -k * jnp.exp(-2.0 * alpha) + b1 = 2.0**(-alpha) + b2 = -jnp.exp(-2.0 * alpha) + c1 = c2 = 1 + + y1 = jnp.empty_like(imgIn) + y1 = y1.at[:, 0].set(a1 * imgIn[:, 0]) + y1 = y1.at[:, 1].set(a1 * imgIn[:, 1] + a2 * imgIn[:, 0] + b1 * y1[:, 0]) + + def horizontal_forward(j, y1): + return y1.at[:, j].set( + a1 * imgIn[:, j] + a2 * imgIn[:, j-1] + + b1 * y1[:, j-1] + b2 * y1[:, j-2] + ) + + y1 = lax.fori_loop(2, imgIn.shape[1], horizontal_forward, y1) + + y2 = jnp.empty_like(imgIn) + y2 = y2.at[:, -1].set(0.0) + y2 = y2.at[:, -2].set(a3 * imgIn[:, -1]) + + def horizontal_backward(j, y2): + idx = imgIn.shape[1] - 3 - j + return y2.at[:, idx].set( + a3 * imgIn[:, idx+1] + a4 * imgIn[:, idx+2] + + b1 * y2[:, idx+1] + b2 * y2[:, idx+2] + ) + + y2 = lax.fori_loop(0, imgIn.shape[1]-2, horizontal_backward, y2) + + imgOut = c1 * (y1 + y2) + + # First vertical pass + y1 = jnp.empty_like(imgOut) + y1 = y1.at[0, :].set(a5 * imgOut[0, :]) + y1 = y1.at[1, :].set(a5 * imgOut[1, :] + a6 * imgOut[0, :] + b1 * y1[0, :]) + + def vertical_forward(i, y1): + return y1.at[i, :].set( + a5 * imgOut[i, :] + a6 * imgOut[i-1, :] + + b1 * y1[i-1, :] + b2 * y1[i-2, :] + ) + + y1 = lax.fori_loop(2, imgIn.shape[0], vertical_forward, y1) + + y2 = jnp.empty_like(imgOut) + y2 = y2.at[-1, :].set(0.0) + y2 = y2.at[-2, :].set(a7 * imgOut[-1, :]) + + def vertical_backward(i, y2): + idx = imgIn.shape[0] - 3 - i + return y2.at[idx, :].set( + a7 * imgOut[idx+1, :] + a8 * imgOut[idx+2, :] + + b1 * y2[idx+1, :] + b2 * y2[idx+2, :] + ) + + y2 = lax.fori_loop(0, imgIn.shape[0]-2, vertical_backward, y2) + + return c2 * (y1 + y2) \ No newline at end of file From b670765f8dca0d60a23dc63f7e6f78f1bb4e73aa Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Tue, 19 Nov 2024 19:52:02 +0100 Subject: [PATCH 059/106] Add adi Jax implementation --- npbench/benchmarks/polybench/adi/adi_jax.py | 89 +++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 npbench/benchmarks/polybench/adi/adi_jax.py diff --git a/npbench/benchmarks/polybench/adi/adi_jax.py b/npbench/benchmarks/polybench/adi/adi_jax.py new file mode 100644 index 0000000..65b3c15 --- /dev/null +++ b/npbench/benchmarks/polybench/adi/adi_jax.py @@ -0,0 +1,89 @@ +import jax +import jax.numpy as jnp +from jax import lax + +def kernel(TSTEPS, N, u): + # Initialize arrays + v = jnp.zeros_like(u) + p = jnp.zeros_like(u) + q = jnp.zeros_like(u) + + # Constants + DX = 1.0 / N + DY = 1.0 / N + DT = 1.0 / TSTEPS + B1 = 2.0 + B2 = 1.0 + mul1 = B1 * DT / (DX * DX) + mul2 = B2 * DT / (DY * DY) + + a = -mul1 / 2.0 + b = 1.0 + mul2 + c = a + d = -mul2 / 2.0 + e = 1.0 + mul2 + f = d + + def first_j_loop_body(j, carry): + p, q, u = carry + p = p.at[1:N-1, j].set(-c / (a * p[1:N-1, j-1] + b)) + q = q.at[1:N-1, j].set( + (-d * u[j, 0:N-2] + (1.0 + 2.0 * d) * u[j, 1:N-1] - f * u[j, 2:N] - + a * q[1:N-1, j-1]) / (a * p[1:N-1, j-1] + b)) + return (p, q, u) + + def first_backward_j_loop_body(j, carry): + v, p, q = carry + idx = N-2-j # Calculate the actual index for backward iteration + v = v.at[idx, 1:N-1].set(p[1:N-1, idx] * v[idx+1, 1:N-1] + q[1:N-1, idx]) + return (v, p, q) + + def second_j_loop_body(j, carry): + p, q, v = carry + p = p.at[1:N-1, j].set(-f / (d * p[1:N-1, j-1] + e)) + q = q.at[1:N-1, j].set( + (-a * v[0:N-2, j] + (1.0 + 2.0 * a) * v[1:N-1, j] - c * v[2:N, j] - + d * q[1:N-1, j-1]) / (d * p[1:N-1, j-1] + e)) + return (p, q, v) + + def second_backward_j_loop_body(j, carry): + u, p, q = carry + idx = N-2-j # Calculate the actual index for backward iteration + u = u.at[1:N-1, idx].set(p[1:N-1, idx] * u[1:N-1, idx+1] + q[1:N-1, idx]) + return (u, p, q) + + def time_step_body(t, carry): + u, v, p, q = carry + + # First part + v = v.at[0, 1:N-1].set(1.0) + p = p.at[1:N-1, 0].set(0.0) + q = q.at[1:N-1, 0].set(v[0, 1:N-1]) + + # First j loop + p, q, u = lax.fori_loop(1, N-1, first_j_loop_body, (p, q, u)) + + v = v.at[N-1, 1:N-1].set(1.0) + + # First backward j loop + v, p, q = lax.fori_loop(0, N-2, first_backward_j_loop_body, (v, p, q)) + + # Second part + u = u.at[1:N-1, 0].set(1.0) + p = p.at[1:N-1, 0].set(0.0) + q = q.at[1:N-1, 0].set(u[1:N-1, 0]) + + # Second j loop + p, q, v = lax.fori_loop(1, N-1, second_j_loop_body, (p, q, v)) + + u = u.at[1:N-1, N-1].set(1.0) + + # Second backward j loop + u, p, q = lax.fori_loop(0, N-2, second_backward_j_loop_body, (u, p, q)) + + return (u, v, p, q) + + # Main time loop + u, v, p, q = lax.fori_loop(1, TSTEPS + 1, time_step_body, (u, v, p, q)) + + return u \ No newline at end of file From 731f0d872e8b5056db42736f6eae6d3f4b228d37 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Thu, 21 Nov 2024 08:38:53 +0100 Subject: [PATCH 060/106] update durbin --- .../benchmarks/polybench/durbin/durbin_jax.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/npbench/benchmarks/polybench/durbin/durbin_jax.py b/npbench/benchmarks/polybench/durbin/durbin_jax.py index 06bae7e..e3f6e4d 100644 --- a/npbench/benchmarks/polybench/durbin/durbin_jax.py +++ b/npbench/benchmarks/polybench/durbin/durbin_jax.py @@ -1,5 +1,6 @@ import jax import jax.numpy as jnp +from jax import lax @jax.jit @@ -10,10 +11,21 @@ def kernel(r): beta = 1.0 y = y.at[0].set(-r[0]) - for k in range(1, r.shape[0]): + + def loop_body(k, loop_vars): + alpha, beta, y, r = loop_vars beta *= 1.0 - alpha * alpha - alpha = -(r[k] + jnp.dot(jnp.flip(r[:k]), y[:k])) / beta - y = y.at[:k].add(alpha * jnp.flip(y[:k])) + + r_slice = jnp.where(jnp.arange(r.shape[0]) < k, jnp.roll(jnp.flip(r), [k], 0), 0.0) + y_slice = jnp.where(jnp.arange(y.shape[0]) < k, y, 0.0) + alpha = -(r[k] + jnp.dot(r_slice, y_slice)) / beta + + y_update_slice = jnp.where(jnp.arange(y.shape[0]) < k, jnp.roll(jnp.flip(y), [k], 0) * alpha, 0.0) + y += y_update_slice y = y.at[k].set(alpha) + return alpha, beta, y, r + + _, _, y, _ = lax.fori_loop(1, r.shape[0], loop_body, (alpha, beta, y, r)) + return y From 091b29579c70fa0b859d881c328ae9273d1a4ef3 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Thu, 21 Nov 2024 09:28:13 +0100 Subject: [PATCH 061/106] update_correlation --- .../polybench/correlation/correlation_jax.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/npbench/benchmarks/polybench/correlation/correlation_jax.py b/npbench/benchmarks/polybench/correlation/correlation_jax.py index af4f56c..2dde7a8 100644 --- a/npbench/benchmarks/polybench/correlation/correlation_jax.py +++ b/npbench/benchmarks/polybench/correlation/correlation_jax.py @@ -6,6 +6,14 @@ @partial(jax.jit, static_argnums=(0,)) def kernel(M, float_n, data): + + def loop_body(i, loop_vars): + corr, data = loop_vars + corr = lax.dynamic_update_slice(corr, jnp.roll(data[i, None], -(i + 1), axis=1), (i, i + 1)) + corr = lax.dynamic_update_slice(corr, jnp.roll(data[i, None], -(i + 1), axis=1).T, (i + 1, i)) + + return corr, data + mean = jnp.mean(data, axis=0) stddev = jnp.std(data, axis=0) stddev = jnp.where(stddev <= 0.1, 1.0, stddev) @@ -13,8 +21,8 @@ def kernel(M, float_n, data): data = data / (jnp.sqrt(float_n) * stddev) corr = jnp.eye(M, dtype=data.dtype) - for i in range(M - 1): - corr = corr.at[i + 1:M, i].set(data[:, i] @ data[:, i + 1:M]) - corr = corr.at[i, i + 1:M].set(data[:, i] @ data[:, i + 1:M]) + data_mul = jnp.dot(data.T, data) + + corr, _ = lax.fori_loop(0, M - 1, loop_body, (corr, data_mul)) return corr From eed0974783f7321c1360cd46951e4b7fd5ecf2c7 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Thu, 21 Nov 2024 09:33:23 +0100 Subject: [PATCH 062/106] update azimint_naive --- npbench/benchmarks/azimint_naive/azimint_naive_jax.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/npbench/benchmarks/azimint_naive/azimint_naive_jax.py b/npbench/benchmarks/azimint_naive/azimint_naive_jax.py index 2f4f242..5523a2d 100644 --- a/npbench/benchmarks/azimint_naive/azimint_naive_jax.py +++ b/npbench/benchmarks/azimint_naive/azimint_naive_jax.py @@ -10,6 +10,7 @@ import jax import jax.numpy as jnp +from jax import lax from functools import partial @@ -18,11 +19,14 @@ def azimint_naive(data, radius, npt): rmax = radius.max() res = jnp.zeros(npt, dtype=jnp.float64) - for i in range(npt): + def loop_body(i, res): r1 = rmax * i / npt r2 = rmax * (i + 1) / npt mask_r12 = jnp.logical_and((r1 <= radius), (radius < r2)) mean = jnp.where(mask_r12, data, 0).mean(where=mask_r12) res = res.at[i].set(mean) + return res + + res = lax.fori_loop(0, npt, loop_body, res) return res From 2a2a98123ab1189865edb20f57a2356c0b9a6def Mon Sep 17 00:00:00 2001 From: Sushant S Date: Thu, 21 Nov 2024 10:13:10 +0100 Subject: [PATCH 063/106] add crc16 --- npbench/benchmarks/crc16/crc16_jax.py | 33 +++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 npbench/benchmarks/crc16/crc16_jax.py diff --git a/npbench/benchmarks/crc16/crc16_jax.py b/npbench/benchmarks/crc16/crc16_jax.py new file mode 100644 index 0000000..7c630ce --- /dev/null +++ b/npbench/benchmarks/crc16/crc16_jax.py @@ -0,0 +1,33 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def crc16(data, poly=0x8408): + ''' + CRC-16-CCITT Algorithm + ''' + crc = 0xFFFF + + def loop_body(crc, b): + cur_byte = 0xFF & b + + def inner_loop_body(carry, data): + crc, cur_byte = carry + xor_flag = (crc & 0x0001) ^ (cur_byte & 0x0001) + crc = lax.select(xor_flag, (crc >> 1) ^ poly, crc >> 1) + cur_byte >>= 1 + + return (crc, cur_byte), None + + (crc, cur_byte), _ = lax.scan(inner_loop_body, (crc, cur_byte), jnp.arange(8)) + + return crc, None + + crc, _ = lax.scan(loop_body, crc, data) + + crc = (~crc & 0xFFFF) + crc = (crc << 8) | ((crc >> 8) & 0xFF) + + + return crc & 0xFFFF From f99a8161f2c495ab7174b60c671c74e4fc6aab86 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Thu, 21 Nov 2024 10:25:09 +0100 Subject: [PATCH 064/106] add arc_distance --- .../pythran/arc_distance/arc_distance_jax.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 npbench/benchmarks/pythran/arc_distance/arc_distance_jax.py diff --git a/npbench/benchmarks/pythran/arc_distance/arc_distance_jax.py b/npbench/benchmarks/pythran/arc_distance/arc_distance_jax.py new file mode 100644 index 0000000..883c7ac --- /dev/null +++ b/npbench/benchmarks/pythran/arc_distance/arc_distance_jax.py @@ -0,0 +1,42 @@ +# Copyright (c) 2019, Serge Guelton +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# Neither the name of HPCProject, Serge Guelton nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import jax +import jax.numpy as jnp + + +@jax.jit +def arc_distance(theta_1, phi_1, theta_2, phi_2): + """ + Calculates the pairwise arc distance between all points in vector a and b. + """ + temp = jnp.sin((theta_2 - theta_1) / + 2)**2 + jnp.cos(theta_1) * jnp.cos(theta_2) * jnp.sin( + (phi_2 - phi_1) / 2)**2 + distance_matrix = 2 * (jnp.arctan2(jnp.sqrt(temp), jnp.sqrt(1 - temp))) + return distance_matrix \ No newline at end of file From caf01414410ce10c70a3b3b0f1f1632df2b41b81 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Thu, 21 Nov 2024 10:28:40 +0100 Subject: [PATCH 065/106] add gemver --- npbench/benchmarks/polybench/gemver/gemver_jax.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 npbench/benchmarks/polybench/gemver/gemver_jax.py diff --git a/npbench/benchmarks/polybench/gemver/gemver_jax.py b/npbench/benchmarks/polybench/gemver/gemver_jax.py new file mode 100644 index 0000000..563afb0 --- /dev/null +++ b/npbench/benchmarks/polybench/gemver/gemver_jax.py @@ -0,0 +1,12 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(alpha, beta, A, u1, v1, u2, v2, w, x, y, z): + + A += jnp.outer(u1, v1) + jnp.outer(u2, v2) + x += beta * y @ A + z + w += alpha * A @ x + + return A, x, w \ No newline at end of file From d70731252e663dc4fa9cf21718ab363382f9975f Mon Sep 17 00:00:00 2001 From: Sushant S Date: Thu, 21 Nov 2024 10:29:04 +0100 Subject: [PATCH 066/106] add gemver --- npbench/benchmarks/polybench/gemver/gemver_jax.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 npbench/benchmarks/polybench/gemver/gemver_jax.py diff --git a/npbench/benchmarks/polybench/gemver/gemver_jax.py b/npbench/benchmarks/polybench/gemver/gemver_jax.py new file mode 100644 index 0000000..2e106cc --- /dev/null +++ b/npbench/benchmarks/polybench/gemver/gemver_jax.py @@ -0,0 +1,12 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(alpha, beta, A, u1, v1, u2, v2, w, x, y, z): + + A += jnp.outer(u1, v1) + jnp.outer(u2, v2) + x += beta * y @ A + z + w += alpha * A @ x + + return A, x, w From bbc546e41088e9051e980ffd2e5d78587ef8fc65 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Thu, 21 Nov 2024 11:49:52 +0100 Subject: [PATCH 067/106] Update conv2d_jax with lax.scan --- .../deep_learning/conv2d_bias/conv2d_jax.py | 50 +++++++++---------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/npbench/benchmarks/deep_learning/conv2d_bias/conv2d_jax.py b/npbench/benchmarks/deep_learning/conv2d_bias/conv2d_jax.py index 36c07a5..6a50551 100644 --- a/npbench/benchmarks/deep_learning/conv2d_bias/conv2d_jax.py +++ b/npbench/benchmarks/deep_learning/conv2d_bias/conv2d_jax.py @@ -12,33 +12,29 @@ def conv2d(input, weights): W_out = input.shape[2] - K + 1 C_out = weights.shape[3] output = jnp.empty((N, H_out, W_out, C_out), dtype=jnp.float32) - - def col_update(j, arrays): - input, weights, output, i = arrays - - input_slice = lax.dynamic_slice( - input, - (0, i, j, 0), - (N, K, K, input.shape[-1]) - ) - conv_result = jnp.sum( - input_slice[:, :, :, :, jnp.newaxis] * weights[jnp.newaxis, :, :, :], - axis=(1, 2, 3) - ) - output = lax.dynamic_update_slice( - output, - conv_result[:, jnp.newaxis, jnp.newaxis, :], - (0, i, j, 0) - ) - return input, weights, output, i - - def row_update(i, arrays): - input, weights, output = arrays - arrays = (input, weights, output, i) - _, _, output, _ = lax.fori_loop(0, W_out, col_update, arrays) - return input, weights, output - - _, _, output = lax.fori_loop(0, H_out, row_update, (input, weights, output)) + + def row_update(output, i): + def col_update(output, j): + input_slice = lax.dynamic_slice( + input, + (0, i, j, 0), + (N, K, K, input.shape[-1]) + ) + conv_result = jnp.sum( + input_slice[:, :, :, :, None] * weights[None, :, :, :], + axis=(1, 2, 3) + ) + output = lax.dynamic_update_slice( + output, + conv_result[:, None, None, :], + (0, i, j, 0) + ) + return output, None + + output, _ = lax.scan(col_update, output, jnp.arange(W_out)) + return output, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) return output From 1d213cb2c1d127ea374a1183cc2ca614afe206a5 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Thu, 21 Nov 2024 12:16:07 +0100 Subject: [PATCH 068/106] Add lenet jax implementation --- .../deep_learning/lenet/lenet_jax.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 npbench/benchmarks/deep_learning/lenet/lenet_jax.py diff --git a/npbench/benchmarks/deep_learning/lenet/lenet_jax.py b/npbench/benchmarks/deep_learning/lenet/lenet_jax.py new file mode 100644 index 0000000..f1dbbd4 --- /dev/null +++ b/npbench/benchmarks/deep_learning/lenet/lenet_jax.py @@ -0,0 +1,87 @@ +import jax.numpy as jnp +import jax +from jax import lax +from functools import partial + +@jax.jit +def relu(x): + return jnp.maximum(x, 0) + + +# Deep learning convolutional operator (stride = 1) +@jax.jit +def conv2d(input, weights): + K = weights.shape[0] # Assuming square kernel + N = input.shape[0] + H_out = input.shape[1] - K + 1 + W_out = input.shape[2] - K + 1 + C_out = weights.shape[3] + output = jnp.empty((N, H_out, W_out, C_out), dtype=jnp.float32) + + def row_update(output, i): + def col_update(output, j): + input_slice = lax.dynamic_slice( + input, + (0, i, j, 0), + (N, K, K, input.shape[-1]) + ) + conv_result = jnp.sum( + input_slice[:, :, :, :, None] * weights[None, :, :, :], + axis=(1, 2, 3) + ) + output = lax.dynamic_update_slice( + output, + conv_result[:, None, None, :], + (0, i, j, 0) + ) + return output, None + + output, _ = lax.scan(col_update, output, jnp.arange(W_out)) + return output, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) + + return output + + +# 2x2 maxpool operator, as used in LeNet-5 +@jax.jit +def maxpool2d(x): + output = jnp.empty( + [x.shape[0], x.shape[1] // 2, x.shape[2] // 2, x.shape[3]], + dtype=x.dtype) + + def row_update(output, i): + def col_update(output, j): + input_slice = lax.dynamic_slice( + x, + (0, 2 * i, 2 * j, 0), + (x.shape[0], 2, 2, x.shape[3]) + ) + output = lax.dynamic_update_slice( + output, + jnp.max(input_slice, axis=(1, 2))[:, None, None, :], + (0, i, j, 0) + ) + return output, None + + output, _ = lax.scan(col_update, output, jnp.arange(x.shape[2] // 2)) + return output, None + + output, _ = lax.scan(row_update, output, jnp.arange(x.shape[1] // 2)) + + return output + + +# LeNet-5 Convolutional Neural Network (inference mode) +@partial(jax.jit, static_argnums=(11, 12)) +def lenet5(input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, + fc3w, fc3b, N, C_before_fc1): + x = relu(conv2d(input, conv1) + conv1bias) + x = maxpool2d(x) + x = relu(conv2d(x, conv2) + conv2bias) + x = maxpool2d(x) + x = jnp.reshape(x, (N, C_before_fc1)) + x = relu(x @ fc1w + fc1b) + x = relu(x @ fc2w + fc2b) + return x @ fc3w + fc3b From 00128b755e17f3f598f1dd74bfb9153ad3abae0b Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Thu, 21 Nov 2024 12:29:17 +0100 Subject: [PATCH 069/106] Add resnet jax implementation --- .../deep_learning/resnet/resnet_jax.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 npbench/benchmarks/deep_learning/resnet/resnet_jax.py diff --git a/npbench/benchmarks/deep_learning/resnet/resnet_jax.py b/npbench/benchmarks/deep_learning/resnet/resnet_jax.py new file mode 100644 index 0000000..e573da9 --- /dev/null +++ b/npbench/benchmarks/deep_learning/resnet/resnet_jax.py @@ -0,0 +1,73 @@ +import jax.numpy as jnp +import jax +from jax import lax + +@jax.jit +def relu(x): + return jnp.maximum(x, 0) + + +# Deep learning convolutional operator (stride = 1) +@jax.jit +def conv2d(input, weights): + K = weights.shape[0] # Assuming square kernel + N = input.shape[0] + H_out = input.shape[1] - K + 1 + W_out = input.shape[2] - K + 1 + C_out = weights.shape[3] + output = jnp.empty((N, H_out, W_out, C_out), dtype=jnp.float32) + + def row_update(output, i): + def col_update(output, j): + input_slice = lax.dynamic_slice( + input, + (0, i, j, 0), + (N, K, K, input.shape[-1]) + ) + conv_result = jnp.sum( + input_slice[:, :, :, :, None] * weights[None, :, :, :], + axis=(1, 2, 3) + ) + output = lax.dynamic_update_slice( + output, + conv_result[:, None, None, :], + (0, i, j, 0) + ) + return output, None + + output, _ = lax.scan(col_update, output, jnp.arange(W_out)) + return output, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) + return output + + +# Batch normalization operator, as used in ResNet +@jax.jit +def batchnorm2d(x, eps=1e-5): + mean = jnp.mean(x, axis=0, keepdims=True) + std = jnp.std(x, axis=0, keepdims=True) + return (x - mean) / jnp.sqrt(std + eps) + + +# Bottleneck residual block (after initial convolution, without downsampling) +# in the ResNet-50 CNN (inference) +@jax.jit +def resnet_basicblock(input, conv1, conv2, conv3): + # Pad output of first convolution for second convolution + padded = jnp.zeros((input.shape[0], input.shape[1] + 2, input.shape[2] + 2, + conv1.shape[3]), dtype=jnp.float32) + padded = lax.dynamic_update_slice( + padded, + conv2d(input, conv1), + (0, 1, 1, 0) + ) + x = batchnorm2d(padded) + x = relu(x) + + x = conv2d(x, conv2) + x = batchnorm2d(x) + x = relu(x) + x = conv2d(x, conv3) + x = batchnorm2d(x) + return relu(x + input) From 6626d60f0f0fe2c800f61b67e40ef61dcbd1050d Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Thu, 21 Nov 2024 13:15:45 +0100 Subject: [PATCH 070/106] Add channel_flow jax implementation --- .../channel_flow/channel_flow_jax.py | 173 ++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 npbench/benchmarks/channel_flow/channel_flow_jax.py diff --git a/npbench/benchmarks/channel_flow/channel_flow_jax.py b/npbench/benchmarks/channel_flow/channel_flow_jax.py new file mode 100644 index 0000000..be14315 --- /dev/null +++ b/npbench/benchmarks/channel_flow/channel_flow_jax.py @@ -0,0 +1,173 @@ +# Barba, Lorena A., and Forsyth, Gilbert F. (2018). +# CFD Python: the 12 steps to Navier-Stokes equations. +# Journal of Open Source Education, 1(9), 21, +# https://doi.org/10.21105/jose.00021 +# TODO: License +# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth. +# All content is under Creative Commons Attribution CC-BY 4.0, +# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018). + +import jax.numpy as jnp +import jax +from jax import lax +from functools import partial + + +@partial(jax.jit, static_argnums=(0,)) +def build_up_b(rho, dt, dx, dy, u, v): + b = jnp.zeros_like(u) + b = b.at[1:-1, + 1:-1].set((rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) + + (v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) - + ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 * + ((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) * + (v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) - + ((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2))) + + # Periodic BC Pressure @ x = 2 + b = b.at[1:-1, -1].set((rho * (1 / dt * ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx) + + (v[2:, -1] - v[0:-2, -1]) / (2 * dy)) - + ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx))**2 - 2 * + ((u[2:, -1] - u[0:-2, -1]) / (2 * dy) * + (v[1:-1, 0] - v[1:-1, -2]) / (2 * dx)) - + ((v[2:, -1] - v[0:-2, -1]) / (2 * dy))**2))) + + # Periodic BC Pressure @ x = 0 + b = b.at[1:-1, 0].set((rho * (1 / dt * ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx) + + (v[2:, 0] - v[0:-2, 0]) / (2 * dy)) - + ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx))**2 - 2 * + ((u[2:, 0] - u[0:-2, 0]) / (2 * dy) * + (v[1:-1, 1] - v[1:-1, -1]) / + (2 * dx)) - ((v[2:, 0] - v[0:-2, 0]) / (2 * dy))**2))) + + return b + +@partial(jax.jit, static_argnums=(0,)) +def pressure_poisson_periodic(nit, p, dx, dy, b): + pn = jnp.empty_like(p) + + def body_func(p, q): + pn = p.copy() + p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 + + (pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) / + (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / + (2 * (dx**2 + dy**2)) * b[1:-1, 1:-1]) + + # Periodic BC Pressure @ x = 2 + p = p.at[1:-1, -1].set(((pn[1:-1, 0] + pn[1:-1, -2]) * dy**2 + + (pn[2:, -1] + pn[0:-2, -1]) * dx**2) / + (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / + (2 * (dx**2 + dy**2)) * b[1:-1, -1]) + + # Periodic BC Pressure @ x = 0 + p = p.at[1:-1, + 0].set((((pn[1:-1, 1] + pn[1:-1, -1]) * dy**2 + + (pn[2:, 0] + pn[0:-2, 0]) * dx**2) / (2 * (dx**2 + dy**2)) - + dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 0])) + + # Wall boundary conditions, pressure + p = p.at[-1, :].set(p[-2, :]) # dp/dy = 0 at y = 2 + p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0 + + return p, None + + p, _ = lax.scan(body_func, p, jnp.arange(nit)) + + +@partial(jax.jit, static_argnums=(0,7,8,9)) +def channel_flow(nit, u, v, dt, dx, dy, p, rho, nu, F): + udiff = 1 + stepcount = 0 + + array_vals = (udiff, stepcount, u, v, dt, dx, dy, p) + + def conf_func(array_vals): + udiff, _, _, _, _, _, _, _ = array_vals + return udiff > .001 + + def body_func(array_vals): + _, stepcount, u, v, dt, dx, dy, p = array_vals + + un = u.copy() + vn = v.copy() + + b = build_up_b(rho, dt, dx, dy, u, v) + pressure_poisson_periodic(nit, p, dx, dy, b) + + u = u.at[1:-1, + 1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * + (un[1:-1, 1:-1] - un[1:-1, 0:-2]) - + vn[1:-1, 1:-1] * dt / dy * + (un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) * + (p[1:-1, 2:] - p[1:-1, 0:-2]) + nu * + (dt / dx**2 * + (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) + + dt / dy**2 * + (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])) + + F * dt) + + v = v.at[1:-1, + 1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * + (vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) - + vn[1:-1, 1:-1] * dt / dy * + (vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) * + (p[2:, 1:-1] - p[0:-2, 1:-1]) + nu * + (dt / dx**2 * + (vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) + + dt / dy**2 * + (vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1]))) + + # Periodic BC u @ x = 2 + u = u.at[1:-1, -1].set( + un[1:-1, -1] - un[1:-1, -1] * dt / dx * + (un[1:-1, -1] - un[1:-1, -2]) - vn[1:-1, -1] * dt / dy * + (un[1:-1, -1] - un[0:-2, -1]) - dt / (2 * rho * dx) * + (p[1:-1, 0] - p[1:-1, -2]) + nu * + (dt / dx**2 * + (un[1:-1, 0] - 2 * un[1:-1, -1] + un[1:-1, -2]) + dt / dy**2 * + (un[2:, -1] - 2 * un[1:-1, -1] + un[0:-2, -1])) + F * dt) + + # Periodic BC u @ x = 0 + u = u.at[1:-1, + 0].set(un[1:-1, 0] - un[1:-1, 0] * dt / dx * + (un[1:-1, 0] - un[1:-1, -1]) - vn[1:-1, 0] * dt / dy * + (un[1:-1, 0] - un[0:-2, 0]) - dt / (2 * rho * dx) * + (p[1:-1, 1] - p[1:-1, -1]) + nu * + (dt / dx**2 * + (un[1:-1, 1] - 2 * un[1:-1, 0] + un[1:-1, -1]) + dt / dy**2 * + (un[2:, 0] - 2 * un[1:-1, 0] + un[0:-2, 0])) + F * dt) + + # Periodic BC v @ x = 2 + v = v.at[1:-1, -1].set( + vn[1:-1, -1] - un[1:-1, -1] * dt / dx * + (vn[1:-1, -1] - vn[1:-1, -2]) - vn[1:-1, -1] * dt / dy * + (vn[1:-1, -1] - vn[0:-2, -1]) - dt / (2 * rho * dy) * + (p[2:, -1] - p[0:-2, -1]) + nu * + (dt / dx**2 * + (vn[1:-1, 0] - 2 * vn[1:-1, -1] + vn[1:-1, -2]) + dt / dy**2 * + (vn[2:, -1] - 2 * vn[1:-1, -1] + vn[0:-2, -1]))) + + # Periodic BC v @ x = 0 + v = v.at[1:-1, + 0].set(vn[1:-1, 0] - un[1:-1, 0] * dt / dx * + (vn[1:-1, 0] - vn[1:-1, -1]) - vn[1:-1, 0] * dt / dy * + (vn[1:-1, 0] - vn[0:-2, 0]) - dt / (2 * rho * dy) * + (p[2:, 0] - p[0:-2, 0]) + nu * + (dt / dx**2 * + (vn[1:-1, 1] - 2 * vn[1:-1, 0] + vn[1:-1, -1]) + dt / dy**2 * + (vn[2:, 0] - 2 * vn[1:-1, 0] + vn[0:-2, 0]))) + + # Wall BC: u,v = 0 @ y = 0,2 + u = u.at[0, :].set(0) + u = u.at[-1, :].set(0) + v = v.at[0, :].set(0) + v = v.at[-1, :].set(0) + + udiff = (jnp.sum(u) - jnp.sum(un)) / jnp.sum(u) + stepcount += 1 + + return (udiff, stepcount, u, v, dt, dx, dy, p) + + _, stepcount, _, _, _, _, _, _= lax.while_loop(conf_func, body_func, array_vals) + + return stepcount From 8caad35811c8797dc29235d7ad6d603129d8f369 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Thu, 21 Nov 2024 13:37:41 +0100 Subject: [PATCH 071/106] Add cavity_flow jax implementation --- .../benchmarks/cavity_flow/cavity_flow_jax.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 npbench/benchmarks/cavity_flow/cavity_flow_jax.py diff --git a/npbench/benchmarks/cavity_flow/cavity_flow_jax.py b/npbench/benchmarks/cavity_flow/cavity_flow_jax.py new file mode 100644 index 0000000..2fa1ea6 --- /dev/null +++ b/npbench/benchmarks/cavity_flow/cavity_flow_jax.py @@ -0,0 +1,102 @@ +# Barba, Lorena A., and Forsyth, Gilbert F. (2018). +# CFD Python: the 12 steps to Navier-Stokes equations. +# Journal of Open Source Education, 1(9), 21, +# https://doi.org/10.21105/jose.00021 +# TODO: License +# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth. +# All content is under Creative Commons Attribution CC-BY 4.0, +# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018). + +import jax.numpy as jnp +import jax +from jax import lax +from functools import partial + + +@partial(jax.jit, static_argnums=(1,)) +def build_up_b(b, rho, dt, u, v, dx, dy): + + b = b.at[1:-1, + 1:-1].set(rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) + + (v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) - + ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 * + ((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) * + (v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) - + ((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2)) + + return b + + +@partial(jax.jit, static_argnums=(0,)) +def pressure_poisson(nit, p, dx, dy, b): + def body_func(p, _): + pn = p.copy() + p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 + + (pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) / + (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / + (2 * (dx**2 + dy**2)) * b[1:-1, 1:-1]) + + p = p.at[:, -1].set(p[:, -2]) # dp/dx = 0 at x = 2 + p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0 + p = p.at[:, 0].set(p[:, 1]) # dp/dx = 0 at x = 0 + p = p.at[-1, :].set(0) # p = 0 at y = 2 + + return p, None + + p, _ = lax.scan(body_func, p, jnp.arange(nit)) + + return p + + +@partial(jax.jit, static_argnums=(0,1,2,3,10,11,)) +def cavity_flow(nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu): + b = jnp.zeros((ny, nx)) + array_vals = (u, v, p, b) + + def body_func(array_vals, _): + + u, v, p, b = array_vals + + un = u.copy() + vn = v.copy() + + b = build_up_b(b, rho, dt, u, v, dx, dy) + p = pressure_poisson(nit, p, dx, dy, b) + + u = u.at[1:-1, + 1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * + (un[1:-1, 1:-1] - un[1:-1, 0:-2]) - + vn[1:-1, 1:-1] * dt / dy * + (un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) * + (p[1:-1, 2:] - p[1:-1, 0:-2]) + nu * + (dt / dx**2 * + (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) + + dt / dy**2 * + (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1]))) + + v = v.at[1:-1, + 1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * + (vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) - + vn[1:-1, 1:-1] * dt / dy * + (vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) * + (p[2:, 1:-1] - p[0:-2, 1:-1]) + nu * + (dt / dx**2 * + (vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) + + dt / dy**2 * + (vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1]))) + + u = u.at[0, :].set(0) + u = u.at[:, 0].set(0) + u = u.at[:, -1].set(0) + u = u.at[-1, :].set(1) # set velocity on cavity lid equal to 1 + v = v.at[0, :].set(0) + v = v.at[-1, :].set(0) + v = v.at[:, 0].set(0) + v = v.at[:, -1].set(0) + + return (u, v, p, b), None + + out_vals, _ = lax.scan(body_func, array_vals, jnp.arange(nt)) + u, v, p, b = out_vals + + return u, v, p, b \ No newline at end of file From 5d568b655fc6422318f05036d09f5f4e952e5135 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Thu, 21 Nov 2024 13:43:33 +0100 Subject: [PATCH 072/106] Update channel_flow_jax --- npbench/benchmarks/channel_flow/channel_flow_jax.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/npbench/benchmarks/channel_flow/channel_flow_jax.py b/npbench/benchmarks/channel_flow/channel_flow_jax.py index be14315..eb5ffbc 100644 --- a/npbench/benchmarks/channel_flow/channel_flow_jax.py +++ b/npbench/benchmarks/channel_flow/channel_flow_jax.py @@ -44,7 +44,6 @@ def build_up_b(rho, dt, dx, dy, u, v): @partial(jax.jit, static_argnums=(0,)) def pressure_poisson_periodic(nit, p, dx, dy, b): - pn = jnp.empty_like(p) def body_func(p, q): pn = p.copy() @@ -79,14 +78,14 @@ def channel_flow(nit, u, v, dt, dx, dy, p, rho, nu, F): udiff = 1 stepcount = 0 - array_vals = (udiff, stepcount, u, v, dt, dx, dy, p) + array_vals = (udiff, stepcount, u, v, p) def conf_func(array_vals): - udiff, _, _, _, _, _, _, _ = array_vals + udiff, _, _, _ , _ = array_vals return udiff > .001 def body_func(array_vals): - _, stepcount, u, v, dt, dx, dy, p = array_vals + _, stepcount, u, v, p = array_vals un = u.copy() vn = v.copy() @@ -166,8 +165,8 @@ def body_func(array_vals): udiff = (jnp.sum(u) - jnp.sum(un)) / jnp.sum(u) stepcount += 1 - return (udiff, stepcount, u, v, dt, dx, dy, p) + return (udiff, stepcount, u, v, p) - _, stepcount, _, _, _, _, _, _= lax.while_loop(conf_func, body_func, array_vals) + _, stepcount, _, _, _ = lax.while_loop(conf_func, body_func, array_vals) return stepcount From 1c42505d6b7531694fc708f60ed93dee42eff69d Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Thu, 21 Nov 2024 13:44:58 +0100 Subject: [PATCH 073/106] Update cavity_flow_jax: remove return of modified arrays used for validation --- npbench/benchmarks/cavity_flow/cavity_flow_jax.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/npbench/benchmarks/cavity_flow/cavity_flow_jax.py b/npbench/benchmarks/cavity_flow/cavity_flow_jax.py index 2fa1ea6..47c407f 100644 --- a/npbench/benchmarks/cavity_flow/cavity_flow_jax.py +++ b/npbench/benchmarks/cavity_flow/cavity_flow_jax.py @@ -98,5 +98,3 @@ def body_func(array_vals, _): out_vals, _ = lax.scan(body_func, array_vals, jnp.arange(nt)) u, v, p, b = out_vals - - return u, v, p, b \ No newline at end of file From fb10d475ef08430915cfa63b7cd89f7c79c560df Mon Sep 17 00:00:00 2001 From: Sushant S Date: Thu, 21 Nov 2024 15:42:37 +0100 Subject: [PATCH 074/106] add nbody --- npbench/benchmarks/nbody/nbody_jax.py | 132 ++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 npbench/benchmarks/nbody/nbody_jax.py diff --git a/npbench/benchmarks/nbody/nbody_jax.py b/npbench/benchmarks/nbody/nbody_jax.py new file mode 100644 index 0000000..d34dfba --- /dev/null +++ b/npbench/benchmarks/nbody/nbody_jax.py @@ -0,0 +1,132 @@ +# Adapted from https://github.com/pmocz/nbody-python/blob/master/nbody.py +# TODO: Add GPL-3.0 License + +import jax +import jax.numpy as jnp +from jax import lax +from functools import partial +""" +Create Your Own N-body Simulation (With Python) +Philip Mocz (2020) Princeton Univeristy, @PMocz +Simulate orbits of stars interacting due to gravity +Code calculates pairwise forces according to Newton's Law of Gravity +""" + +@jax.jit +def getAcc(pos, mass, G, softening): + """ + Calculate the acceleration on each particle due to Newton's Law + pos is an N x 3 matrix of positions + mass is an N x 1 vector of masses + G is Newton's Gravitational constant + softening is the softening length + a is N x 3 matrix of accelerations + """ + # positions r = [x,y,z] for all particles + x = pos[:, 0:1] + y = pos[:, 1:2] + z = pos[:, 2:3] + + # matrix that stores all pairwise particle separations: r_j - r_i + dx = x.T - x + dy = y.T - y + dz = z.T - z + + # matrix that stores 1/r^3 for all particle pairwise particle separations + inv_r3 = (dx**2 + dy**2 + dz**2 + softening**2) + inv_r3 = jnp.where(inv_r3 > 0, inv_r3**(-1.5), inv_r3) + + ax = G * (dx * inv_r3) @ mass + ay = G * (dy * inv_r3) @ mass + az = G * (dz * inv_r3) @ mass + + # pack together the acceleration components + a = jnp.hstack((ax, ay, az)) + + return a + +@jax.jit +def getEnergy(pos, vel, mass, G): + """ + Get kinetic energy (KE) and potential energy (PE) of simulation + pos is N x 3 matrix of positions + vel is N x 3 matrix of velocities + mass is an N x 1 vector of masses + G is Newton's Gravitational constant + KE is the kinetic energy of the system + PE is the potential energy of the system + """ + # Kinetic Energy: + # KE = 0.5 * np.sum(np.sum( mass * vel**2 )) + KE = 0.5 * jnp.sum(mass * vel**2) + + # Potential Energy: + + # positions r = [x,y,z] for all particles + x = pos[:, 0:1] + y = pos[:, 1:2] + z = pos[:, 2:3] + + # matrix that stores all pairwise particle separations: r_j - r_i + dx = x.T - x + dy = y.T - y + dz = z.T - z + + # matrix that stores 1/r for all particle pairwise particle separations + inv_r = jnp.sqrt(dx**2 + dy**2 + dz**2) + inv_r = jnp.where(inv_r > 0, 1.0 / inv_r, inv_r) + + # sum over upper triangle, to count each interaction only once + # PE = G * np.sum(np.sum(np.triu(-(mass*mass.T)*inv_r,1))) + PE = G * jnp.sum(jnp.triu(-(mass * mass.T) * inv_r, 1)) + + return KE, PE + +@partial(jax.jit, static_argnums=(4,)) +def nbody(mass, pos, vel, N, Nt, dt, G, softening): + + # Convert to Center-of-Mass frame + vel -= jnp.mean(mass * vel, axis=0) / jnp.mean(mass) + + # calculate initial gravitational accelerations + acc = getAcc(pos, mass, G, softening) + + # calculate initial energy of system + KE = jnp.zeros(Nt + 1, dtype=jnp.float64) + PE = jnp.zeros(Nt + 1, dtype=jnp.float64) + ke, pe = getEnergy(pos, vel, mass, G) + KE = KE.at[0].set(ke) + PE = PE.at[0].set(pe) + + t = 0.0 + + def loop_body(i, loop_vars): + pos, vel, acc, KE, PE, t = loop_vars + + # (1/2) kick + vel += acc * dt / 2.0 + + # drift + pos += vel * dt + + # update accelerations + acc = getAcc(pos, mass, G, softening) + + # (1/2) kick + vel += acc * dt / 2.0 + + # update time + t += dt + + # get energy of system + ke, pe = getEnergy(pos, vel, mass, G) + + KE = KE.at[i + 1].set(ke) + PE = PE.at[i + 1].set(pe) + + return pos, vel, acc, KE, PE, t + + # Simulation Main Loop + pos, vel, acc, KE, PE, t = lax.fori_loop(0, Nt, loop_body, (pos, vel, acc, KE, PE, t)) + + return KE, PE From aa22c518e346e955accf13c0566c0746c35dfd9b Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Fri, 22 Nov 2024 21:43:29 +0100 Subject: [PATCH 075/106] Remove standalone comments and some type hints to not inflate "lines changes" metric Removed this to keep it identical to the NumPy implementations. --- npbench/benchmarks/polybench/adi/adi_jax.py | 14 +++----------- npbench/benchmarks/polybench/adi/adi_numpy.py | 2 ++ npbench/benchmarks/polybench/atax/atax_jax.py | 2 +- .../benchmarks/polybench/deriche/deriche_jax.py | 1 - .../benchmarks/polybench/doitgen/doitgen_jax.py | 2 +- npbench/benchmarks/polybench/gemm/gemm_jax.py | 2 +- npbench/benchmarks/polybench/k2mm/k2mm_jax.py | 2 +- .../scattering_self_energies_jax.py | 10 +++------- 8 files changed, 12 insertions(+), 23 deletions(-) diff --git a/npbench/benchmarks/polybench/adi/adi_jax.py b/npbench/benchmarks/polybench/adi/adi_jax.py index 65b3c15..827a577 100644 --- a/npbench/benchmarks/polybench/adi/adi_jax.py +++ b/npbench/benchmarks/polybench/adi/adi_jax.py @@ -3,12 +3,11 @@ from jax import lax def kernel(TSTEPS, N, u): - # Initialize arrays + v = jnp.zeros_like(u) p = jnp.zeros_like(u) q = jnp.zeros_like(u) - # Constants DX = 1.0 / N DY = 1.0 / N DT = 1.0 / TSTEPS @@ -34,7 +33,7 @@ def first_j_loop_body(j, carry): def first_backward_j_loop_body(j, carry): v, p, q = carry - idx = N-2-j # Calculate the actual index for backward iteration + idx = N-2-j v = v.at[idx, 1:N-1].set(p[1:N-1, idx] * v[idx+1, 1:N-1] + q[1:N-1, idx]) return (v, p, q) @@ -48,42 +47,35 @@ def second_j_loop_body(j, carry): def second_backward_j_loop_body(j, carry): u, p, q = carry - idx = N-2-j # Calculate the actual index for backward iteration + idx = N-2-j u = u.at[1:N-1, idx].set(p[1:N-1, idx] * u[1:N-1, idx+1] + q[1:N-1, idx]) return (u, p, q) def time_step_body(t, carry): u, v, p, q = carry - # First part v = v.at[0, 1:N-1].set(1.0) p = p.at[1:N-1, 0].set(0.0) q = q.at[1:N-1, 0].set(v[0, 1:N-1]) - # First j loop p, q, u = lax.fori_loop(1, N-1, first_j_loop_body, (p, q, u)) v = v.at[N-1, 1:N-1].set(1.0) - # First backward j loop v, p, q = lax.fori_loop(0, N-2, first_backward_j_loop_body, (v, p, q)) - # Second part u = u.at[1:N-1, 0].set(1.0) p = p.at[1:N-1, 0].set(0.0) q = q.at[1:N-1, 0].set(u[1:N-1, 0]) - # Second j loop p, q, v = lax.fori_loop(1, N-1, second_j_loop_body, (p, q, v)) u = u.at[1:N-1, N-1].set(1.0) - # Second backward j loop u, p, q = lax.fori_loop(0, N-2, second_backward_j_loop_body, (u, p, q)) return (u, v, p, q) - # Main time loop u, v, p, q = lax.fori_loop(1, TSTEPS + 1, time_step_body, (u, v, p, q)) return u \ No newline at end of file diff --git a/npbench/benchmarks/polybench/adi/adi_numpy.py b/npbench/benchmarks/polybench/adi/adi_numpy.py index 24920ba..c5df9a7 100644 --- a/npbench/benchmarks/polybench/adi/adi_numpy.py +++ b/npbench/benchmarks/polybench/adi/adi_numpy.py @@ -50,3 +50,5 @@ def kernel(TSTEPS, N, u): u[1:N - 1, N - 1] = 1.0 for j in range(N - 2, 0, -1): u[1:N - 1, j] = p[1:N - 1, j] * u[1:N - 1, j + 1] + q[1:N - 1, j] + + return u \ No newline at end of file diff --git a/npbench/benchmarks/polybench/atax/atax_jax.py b/npbench/benchmarks/polybench/atax/atax_jax.py index 26ed193..a0c0dda 100644 --- a/npbench/benchmarks/polybench/atax/atax_jax.py +++ b/npbench/benchmarks/polybench/atax/atax_jax.py @@ -2,5 +2,5 @@ import jax.numpy as jnp @jax.jit -def kernel(A: jax.Array, x: jax.Array): +def kernel(A, x): return (A @ x) @ A diff --git a/npbench/benchmarks/polybench/deriche/deriche_jax.py b/npbench/benchmarks/polybench/deriche/deriche_jax.py index a390a33..978c64f 100644 --- a/npbench/benchmarks/polybench/deriche/deriche_jax.py +++ b/npbench/benchmarks/polybench/deriche/deriche_jax.py @@ -42,7 +42,6 @@ def horizontal_backward(j, y2): imgOut = c1 * (y1 + y2) - # First vertical pass y1 = jnp.empty_like(imgOut) y1 = y1.at[0, :].set(a5 * imgOut[0, :]) y1 = y1.at[1, :].set(a5 * imgOut[1, :] + a6 * imgOut[0, :] + b1 * y1[0, :]) diff --git a/npbench/benchmarks/polybench/doitgen/doitgen_jax.py b/npbench/benchmarks/polybench/doitgen/doitgen_jax.py index 8e6caf9..2d77255 100644 --- a/npbench/benchmarks/polybench/doitgen/doitgen_jax.py +++ b/npbench/benchmarks/polybench/doitgen/doitgen_jax.py @@ -4,7 +4,7 @@ from functools import partial @partial(jax.jit, static_argnums=(0, 1, 2)) -def kernel(NR: int, NQ: int, NP: int, A:jax.Array, C4:jax.Array): +def kernel(NR, NQ, NP, A, C4): # for r in range(NR): # for q in range(NQ): diff --git a/npbench/benchmarks/polybench/gemm/gemm_jax.py b/npbench/benchmarks/polybench/gemm/gemm_jax.py index 6dfe60a..b7254f0 100644 --- a/npbench/benchmarks/polybench/gemm/gemm_jax.py +++ b/npbench/benchmarks/polybench/gemm/gemm_jax.py @@ -2,7 +2,7 @@ import jax.numpy as jnp @jax.jit -def kernel(alpha: jnp.float64, beta: jnp.float64, C:jax.Array, A:jax.Array, B:jax.Array): +def kernel(alpha, beta, C, A, B): C = C.at[:].set(alpha * A @ B + beta * C) return C diff --git a/npbench/benchmarks/polybench/k2mm/k2mm_jax.py b/npbench/benchmarks/polybench/k2mm/k2mm_jax.py index db2a15f..fcb801c 100644 --- a/npbench/benchmarks/polybench/k2mm/k2mm_jax.py +++ b/npbench/benchmarks/polybench/k2mm/k2mm_jax.py @@ -3,7 +3,7 @@ @jax.jit -def kernel(alpha: jnp.float64, beta: jnp.float64, A: jax.Array, B: jax.Array, C: jax.Array, D: jax.Array): +def kernel(alpha, beta, A, B, C, D): D = D.at[:].set(alpha * A @ B @ C + beta * D) return D diff --git a/npbench/benchmarks/scattering_self_energies/scattering_self_energies_jax.py b/npbench/benchmarks/scattering_self_energies/scattering_self_energies_jax.py index a8f34cc..8bcc1a9 100644 --- a/npbench/benchmarks/scattering_self_energies/scattering_self_energies_jax.py +++ b/npbench/benchmarks/scattering_self_energies/scattering_self_energies_jax.py @@ -13,7 +13,6 @@ def body_fun(sigma, idx): return sigma.at[k, E, a].add(update), None - # Create all possible index combinations k_range = jnp.arange(G.shape[0]) E_range = jnp.arange(G.shape[1]) q_range = jnp.arange(D.shape[0]) @@ -23,17 +22,14 @@ def body_fun(sigma, idx): i_range = jnp.arange(D.shape[-2]) j_range = jnp.arange(D.shape[-1]) - # Create meshgrid of indices - indices = jnp.meshgrid( + indices = jnp.meshgrid( # Create meshgrid of indices k_range, E_range, q_range, w_range, a_range, b_range, i_range, j_range, indexing='ij' ) - # Reshape indices into a single array of 8-tuples - indices = jnp.stack([idx.ravel() for idx in indices], axis=1) + indices = jnp.stack([idx.ravel() for idx in indices], axis=1) # Reshape indices into a single array of 8-tuples - # Use scan to iterate over all index combinations - result, _ = jax.lax.scan(body_fun, Sigma, indices) + result, _ = jax.lax.scan(body_fun, Sigma, indices) # Use scan to iterate over all index combinations return result From c9b9f61385b2105598fced6cb33acf1ae4b70040 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Sat, 23 Nov 2024 11:29:44 +0100 Subject: [PATCH 076/106] update correlation --- npbench/benchmarks/polybench/correlation/correlation_jax.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/npbench/benchmarks/polybench/correlation/correlation_jax.py b/npbench/benchmarks/polybench/correlation/correlation_jax.py index 2dde7a8..b7db052 100644 --- a/npbench/benchmarks/polybench/correlation/correlation_jax.py +++ b/npbench/benchmarks/polybench/correlation/correlation_jax.py @@ -9,8 +9,10 @@ def kernel(M, float_n, data): def loop_body(i, loop_vars): corr, data = loop_vars - corr = lax.dynamic_update_slice(corr, jnp.roll(data[i, None], -(i + 1), axis=1), (i, i + 1)) - corr = lax.dynamic_update_slice(corr, jnp.roll(data[i, None], -(i + 1), axis=1).T, (i + 1, i)) + corr_update_x = jnp.where(jnp.arange(data.shape[0]) > i, data[i], corr[i]) + corr_update_y = jnp.where(jnp.arange(data.shape[0]) > i, data[i], corr[:, i]) + corr = lax.dynamic_update_slice(corr, corr_update_x[None, :], (i, 0)) + corr = lax.dynamic_update_slice(corr, corr_update_y[:, None], (0, i)) return corr, data From 4b045b8196be9fcc412262d15deb80198972852c Mon Sep 17 00:00:00 2001 From: Sushant S Date: Sat, 23 Nov 2024 14:35:16 +0100 Subject: [PATCH 077/106] add hdiff --- .../weather_stencils/hdiff/hdiff_jax.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 npbench/benchmarks/weather_stencils/hdiff/hdiff_jax.py diff --git a/npbench/benchmarks/weather_stencils/hdiff/hdiff_jax.py b/npbench/benchmarks/weather_stencils/hdiff/hdiff_jax.py new file mode 100644 index 0000000..fa472fa --- /dev/null +++ b/npbench/benchmarks/weather_stencils/hdiff/hdiff_jax.py @@ -0,0 +1,32 @@ +import jax +import jax.numpy as jnp + +# Adapted from https://github.com/GridTools/gt4py/blob/1caca893034a18d5df1522ed251486659f846589/tests/test_integration/stencil_definitions.py#L194 +@jax.jit +def hdiff(in_field, out_field, coeff): + I, J, K = out_field.shape[0], out_field.shape[1], out_field.shape[2] + lap_field = 4.0 * in_field[1:I + 3, 1:J + 3, :] - ( + in_field[2:I + 4, 1:J + 3, :] + in_field[0:I + 2, 1:J + 3, :] + + in_field[1:I + 3, 2:J + 4, :] + in_field[1:I + 3, 0:J + 2, :]) + + res = lap_field[1:, 1:J + 1, :] - lap_field[:-1, 1:J + 1, :] + flx_field = jnp.where( + (res * + (in_field[2:I + 3, 2:J + 2, :] - in_field[1:I + 2, 2:J + 2, :])) > 0, + 0, + res, + ) + + res = lap_field[1:I + 1, 1:, :] - lap_field[1:I + 1, :-1, :] + fly_field = jnp.where( + (res * + (in_field[2:I + 2, 2:J + 3, :] - in_field[2:I + 2, 1:J + 2, :])) > 0, + 0, + res, + ) + + out_field = out_field.at[:, :, :].set(in_field[2:I + 2, 2:J + 2, :] - coeff[:, :, :] * ( + flx_field[1:, :, :] - flx_field[:-1, :, :] + fly_field[:, 1:, :] - + fly_field[:, :-1, :])) + + return out_field \ No newline at end of file From 448e92d0c48ca9a2905afa71c21415102ea05247 Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Sat, 23 Nov 2024 14:45:04 +0100 Subject: [PATCH 078/106] Add mandelbrot2 Jax implementation --- .../benchmarks/mandelbrot2/mandelbrot2_jax.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 npbench/benchmarks/mandelbrot2/mandelbrot2_jax.py diff --git a/npbench/benchmarks/mandelbrot2/mandelbrot2_jax.py b/npbench/benchmarks/mandelbrot2/mandelbrot2_jax.py new file mode 100644 index 0000000..4e42958 --- /dev/null +++ b/npbench/benchmarks/mandelbrot2/mandelbrot2_jax.py @@ -0,0 +1,52 @@ +# ----------------------------------------------------------------------------- +# From Numpy to Python +# Copyright (2017) Nicolas P. Rougier - BSD license +# More information at https://github.com/rougier/numpy-book +# ----------------------------------------------------------------------------- + +import jax +import jax.numpy as jnp +from functools import partial + +@partial(jax.jit, static_argnames=["xn", "yn", "itermax"]) +def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, itermax, horizon=2.0): + # Adapted from + # https://thesamovar.wordpress.com/2009/03/22/fast-fractals-with-python-and-numpy/ + Xi, Yi = jnp.mgrid[0:xn, 0:yn] + X = jnp.linspace(xmin, xmax, xn, dtype=jnp.float64)[Xi] + Y = jnp.linspace(ymin, ymax, yn, dtype=jnp.float64)[Yi] + C = X + Y * 1j + N_ = jnp.zeros(C.shape, dtype=jnp.int64) + Z_ = jnp.zeros(C.shape, dtype=jnp.complex128) + + original_shape = C.shape + Xi = Xi.reshape(-1) + Yi = Yi.reshape(-1) + C = C.reshape(-1) + + def body_fun(i, state): + Z, Xi, Yi, C, N_, Z_, mask = state + # Compute for relevant points only + Z = Z * Z + C + + # Failed convergence + I = abs(Z) > horizon + I = I & mask # Only consider points that haven't diverged yet + + N_ = jnp.where(I, i + 1, N_) + Z_ = jnp.where(I, Z, Z_) + + # Keep going with those who have not diverged yet + mask = mask & ~I + Z = jnp.where(mask, Z, 0) + + return (Z, Xi, Yi, C, N_, Z_, mask) + + init_state = (jnp.zeros_like(C, dtype=jnp.complex128), Xi, Yi, C, + N_.reshape(-1), Z_.reshape(-1), jnp.ones_like(C, dtype=bool)) + _, _, _, _, N_, Z_, _ = jax.lax.fori_loop(0, itermax, body_fun, init_state) + + Z_ = Z_.reshape(original_shape) # Reshape results back to original shape + N_ = N_.reshape(original_shape) + + return Z_.T, N_.T From 578e440bd58eeccedb7a58e59e92f2dfb88a0a3d Mon Sep 17 00:00:00 2001 From: Sushant S Date: Sat, 23 Nov 2024 14:55:53 +0100 Subject: [PATCH 079/106] add vadv --- .../weather_stencils/vadv/vadv_jax.py | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 npbench/benchmarks/weather_stencils/vadv/vadv_jax.py diff --git a/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py b/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py new file mode 100644 index 0000000..7642fac --- /dev/null +++ b/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py @@ -0,0 +1,108 @@ +import jax +import jax.numpy as jnp +from jax import lax + +# Sample constants +BET_M = 0.5 +BET_P = 0.5 + +# Adapted from https://github.com/GridTools/gt4py/blob/1caca893034a18d5df1522ed251486659f846589/tests/test_integration/stencil_definitions.py#L111 +@jax.jit +def vadv(utens_stage, u_stage, wcon, u_pos, utens, dtr_stage): + I, J, K = utens_stage.shape[0], utens_stage.shape[1], utens_stage.shape[2] + ccol = jnp.empty((I, J, K), dtype=utens_stage.dtype) + dcol = jnp.empty((I, J, K), dtype=utens_stage.dtype) + data_col = jnp.empty((I, J), dtype=utens_stage.dtype) + + def loop1(k, loop_vars): + wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = loop_vars + gcv = 0.25 * (wcon[1:, :, 0 + 1] + wcon[:-1, :, 0 + 1]) + cs = gcv * BET_M + ccol = ccol.at[:, :, k].set(gcv * BET_P) + bcol = dtr_stage - ccol[:, :, k] + + # update the d column + correction_term = -cs * (u_stage[:, :, k + 1] - u_stage[:, :, k]) + dcol = dcol.at[:, :, k].set((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + + utens_stage[:, :, k] + correction_term)) + + # Thomas forward + divided = 1.0 / bcol + ccol = ccol.at[:, :, k].set(ccol[:, :, k] * divided) + dcol = dcol.at[:, :, k].set(dcol[:, :, k] * divided) + + return wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol + + wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = lax.fori_loop(0, 1, loop1, (wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol)) + + def loop2(k, loop_vars): + wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = loop_vars + gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) + gcv = 0.25 * (wcon[1:, :, k + 1] + wcon[:-1, :, k + 1]) + + as_ = gav * BET_M + cs = gcv * BET_M + + acol = gav * BET_P + ccol = ccol.at[:, :, k].set(gcv * BET_P) + bcol = dtr_stage - acol - ccol[:, :, k] + + # update the d column + correction_term = -as_ * (u_stage[:, :, k - 1] - + u_stage[:, :, k]) - cs * ( + u_stage[:, :, k + 1] - u_stage[:, :, k]) + dcol = dcol.at[:, :, k].set((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + + utens_stage[:, :, k] + correction_term)) + + # Thomas forward + divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) + ccol = ccol.at[:, :, k].set(ccol[:, :, k] * divided) + dcol = dcol.at[:, :, k].set((dcol[:, :, k] - (dcol[:, :, k - 1]) * acol) * divided) + + return wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol + + wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = lax.fori_loop(1, K - 1, loop2, (wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol)) + + def loop3(k, loop_vars): + wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = loop_vars + gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) + as_ = gav * BET_M + acol = gav * BET_P + bcol = dtr_stage - acol + + # update the d column + correction_term = -as_ * (u_stage[:, :, k - 1] - u_stage[:, :, k]) + dcol = dcol.at[:, :, k].set((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + + utens_stage[:, :, k] + correction_term)) + + # Thomas forward + divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) + dcol = dcol.at[:, :, k].set((dcol[:, :, k] - (dcol[:, :, k - 1]) * acol) * divided) + + return wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol + + wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = lax.fori_loop(K - 1, K, loop3, (wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol)) + + def loop4(k, loop_vars): + dcol, data_col, utens_stage, u_pos, dtr_stage = loop_vars + datacol = dcol[:, :, k] + data_col = data_col.at[:].set(datacol) + utens_stage = utens_stage.at[:, :, k].set(dtr_stage * (datacol - u_pos[:, :, k])) + + return dcol, data_col, utens_stage, u_pos, dtr_stage + + dcol, data_col, utens_stage, u_pos, dtr_stage = lax.fori_loop(K - 1, K, loop4, (dcol, data_col, utens_stage, u_pos, dtr_stage)) + + def loop5(k, loop_vars): + dcol, data_col, utens_stage, u_pos, dtr_stage = loop_vars + K = utens_stage.shape[2] + k = K - 2 - k + datacol = dcol[:, :, k] - ccol[:, :, k] * data_col[:, :] + data_col = data_col.at[:].set(datacol) + utens_stage = utens_stage.at[:, :, k].set(dtr_stage * (datacol - u_pos[:, :, k])) + + return dcol, data_col, utens_stage, u_pos, dtr_stage + + dcol, data_col, utens_stage, u_pos, dtr_stage = lax.fori_loop(0, K - 1, loop5, (dcol, data_col, utens_stage, u_pos, dtr_stage)) + + return ccol, dcol, data_col, utens_stage From 107584bbb93eefca3befe839e4ddad3cd263cf5a Mon Sep 17 00:00:00 2001 From: Sushant S Date: Sat, 23 Nov 2024 14:56:01 +0100 Subject: [PATCH 080/106] update nbody --- npbench/benchmarks/nbody/nbody_jax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/npbench/benchmarks/nbody/nbody_jax.py b/npbench/benchmarks/nbody/nbody_jax.py index d34dfba..ca6a635 100644 --- a/npbench/benchmarks/nbody/nbody_jax.py +++ b/npbench/benchmarks/nbody/nbody_jax.py @@ -92,8 +92,8 @@ def nbody(mass, pos, vel, N, Nt, dt, G, softening): acc = getAcc(pos, mass, G, softening) # calculate initial energy of system - KE = jnp.zeros(Nt + 1, dtype=jnp.float64) - PE = jnp.zeros(Nt + 1, dtype=jnp.float64) + KE = jnp.empty(Nt + 1, dtype=jnp.float64) + PE = jnp.empty(Nt + 1, dtype=jnp.float64) ke, pe = getEnergy(pos, vel, mass, G) KE = KE.at[0].set(ke) PE = PE.at[0].set(pe) From a2d4464cb44fe25040a5585ff2b990c35b628eac Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Sat, 23 Nov 2024 16:51:55 +0100 Subject: [PATCH 081/106] Add nussinov jax implementation --- .../polybench/nussinov/nussinov_jax.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 npbench/benchmarks/polybench/nussinov/nussinov_jax.py diff --git a/npbench/benchmarks/polybench/nussinov/nussinov_jax.py b/npbench/benchmarks/polybench/nussinov/nussinov_jax.py new file mode 100644 index 0000000..b59e387 --- /dev/null +++ b/npbench/benchmarks/polybench/nussinov/nussinov_jax.py @@ -0,0 +1,64 @@ +import jax.numpy as jnp +import jax +from jax import lax +from functools import partial + +@jax.jit +def match(b1, b2): + return jnp.where(b1 + b2 == 3, 1, 0) + + +@partial(jax.jit, static_argnums=(0,)) +def kernel(N, seq): + + table = jnp.zeros((N, N), jnp.int32) + + def func_i(i, table): + i = N - 1 - i + def func_j(j, table): + table = table.at[i, j].set( + jnp.where( + j - 1 >= 0, + jnp.maximum(table[i, j], table[i, j - 1]), + table[i, j] + ) + ) + table = table.at[i, j].set( + jnp.where( + i + 1 < N, + jnp.maximum(table[i, j], table[i + 1, j]), + table[i, j] + ) + ) + table = table.at[i, j].set( + jnp.where( + (j - 1 >= 0) & (i + 1 < N) & (i < j - 1), + jnp.maximum(table[i, j], table[i + 1, j - 1] + match(seq[i], seq[j])), + table[i, j] + ) + ) + table = table.at[i, j].set( + jnp.where( + (j - 1 >= 0) & (i + 1 < N) & (i >= j - 1), + jnp.maximum(table[i, j], table[i + 1, j - 1]), + table[i, j] + ) + ) + + def func_k(k, table): + table = table.at[i, j].set( + jnp.maximum( + table[i, j], + table[i, k] + table[k + 1, j] + ) + ) + return table + + table = lax.fori_loop(i + 1, j, func_k, table) + return table + + table = lax.fori_loop(i + 1, N, func_j, table) + return table + + table = lax.fori_loop(0, N, func_i, table) + return table From fe87b835400267eac8570749691619b235bfad93 Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Sun, 24 Nov 2024 16:47:40 +0100 Subject: [PATCH 082/106] Fix(cavity_flow): return result arrays --- npbench/benchmarks/cavity_flow/cavity_flow_jax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/npbench/benchmarks/cavity_flow/cavity_flow_jax.py b/npbench/benchmarks/cavity_flow/cavity_flow_jax.py index 47c407f..4fbda96 100644 --- a/npbench/benchmarks/cavity_flow/cavity_flow_jax.py +++ b/npbench/benchmarks/cavity_flow/cavity_flow_jax.py @@ -98,3 +98,5 @@ def body_func(array_vals, _): out_vals, _ = lax.scan(body_func, array_vals, jnp.arange(nt)) u, v, p, b = out_vals + + return u, v, p From 23919b3922ff89951bb9b8c7fd3bf135fe7a6dd2 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Sun, 15 Dec 2024 21:57:43 +0100 Subject: [PATCH 083/106] add lib_implementation to jax_framework --- npbench/infrastructure/jax_framework.py | 65 +++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/npbench/infrastructure/jax_framework.py b/npbench/infrastructure/jax_framework.py index dac358c..f76bcd5 100644 --- a/npbench/infrastructure/jax_framework.py +++ b/npbench/infrastructure/jax_framework.py @@ -1,4 +1,5 @@ # Copyright 2021 ETH Zurich and the NPBench authors. All rights reserved. +import pathlib import jax.numpy as jnp import jax jax.config.update("jax_enable_x64", True) @@ -7,6 +8,10 @@ from typing import Any, Callable, Dict +_impl = { + 'lib-implementation': 'lib' +} + class JaxFramework(Framework): """ A class for reading and processing framework information. """ @@ -25,6 +30,66 @@ def copy_func(self) -> Callable: for copying the benchmark arguments. """ return jnp.array + def impl_files(self, bench: Benchmark): + """ Returns the framework's implementation files for a particular + benchmark. + :param bench: A benchmark. + :returns: A list of the benchmark implementation files. + """ + + parent_folder = pathlib.Path(__file__).parent.absolute() + implementations = [] + for impl_name, impl_postfix in _impl.items(): + pymod_path = parent_folder.joinpath( + "..", "..", "npbench", "benchmarks", bench.info["relative_path"], + bench.info["module_name"] + "_" + self.info["postfix"] + "_" + impl_postfix + ".py") + implementations.append((pymod_path, impl_name)) + + # appending the default implementation + pymod_path = parent_folder.joinpath("..", "..", "npbench", "benchmarks", bench.info["relative_path"], + bench.info["module_name"] + "_" + self.info["postfix"] + ".py") + + implementations.append((pymod_path, 'default')) + return implementations + + def implementations(self, bench: Benchmark): + """ Returns the framework's implementations for a particular benchmark. + :param bench: A benchmark. + :returns: A list of the benchmark implementations. + """ + + module_pypath = "npbench.benchmarks.{r}.{m}".format(r=bench.info["relative_path"].replace('/', '.'), + m=bench.info["module_name"]) + if "postfix" in self.info.keys(): + postfix = self.info["postfix"] + else: + postfix = self.fname + module_str = "{m}_{p}".format(m=module_pypath, p=postfix) + func_str = bench.info["func_name"] + + implementations = [] + for impl_name, impl_postfix in _impl.items(): + ldict = dict() + try: + exec("from {m}_{p} import {f} as impl".format(m=module_str, p=impl_postfix, f=func_str), ldict) + implementations.append((ldict['impl'], impl_name)) + except ImportError: + continue + except Exception: + print("Failed to load the {r} {f} implementation.".format(r=self.info["full_name"], f=impl_name)) + continue + + # appending the default implementation + try: + ldict = dict() + exec("from {m} import {f} as impl".format(m=module_str, f=func_str), ldict) + implementations.append((ldict['impl'], 'default')) + except Exception as e: + print("Failed to load the {r} {f} implementation.".format(r=self.info["full_name"], f=func_str)) + raise e + + return implementations + def exec_str(self, bench: Benchmark, impl: Callable = None): """ Generates the execution-string that should be used to call the benchmark implementation. From c016a09953dc7f5eebcfbf4bc648641452601b7e Mon Sep 17 00:00:00 2001 From: Sushant S Date: Sun, 15 Dec 2024 21:58:05 +0100 Subject: [PATCH 084/106] separate cov lib and cov default --- .../polybench/covariance/covariance_jax.py | 26 ++++++++++++++++--- .../covariance/covariance_jax_lib.py | 10 +++++++ 2 files changed, 33 insertions(+), 3 deletions(-) create mode 100644 npbench/benchmarks/polybench/covariance/covariance_jax_lib.py diff --git a/npbench/benchmarks/polybench/covariance/covariance_jax.py b/npbench/benchmarks/polybench/covariance/covariance_jax.py index 57a51a2..c36c3e7 100644 --- a/npbench/benchmarks/polybench/covariance/covariance_jax.py +++ b/npbench/benchmarks/polybench/covariance/covariance_jax.py @@ -1,10 +1,30 @@ import jax import jax.numpy as jnp +from jax import lax +from functools import partial - -@jax.jit +@partial(jax.jit, static_argnums=(0,)) def kernel(M, float_n, data): - cov = jnp.cov(data, rowvar=False) + mean = jnp.mean(data, axis=0) + data -= mean + + cov = jnp.zeros((M, M), dtype=data.dtype) + + def loop_body(i, loop_vars): + data, cov = loop_vars + cov_slice1 = data[:, i] + cov_slice2 = jnp.where(jnp.arange(M) >= i, data, 0.0) + + ans = cov_slice1 @ cov_slice2 / (float_n - 1.0) + row_update_slice = jnp.where(jnp.arange(M) >= i, ans, cov[i, :]) + col_update_slice = jnp.where(jnp.arange(M) >= i, ans, cov[:, i]) + cov = cov.at[i, :].set(row_update_slice) + cov = cov.at[:, i].set(col_update_slice) + + return data, cov + + _, cov = lax.fori_loop(0, M, loop_body, (data, cov)) + return cov diff --git a/npbench/benchmarks/polybench/covariance/covariance_jax_lib.py b/npbench/benchmarks/polybench/covariance/covariance_jax_lib.py new file mode 100644 index 0000000..e788823 --- /dev/null +++ b/npbench/benchmarks/polybench/covariance/covariance_jax_lib.py @@ -0,0 +1,10 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(M, float_n, data): + + cov = jnp.cov(data, rowvar=False) + + return cov \ No newline at end of file From 53ac30ed8aa143b7c70437503c75b2c3fc7bb418 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Mon, 16 Dec 2024 06:40:41 +0100 Subject: [PATCH 085/106] fix comments --- npbench/infrastructure/jax_framework.py | 31 ++++++++++++++----------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/npbench/infrastructure/jax_framework.py b/npbench/infrastructure/jax_framework.py index f76bcd5..b58625d 100644 --- a/npbench/infrastructure/jax_framework.py +++ b/npbench/infrastructure/jax_framework.py @@ -39,17 +39,19 @@ def impl_files(self, bench: Benchmark): parent_folder = pathlib.Path(__file__).parent.absolute() implementations = [] + + # appending the default implementation + pymod_path = parent_folder.joinpath("..", "..", "npbench", "benchmarks", bench.info["relative_path"], + bench.info["module_name"] + "_" + self.info["postfix"] + ".py") + + implementations.append((pymod_path, 'default')) + for impl_name, impl_postfix in _impl.items(): pymod_path = parent_folder.joinpath( "..", "..", "npbench", "benchmarks", bench.info["relative_path"], bench.info["module_name"] + "_" + self.info["postfix"] + "_" + impl_postfix + ".py") implementations.append((pymod_path, impl_name)) - # appending the default implementation - pymod_path = parent_folder.joinpath("..", "..", "npbench", "benchmarks", bench.info["relative_path"], - bench.info["module_name"] + "_" + self.info["postfix"] + ".py") - - implementations.append((pymod_path, 'default')) return implementations def implementations(self, bench: Benchmark): @@ -68,6 +70,16 @@ def implementations(self, bench: Benchmark): func_str = bench.info["func_name"] implementations = [] + + # appending the default implementation + try: + ldict = dict() + exec("from {m} import {f} as impl".format(m=module_str, f=func_str), ldict) + implementations.append((ldict['impl'], 'default')) + except Exception as e: + print("Failed to load the {r} {f} implementation.".format(r=self.info["full_name"], f=func_str)) + raise e + for impl_name, impl_postfix in _impl.items(): ldict = dict() try: @@ -79,15 +91,6 @@ def implementations(self, bench: Benchmark): print("Failed to load the {r} {f} implementation.".format(r=self.info["full_name"], f=impl_name)) continue - # appending the default implementation - try: - ldict = dict() - exec("from {m} import {f} as impl".format(m=module_str, f=func_str), ldict) - implementations.append((ldict['impl'], 'default')) - except Exception as e: - print("Failed to load the {r} {f} implementation.".format(r=self.info["full_name"], f=func_str)) - raise e - return implementations def exec_str(self, bench: Benchmark, impl: Callable = None): From c92ec2f60b5fa2872d385163583892a3c3104fad Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Tue, 17 Dec 2024 11:22:27 +0100 Subject: [PATCH 086/106] minor fix in frmwrk-name for validation --- npbench/infrastructure/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/npbench/infrastructure/test.py b/npbench/infrastructure/test.py index da0eafd..cd409b5 100644 --- a/npbench/infrastructure/test.py +++ b/npbench/infrastructure/test.py @@ -101,7 +101,7 @@ def first_execution(impl, impl_name): valid = True if validate and np_out is not None: try: - frmwrk_name = self.frmwrk.info["full_name"] + frmwrk_name = self.frmwrk.info["full_name"] + " - " + impl_name rtol = 1e-5 if not 'rtol' in self.bench.info else self.bench.info['rtol'] atol = 1e-8 if not 'atol' in self.bench.info else self.bench.info['atol'] From 970c3fd307a605a66783f825b5d8263e74bd3a7f Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Tue, 17 Dec 2024 15:26:28 +0100 Subject: [PATCH 087/106] Update go_fast jax implementation --- npbench/benchmarks/go_fast/go_fast_jax.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/npbench/benchmarks/go_fast/go_fast_jax.py b/npbench/benchmarks/go_fast/go_fast_jax.py index 4db63ab..4cebf26 100644 --- a/npbench/benchmarks/go_fast/go_fast_jax.py +++ b/npbench/benchmarks/go_fast/go_fast_jax.py @@ -5,8 +5,9 @@ @jax.jit def go_fast(a: jax.Array): - # Calculate the trace of the tanh of the diagonal elements - trace = jnp.sum(jnp.tanh(jnp.diag(a))) - - # Add the result to the original matrix + trace = 0.0 + def body_fn(i, trace): + trace += jnp.tanh(a[i, i]) + return trace + trace = jax.lax.fori_loop(0, a.shape[0], body_fn, trace) return a + trace From 78230557799a1994b9d229c2157f1356ca95194d Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Tue, 17 Dec 2024 15:27:16 +0100 Subject: [PATCH 088/106] Add go_fast jax_lib implementation --- npbench/benchmarks/go_fast/go_fast_jax_lib.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 npbench/benchmarks/go_fast/go_fast_jax_lib.py diff --git a/npbench/benchmarks/go_fast/go_fast_jax_lib.py b/npbench/benchmarks/go_fast/go_fast_jax_lib.py new file mode 100644 index 0000000..67cdf3f --- /dev/null +++ b/npbench/benchmarks/go_fast/go_fast_jax_lib.py @@ -0,0 +1,9 @@ +# https://numba.readthedocs.io/en/stable/user/5minguide.html + +import jax +import jax.numpy as jnp + +@jax.jit +def go_fast(a: jax.Array): + trace = jnp.sum(jnp.tanh(jnp.diag(a))) + return a + trace From bcf0336ef5e846ea2a6d3c955341077b0f46d025 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Thu, 19 Dec 2024 15:08:35 +0100 Subject: [PATCH 089/106] add trisolv lib implementation --- npbench/benchmarks/polybench/trisolv/trisolv_jax_lib.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 npbench/benchmarks/polybench/trisolv/trisolv_jax_lib.py diff --git a/npbench/benchmarks/polybench/trisolv/trisolv_jax_lib.py b/npbench/benchmarks/polybench/trisolv/trisolv_jax_lib.py new file mode 100644 index 0000000..287dfd4 --- /dev/null +++ b/npbench/benchmarks/polybench/trisolv/trisolv_jax_lib.py @@ -0,0 +1,9 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def kernel(L, x, b): + + x = jax.scipy.linalg.solve_triangular(L, b, lower=True) + return x From 51c1f80f40d8f3ed983a1b3c155f94b3439f2a1f Mon Sep 17 00:00:00 2001 From: Sushant S Date: Thu, 19 Dec 2024 15:08:44 +0100 Subject: [PATCH 090/106] update trisolv --- .../polybench/trisolv/trisolv_jax.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/npbench/benchmarks/polybench/trisolv/trisolv_jax.py b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py index 60b2ed3..35480d4 100644 --- a/npbench/benchmarks/polybench/trisolv/trisolv_jax.py +++ b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py @@ -1,17 +1,15 @@ import jax import jax.numpy as jnp - +from jax import lax @jax.jit def kernel(L, x, b): + + def loop_body(i, x): + mask = jnp.arange(L.shape[1]) < i + L_slice = jnp.where(mask, L[i, :], 0.0) + x_slice = jnp.where(mask, x, 0.0) - def loop_body(i, loop_vars): - L, x, b = loop_vars - L_slice = jnp.where(jnp.arange(L.shape[1]) < i, L[i, :], 0.0) - x_slice = jnp.where(jnp.arange(x.shape[0]) < i, x, 0.0) - x = x.at[i].set((b[i] - L_slice @ x_slice) / L[i, i]) - return L, x, b - - _, x, _ = jax.lax.fori_loop(0, x.shape[0], loop_body, (L, x, b)) + return x.at[i].set((b[i] - jnp.dot(L_slice, x_slice)) / L[i, i]) - return x + return lax.fori_loop(0, x.shape[0], loop_body, x) From f8c14ad0eb1673aa54ef3f4750ebdf82ecfa7e45 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Thu, 19 Dec 2024 15:45:37 +0100 Subject: [PATCH 091/106] remove excessive loop vars --- .../weather_stencils/vadv/vadv_jax.py | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py b/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py index 7642fac..9addbe3 100644 --- a/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py +++ b/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py @@ -15,7 +15,7 @@ def vadv(utens_stage, u_stage, wcon, u_pos, utens, dtr_stage): data_col = jnp.empty((I, J), dtype=utens_stage.dtype) def loop1(k, loop_vars): - wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = loop_vars + ccol, dcol = loop_vars gcv = 0.25 * (wcon[1:, :, 0 + 1] + wcon[:-1, :, 0 + 1]) cs = gcv * BET_M ccol = ccol.at[:, :, k].set(gcv * BET_P) @@ -31,12 +31,12 @@ def loop1(k, loop_vars): ccol = ccol.at[:, :, k].set(ccol[:, :, k] * divided) dcol = dcol.at[:, :, k].set(dcol[:, :, k] * divided) - return wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol + return ccol, dcol - wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = lax.fori_loop(0, 1, loop1, (wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol)) + ccol, dcol = lax.fori_loop(0, 1, loop1, (ccol, dcol)) def loop2(k, loop_vars): - wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = loop_vars + ccol, dcol = loop_vars gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) gcv = 0.25 * (wcon[1:, :, k + 1] + wcon[:-1, :, k + 1]) @@ -59,12 +59,11 @@ def loop2(k, loop_vars): ccol = ccol.at[:, :, k].set(ccol[:, :, k] * divided) dcol = dcol.at[:, :, k].set((dcol[:, :, k] - (dcol[:, :, k - 1]) * acol) * divided) - return wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol + return ccol, dcol - wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = lax.fori_loop(1, K - 1, loop2, (wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol)) + ccol, dcol = lax.fori_loop(1, K - 1, loop2, (ccol, dcol)) - def loop3(k, loop_vars): - wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = loop_vars + def loop3(k, dcol): gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) as_ = gav * BET_M acol = gav * BET_P @@ -79,30 +78,30 @@ def loop3(k, loop_vars): divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) dcol = dcol.at[:, :, k].set((dcol[:, :, k] - (dcol[:, :, k - 1]) * acol) * divided) - return wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol + return dcol - wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol = lax.fori_loop(K - 1, K, loop3, (wcon, u_stage, u_pos, utens, utens_stage, dtr_stage, ccol, dcol)) + dcol = lax.fori_loop(K - 1, K, loop3, dcol) def loop4(k, loop_vars): - dcol, data_col, utens_stage, u_pos, dtr_stage = loop_vars + data_col, utens_stage = loop_vars datacol = dcol[:, :, k] data_col = data_col.at[:].set(datacol) utens_stage = utens_stage.at[:, :, k].set(dtr_stage * (datacol - u_pos[:, :, k])) - return dcol, data_col, utens_stage, u_pos, dtr_stage + return data_col, utens_stage - dcol, data_col, utens_stage, u_pos, dtr_stage = lax.fori_loop(K - 1, K, loop4, (dcol, data_col, utens_stage, u_pos, dtr_stage)) + data_col, utens_stage = lax.fori_loop(K - 1, K, loop4, (data_col, utens_stage)) def loop5(k, loop_vars): - dcol, data_col, utens_stage, u_pos, dtr_stage = loop_vars + data_col, utens_stage = loop_vars K = utens_stage.shape[2] k = K - 2 - k datacol = dcol[:, :, k] - ccol[:, :, k] * data_col[:, :] data_col = data_col.at[:].set(datacol) utens_stage = utens_stage.at[:, :, k].set(dtr_stage * (datacol - u_pos[:, :, k])) - return dcol, data_col, utens_stage, u_pos, dtr_stage + return data_col, utens_stage - dcol, data_col, utens_stage, u_pos, dtr_stage = lax.fori_loop(0, K - 1, loop5, (dcol, data_col, utens_stage, u_pos, dtr_stage)) + data_col, utens_stage = lax.fori_loop(0, K - 1, loop5, (data_col, utens_stage)) return ccol, dcol, data_col, utens_stage From c526ad7b88e8fa3e142c2860c7db87b7a1958e02 Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Thu, 19 Dec 2024 17:25:48 +0100 Subject: [PATCH 092/106] Make cholesky go brrr Previously >10x slower, now around same runtime as numpy. --- .../polybench/cholesky/cholesky_jax.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/npbench/benchmarks/polybench/cholesky/cholesky_jax.py b/npbench/benchmarks/polybench/cholesky/cholesky_jax.py index 0633886..d756f43 100644 --- a/npbench/benchmarks/polybench/cholesky/cholesky_jax.py +++ b/npbench/benchmarks/polybench/cholesky/cholesky_jax.py @@ -3,30 +3,26 @@ from jax import lax @jax.jit -def kernel(A: jax.Array): - +def kernel(A): + A = A.at[0, 0].set(jnp.sqrt(A[0, 0])) def row_update(i, A): - + def col_update(j, A): mask = jnp.arange(A.shape[1]) < j - - A_i_slice = jnp.where(mask, A[i, :], 0) - A_j_slice = jnp.where(mask, A[j, :], 0) - - dot_product = jnp.dot(A_i_slice, A_j_slice) - A = A.at[i, j].set(A[i, j] - dot_product) - A = A.at[i, j].divide(A[j, j]) + products = jnp.where(mask, A[i, :] * A[j, :], 0.0) + dot_prod = jnp.sum(products) + A = A.at[i, j].set((A[i, j] - dot_prod) / A[j, j]) return A A = lax.fori_loop(0, i, col_update, A) - A_i_slice = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0) - dot_product = jnp.dot(A_i_slice, A_i_slice) - A = A.at[i, i].set(A[i, i] - dot_product) - A = A.at[i, i].set(jnp.sqrt(A[i, i])) + mask = jnp.arange(A.shape[1]) < i + products = jnp.where(mask, A[i, :] * A[i, :], 0) + dot_product = jnp.sum(products) + A = A.at[i, i].set(jnp.sqrt(A[i, i] - dot_product)) return A From 1dd397756438edc77edccee76b261ae733b1958c Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Thu, 19 Dec 2024 17:55:14 +0100 Subject: [PATCH 093/106] Make lu go brrr Previously was up to 70x slower, now it's faster for smaller sizes and comparable to numpy for bigger ones. --- npbench/benchmarks/polybench/lu/lu_jax.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/npbench/benchmarks/polybench/lu/lu_jax.py b/npbench/benchmarks/polybench/lu/lu_jax.py index 72ac80b..8957077 100644 --- a/npbench/benchmarks/polybench/lu/lu_jax.py +++ b/npbench/benchmarks/polybench/lu/lu_jax.py @@ -7,17 +7,18 @@ def kernel(A): def loop_body(i, A): - def inner_loop_1(j, A): - A_slice_1 = jnp.where(jnp.arange(A.shape[1]) < j, A[i, :], 0.0) - A_slice_2 = jnp.where(jnp.arange(A.shape[0]) < j, A[:, j], 0.0) - A = A.at[i, j].set(A[i, j] - A_slice_1 @ A_slice_2) - A = A.at[i, j].set(A[i, j] / A[j, j]) + mask = jnp.arange(A.shape[0]) < j + A_slice_1 = jnp.where(mask, A[i, :], 0.0) + A_slice_2 = jnp.where(mask, A[:, j], 0.0) + + A = A.at[i, j].set((A[i, j] - A_slice_1 @ A_slice_2) / A[j, j]) return A def inner_loop_2(j, A): - A_slice_1 = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0.0) - A_slice_2 = jnp.where(jnp.arange(A.shape[0]) < i, A[:, j], 0.0) + mask = jnp.arange(A.shape[0]) < i + A_slice_1 = jnp.where(mask, A[i, :], 0.0) + A_slice_2 = jnp.where(mask, A[:, j], 0.0) A = A.at[i, j].set(A[i, j] - A_slice_1 @ A_slice_2) return A @@ -28,4 +29,4 @@ def inner_loop_2(j, A): A = lax.fori_loop(0, A.shape[0], loop_body, A) - return A + return A \ No newline at end of file From babb413eb8cad80cba575b8f3603b251b6ea90e1 Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Thu, 19 Dec 2024 19:19:22 +0100 Subject: [PATCH 094/106] Make spmv go brrr Previously was up to 90x slower, now it's up to 40x faster than numpy. --- npbench/benchmarks/spmv/spmv_jax.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/npbench/benchmarks/spmv/spmv_jax.py b/npbench/benchmarks/spmv/spmv_jax.py index 0a8e109..525ea47 100644 --- a/npbench/benchmarks/spmv/spmv_jax.py +++ b/npbench/benchmarks/spmv/spmv_jax.py @@ -1,22 +1,12 @@ # Sparse Matrix-Vector Multiplication (SpMV) -import jax.numpy as jnp -import jax -from jax import lax +from jax.experimental import sparse +import scipy # Matrix-Vector Multiplication with the matrix given in Compressed Sparse Row # (CSR) format -@jax.jit def spmv(A_row, A_col, A_val, x): - y = jnp.empty(A_row.size - 1, dtype=A_val.dtype) + dim = A_row.size - 1 # needed because for the "paper" test size, scipy auto-infers the dims wrong + csr_m = scipy.sparse.csr_matrix((A_val, A_col, A_row), shape=(dim, dim)) - def row_update(i, y): - mask = (jnp.arange(A_col.size) >= A_row[i]) & (jnp.arange(A_col.size) < A_row[i + 1]) - - cols = jnp.where(mask, A_col, 0) - vals = jnp.where(mask, A_val, 0) - y = y.at[i].set(vals @ x[cols]) - - return y - - y = lax.fori_loop(0, A_row.size - 1, row_update, y) - return y + bcoo_m = sparse.BCOO.from_scipy_sparse(csr_m) + return bcoo_m @ x From 67f6569a641908461a6fcf787aed48dd91e55b6c Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Thu, 19 Dec 2024 20:05:59 +0100 Subject: [PATCH 095/106] Make durbin go brrr (a bit) Previously was up to 3x slower, now comparable to numpy. --- npbench/benchmarks/polybench/durbin/durbin_jax.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/npbench/benchmarks/polybench/durbin/durbin_jax.py b/npbench/benchmarks/polybench/durbin/durbin_jax.py index e3f6e4d..491ccac 100644 --- a/npbench/benchmarks/polybench/durbin/durbin_jax.py +++ b/npbench/benchmarks/polybench/durbin/durbin_jax.py @@ -11,16 +11,16 @@ def kernel(r): beta = 1.0 y = y.at[0].set(-r[0]) - def loop_body(k, loop_vars): alpha, beta, y, r = loop_vars beta *= 1.0 - alpha * alpha - - r_slice = jnp.where(jnp.arange(r.shape[0]) < k, jnp.roll(jnp.flip(r), [k], 0), 0.0) - y_slice = jnp.where(jnp.arange(y.shape[0]) < k, y, 0.0) - alpha = -(r[k] + jnp.dot(r_slice, y_slice)) / beta + mask = jnp.arange(r.shape[0]) < k + + products = jnp.where(mask, y * jnp.roll(jnp.flip(r), [k], 0),0.0) + dot_prod = jnp.sum(products) + alpha = -(r[k] + dot_prod) / beta - y_update_slice = jnp.where(jnp.arange(y.shape[0]) < k, jnp.roll(jnp.flip(y), [k], 0) * alpha, 0.0) + y_update_slice = jnp.where(mask, jnp.roll(jnp.flip(y), [k], 0) * alpha, 0.0) y += y_update_slice y = y.at[k].set(alpha) From 441e1395e872f78e81dd0093b46c04a1926650b1 Mon Sep 17 00:00:00 2001 From: Filip Jaksic Date: Thu, 19 Dec 2024 23:01:44 +0100 Subject: [PATCH 096/106] Rename vars in spmv to be more intuitive --- npbench/benchmarks/spmv/spmv_jax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/npbench/benchmarks/spmv/spmv_jax.py b/npbench/benchmarks/spmv/spmv_jax.py index 525ea47..c4f4645 100644 --- a/npbench/benchmarks/spmv/spmv_jax.py +++ b/npbench/benchmarks/spmv/spmv_jax.py @@ -1,12 +1,12 @@ # Sparse Matrix-Vector Multiplication (SpMV) -from jax.experimental import sparse +from jax.experimental import sparse as jax_sparse import scipy # Matrix-Vector Multiplication with the matrix given in Compressed Sparse Row # (CSR) format def spmv(A_row, A_col, A_val, x): dim = A_row.size - 1 # needed because for the "paper" test size, scipy auto-infers the dims wrong - csr_m = scipy.sparse.csr_matrix((A_val, A_col, A_row), shape=(dim, dim)) + matrix_in_csr_format = scipy.sparse.csr_matrix((A_val, A_col, A_row), shape=(dim, dim)) + matrix_in_bcoo_format = jax_sparse.BCOO.from_scipy_sparse(matrix_in_csr_format) - bcoo_m = sparse.BCOO.from_scipy_sparse(csr_m) - return bcoo_m @ x + return matrix_in_bcoo_format @ x From 0cb56c06083f43d2925004e69dd8eae98177afe5 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Thu, 19 Dec 2024 23:44:55 +0100 Subject: [PATCH 097/106] Update seidel_2d jax implementation --- npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py b/npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py index a5b0ccb..5f0131a 100644 --- a/npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py +++ b/npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py @@ -11,8 +11,7 @@ def loop1(t, A): def loop2(i, A): def loop3(j, A): - A = A.at[i, j].set(A[i, j] + A[i, j - 1]) - A = A.at[i, j].set(A[i, j] / 9.0) + A = A.at[i, j].set((A[i, j] + A[i, j - 1]) / 9.0) return A A = A.at[i, 1:-1].set( From 88c809e512e2a1395372f2639bfbb845a9bda904 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Fri, 20 Dec 2024 00:07:50 +0100 Subject: [PATCH 098/106] Update ludcmp jax implementation --- npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py b/npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py index 528fe59..6c741f8 100644 --- a/npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py +++ b/npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py @@ -13,8 +13,7 @@ def inner_loop_1(j, A): A_slice_1 = jnp.where(jnp.arange(A.shape[1]) < j, A[i, :], 0.0) A_slice_2 = jnp.where(jnp.arange(A.shape[0]) < j, A[:, j], 0.0) - A = A.at[i, j].set(A[i, j] - A_slice_1 @ A_slice_2) - A = A.at[i, j].set(A[i, j] / A[j, j]) + A = A.at[i, j].set((A[i, j] - A_slice_1 @ A_slice_2) / A[j, j]) return A def inner_loop_2(j, A): From 5d2d00fda4d9c8a9ae5c0051e97a8789fdd8141f Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Fri, 20 Dec 2024 00:25:31 +0100 Subject: [PATCH 099/106] Update cholesky jax implementation --- .../benchmarks/polybench/cholesky/cholesky_jax.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/npbench/benchmarks/polybench/cholesky/cholesky_jax.py b/npbench/benchmarks/polybench/cholesky/cholesky_jax.py index d756f43..55e6a3d 100644 --- a/npbench/benchmarks/polybench/cholesky/cholesky_jax.py +++ b/npbench/benchmarks/polybench/cholesky/cholesky_jax.py @@ -11,17 +11,19 @@ def row_update(i, A): def col_update(j, A): mask = jnp.arange(A.shape[1]) < j - products = jnp.where(mask, A[i, :] * A[j, :], 0.0) - dot_prod = jnp.sum(products) - A = A.at[i, j].set((A[i, j] - dot_prod) / A[j, j]) + + A_i_slice = jnp.where(mask, A[i, :], 0) + A_j_slice = jnp.where(mask, A[j, :], 0) + + dot_product = jnp.dot(A_i_slice, A_j_slice) + A = A.at[i, j].set((A[i, j] - dot_product) / A[j, j]) return A A = lax.fori_loop(0, i, col_update, A) - mask = jnp.arange(A.shape[1]) < i - products = jnp.where(mask, A[i, :] * A[i, :], 0) - dot_product = jnp.sum(products) + A_i_slice = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0) + dot_product = jnp.dot(A_i_slice, A_i_slice) A = A.at[i, i].set(jnp.sqrt(A[i, i] - dot_product)) return A From 35490a5d1ff84791ef9f0cbdb954318927b32919 Mon Sep 17 00:00:00 2001 From: hardik01shah Date: Fri, 20 Dec 2024 00:51:37 +0100 Subject: [PATCH 100/106] Update vadv jax implementation --- .../weather_stencils/vadv/vadv_jax.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py b/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py index 9addbe3..a6d6a19 100644 --- a/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py +++ b/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py @@ -18,19 +18,17 @@ def loop1(k, loop_vars): ccol, dcol = loop_vars gcv = 0.25 * (wcon[1:, :, 0 + 1] + wcon[:-1, :, 0 + 1]) cs = gcv * BET_M - ccol = ccol.at[:, :, k].set(gcv * BET_P) - bcol = dtr_stage - ccol[:, :, k] + bs = gcv * BET_P + bcol = dtr_stage - bs # update the d column correction_term = -cs * (u_stage[:, :, k + 1] - u_stage[:, :, k]) - dcol = dcol.at[:, :, k].set((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + - utens_stage[:, :, k] + correction_term)) # Thomas forward divided = 1.0 / bcol - ccol = ccol.at[:, :, k].set(ccol[:, :, k] * divided) - dcol = dcol.at[:, :, k].set(dcol[:, :, k] * divided) - + ccol = ccol.at[:, :, k].set(bs * divided) + dcol = dcol.at[:, :, k].set((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + + utens_stage[:, :, k] + correction_term) * divided) return ccol, dcol ccol, dcol = lax.fori_loop(0, 1, loop1, (ccol, dcol)) @@ -42,22 +40,21 @@ def loop2(k, loop_vars): as_ = gav * BET_M cs = gcv * BET_M + bs = gcv * BET_P acol = gav * BET_P - ccol = ccol.at[:, :, k].set(gcv * BET_P) - bcol = dtr_stage - acol - ccol[:, :, k] + bcol = dtr_stage - acol - bs # update the d column correction_term = -as_ * (u_stage[:, :, k - 1] - u_stage[:, :, k]) - cs * ( u_stage[:, :, k + 1] - u_stage[:, :, k]) - dcol = dcol.at[:, :, k].set((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + - utens_stage[:, :, k] + correction_term)) # Thomas forward divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) - ccol = ccol.at[:, :, k].set(ccol[:, :, k] * divided) - dcol = dcol.at[:, :, k].set((dcol[:, :, k] - (dcol[:, :, k - 1]) * acol) * divided) + ccol = ccol.at[:, :, k].set(bs * divided) + dcol = dcol.at[:, :, k].set(((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + + utens_stage[:, :, k] + correction_term) - (dcol[:, :, k - 1]) * acol) * divided) return ccol, dcol @@ -71,12 +68,11 @@ def loop3(k, dcol): # update the d column correction_term = -as_ * (u_stage[:, :, k - 1] - u_stage[:, :, k]) - dcol = dcol.at[:, :, k].set((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + - utens_stage[:, :, k] + correction_term)) # Thomas forward divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) - dcol = dcol.at[:, :, k].set((dcol[:, :, k] - (dcol[:, :, k - 1]) * acol) * divided) + dcol = dcol.at[:, :, k].set(((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + + utens_stage[:, :, k] + correction_term) - (dcol[:, :, k - 1]) * acol) * divided) return dcol From 312971c7b06648238be07021a430507f91629676 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Fri, 20 Dec 2024 07:18:02 +0100 Subject: [PATCH 101/106] update trisolv --- .../polybench/trisolv/trisolv_jax.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/npbench/benchmarks/polybench/trisolv/trisolv_jax.py b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py index 35480d4..30eab9c 100644 --- a/npbench/benchmarks/polybench/trisolv/trisolv_jax.py +++ b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py @@ -2,14 +2,18 @@ import jax.numpy as jnp from jax import lax + @jax.jit def kernel(L, x, b): - - def loop_body(i, x): - mask = jnp.arange(L.shape[1]) < i - L_slice = jnp.where(mask, L[i, :], 0.0) - x_slice = jnp.where(mask, x, 0.0) - return x.at[i].set((b[i] - jnp.dot(L_slice, x_slice)) / L[i, i]) + def loop_body(i, loop_vars): + L, x, b = loop_vars + mask = jnp.arange(x.shape[0]) < i + products = jnp.where(mask, L[i, :] * x, 0.0) + dot_product = jnp.sum(products) + x.at[i].set((b[i] - dot_product) / L[i, i]) + return L, x, b + + _, x, _ = lax.fori_loop(0, x.shape[0], loop_body, (L, x, b)) - return lax.fori_loop(0, x.shape[0], loop_body, x) + return x From 52e0f768ef624a7d97dd03932280c20460c95d5e Mon Sep 17 00:00:00 2001 From: Sushant S Date: Fri, 20 Dec 2024 11:58:21 +0100 Subject: [PATCH 102/106] update covariance --- .../polybench/covariance/covariance_jax.py | 20 +------------------ 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/npbench/benchmarks/polybench/covariance/covariance_jax.py b/npbench/benchmarks/polybench/covariance/covariance_jax.py index c36c3e7..7ac7ba9 100644 --- a/npbench/benchmarks/polybench/covariance/covariance_jax.py +++ b/npbench/benchmarks/polybench/covariance/covariance_jax.py @@ -8,23 +8,5 @@ def kernel(M, float_n, data): mean = jnp.mean(data, axis=0) data -= mean - - cov = jnp.zeros((M, M), dtype=data.dtype) - - def loop_body(i, loop_vars): - data, cov = loop_vars - cov_slice1 = data[:, i] - cov_slice2 = jnp.where(jnp.arange(M) >= i, data, 0.0) - - ans = cov_slice1 @ cov_slice2 / (float_n - 1.0) - row_update_slice = jnp.where(jnp.arange(M) >= i, ans, cov[i, :]) - col_update_slice = jnp.where(jnp.arange(M) >= i, ans, cov[:, i]) - - cov = cov.at[i, :].set(row_update_slice) - cov = cov.at[:, i].set(col_update_slice) - - return data, cov - - _, cov = lax.fori_loop(0, M, loop_body, (data, cov)) - + cov = data.T @ data / (float_n - 1.0) return cov From 61fd0a632ae7865e12be8c4fb7183933bf1db73c Mon Sep 17 00:00:00 2001 From: Sushant S Date: Fri, 20 Dec 2024 11:58:34 +0100 Subject: [PATCH 103/106] update correlation --- .../polybench/correlation/correlation_jax.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/npbench/benchmarks/polybench/correlation/correlation_jax.py b/npbench/benchmarks/polybench/correlation/correlation_jax.py index b7db052..ec03a79 100644 --- a/npbench/benchmarks/polybench/correlation/correlation_jax.py +++ b/npbench/benchmarks/polybench/correlation/correlation_jax.py @@ -6,25 +6,18 @@ @partial(jax.jit, static_argnums=(0,)) def kernel(M, float_n, data): - - def loop_body(i, loop_vars): - corr, data = loop_vars - corr_update_x = jnp.where(jnp.arange(data.shape[0]) > i, data[i], corr[i]) - corr_update_y = jnp.where(jnp.arange(data.shape[0]) > i, data[i], corr[:, i]) - corr = lax.dynamic_update_slice(corr, corr_update_x[None, :], (i, 0)) - corr = lax.dynamic_update_slice(corr, corr_update_y[:, None], (0, i)) - - return corr, data + def loop_body(i, corr): + corr.at[i, i].set(1) + return corr mean = jnp.mean(data, axis=0) stddev = jnp.std(data, axis=0) stddev = jnp.where(stddev <= 0.1, 1.0, stddev) data = data - mean data = data / (jnp.sqrt(float_n) * stddev) - corr = jnp.eye(M, dtype=data.dtype) - data_mul = jnp.dot(data.T, data) + corr = jnp.dot(data.T, data) - corr, _ = lax.fori_loop(0, M - 1, loop_body, (corr, data_mul)) + corr = lax.fori_loop(0, M, loop_body, corr) return corr From 58f2cdb995ee5957a25b1d74b7875a603d6af630 Mon Sep 17 00:00:00 2001 From: Sushant S Date: Fri, 20 Dec 2024 12:23:03 +0100 Subject: [PATCH 104/106] fix trisolv --- npbench/benchmarks/polybench/trisolv/trisolv_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/npbench/benchmarks/polybench/trisolv/trisolv_jax.py b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py index 30eab9c..192f8d3 100644 --- a/npbench/benchmarks/polybench/trisolv/trisolv_jax.py +++ b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py @@ -11,7 +11,7 @@ def loop_body(i, loop_vars): mask = jnp.arange(x.shape[0]) < i products = jnp.where(mask, L[i, :] * x, 0.0) dot_product = jnp.sum(products) - x.at[i].set((b[i] - dot_product) / L[i, i]) + x = x.at[i].set((b[i] - dot_product) / L[i, i]) return L, x, b _, x, _ = lax.fori_loop(0, x.shape[0], loop_body, (L, x, b)) From 941215fb9f4ac233cf56f31ce237eea9a77f9fca Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Sun, 22 Dec 2024 19:43:01 +0100 Subject: [PATCH 105/106] Update README.md Add JAX as a supported framework --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index e41313c..4c8e9b6 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ python plot_results.py Currently, the following frameworks are supported (in alphabetical order): - CuPy - DaCe +- JAX - Numba - NumPy - Pythran From 9bce82faa187c337a80e919332eed11b8f98ce76 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Sun, 12 Jan 2025 14:54:52 +0100 Subject: [PATCH 106/106] Update README.md with JAX installation instructions --- README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/README.md b/README.md index 4c8e9b6..28431da 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,24 @@ However, you may want to install the latest version from the [GitHub repository] To run NPBench with DaCe, you have to select as framework (see details below) either `dace_cpu` or `dace_gpu`. +### Jax + +JAX can be installed with pip: +- CPU-only (Linux/macOS/Windows) + ```sh + pip install -U jax + ``` +- GPU (NVIDIA, CUDA 12) + ```sh + pip install -U "jax[cuda12]" + ``` +- TPU (Google Cloud TPU VM) + ```sh + pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + ``` +For more installation options, please consult the JAX [installation guide](https://jax.readthedocs.io/en/latest/installation.html#installation). + + ### Numba Numba can be installed with pip: