diff --git a/.gitignore b/.gitignore index b6e4761..8bac872 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +# dace +.dacecache/ diff --git a/README.md b/README.md index e41313c..28431da 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ python plot_results.py Currently, the following frameworks are supported (in alphabetical order): - CuPy - DaCe +- JAX - Numba - NumPy - Pythran @@ -55,6 +56,24 @@ However, you may want to install the latest version from the [GitHub repository] To run NPBench with DaCe, you have to select as framework (see details below) either `dace_cpu` or `dace_gpu`. +### Jax + +JAX can be installed with pip: +- CPU-only (Linux/macOS/Windows) + ```sh + pip install -U jax + ``` +- GPU (NVIDIA, CUDA 12) + ```sh + pip install -U "jax[cuda12]" + ``` +- TPU (Google Cloud TPU VM) + ```sh + pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + ``` +For more installation options, please consult the JAX [installation guide](https://jax.readthedocs.io/en/latest/installation.html#installation). + + ### Numba Numba can be installed with pip: diff --git a/framework_info/jax.json b/framework_info/jax.json new file mode 100644 index 0000000..c3506b8 --- /dev/null +++ b/framework_info/jax.json @@ -0,0 +1,10 @@ +{ + "framework": { + "simple_name": "jax", + "full_name": "Jax", + "prefix": "jax", + "postfix": "jax", + "class": "JaxFramework", + "arch": "cpu" + } +} \ No newline at end of file diff --git a/npbench/benchmarks/azimint_hist/azimint_hist_jax.py b/npbench/benchmarks/azimint_hist/azimint_hist_jax.py new file mode 100644 index 0000000..2103621 --- /dev/null +++ b/npbench/benchmarks/azimint_hist/azimint_hist_jax.py @@ -0,0 +1,19 @@ +# Copyright 2014 Jérôme Kieffer et al. +# This is an open-access article distributed under the terms of the +# Creative Commons Attribution License, which permits unrestricted use, +# distribution, and reproduction in any medium, provided the original author +# and source are credited. +# http://creativecommons.org/licenses/by/3.0/ +# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for +# high performance azimuthal integration on gpu, 2014. In Proceedings of the +# 7th European Conference on Python in Science (EuroSciPy 2014). + +import jax +import jax.numpy as jnp +from functools import partial + +@partial(jax.jit, static_argnums=(2,)) +def azimint_hist(data: jax.Array, radius: jax.Array, npt): + histu = jnp.histogram(radius, npt)[0] + histw = jnp.histogram(radius, npt, weights=data)[0] + return histw / histu diff --git a/npbench/benchmarks/azimint_naive/azimint_naive_jax.py b/npbench/benchmarks/azimint_naive/azimint_naive_jax.py new file mode 100644 index 0000000..5523a2d --- /dev/null +++ b/npbench/benchmarks/azimint_naive/azimint_naive_jax.py @@ -0,0 +1,32 @@ +# Copyright 2014 Jérôme Kieffer et al. +# This is an open-access article distributed under the terms of the +# Creative Commons Attribution License, which permits unrestricted use, +# distribution, and reproduction in any medium, provided the original author +# and source are credited. +# http://creativecommons.org/licenses/by/3.0/ +# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for +# high performance azimuthal integration on gpu, 2014. In Proceedings of the +# 7th European Conference on Python in Science (EuroSciPy 2014). + +import jax +import jax.numpy as jnp +from jax import lax +from functools import partial + + +@partial(jax.jit, static_argnums=(2,)) +def azimint_naive(data, radius, npt): + rmax = radius.max() + res = jnp.zeros(npt, dtype=jnp.float64) + + def loop_body(i, res): + r1 = rmax * i / npt + r2 = rmax * (i + 1) / npt + mask_r12 = jnp.logical_and((r1 <= radius), (radius < r2)) + mean = jnp.where(mask_r12, data, 0).mean(where=mask_r12) + res = res.at[i].set(mean) + return res + + res = lax.fori_loop(0, npt, loop_body, res) + + return res diff --git a/npbench/benchmarks/cavity_flow/cavity_flow_jax.py b/npbench/benchmarks/cavity_flow/cavity_flow_jax.py new file mode 100644 index 0000000..4fbda96 --- /dev/null +++ b/npbench/benchmarks/cavity_flow/cavity_flow_jax.py @@ -0,0 +1,102 @@ +# Barba, Lorena A., and Forsyth, Gilbert F. (2018). +# CFD Python: the 12 steps to Navier-Stokes equations. +# Journal of Open Source Education, 1(9), 21, +# https://doi.org/10.21105/jose.00021 +# TODO: License +# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth. +# All content is under Creative Commons Attribution CC-BY 4.0, +# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018). + +import jax.numpy as jnp +import jax +from jax import lax +from functools import partial + + +@partial(jax.jit, static_argnums=(1,)) +def build_up_b(b, rho, dt, u, v, dx, dy): + + b = b.at[1:-1, + 1:-1].set(rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) + + (v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) - + ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 * + ((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) * + (v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) - + ((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2)) + + return b + + +@partial(jax.jit, static_argnums=(0,)) +def pressure_poisson(nit, p, dx, dy, b): + def body_func(p, _): + pn = p.copy() + p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 + + (pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) / + (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / + (2 * (dx**2 + dy**2)) * b[1:-1, 1:-1]) + + p = p.at[:, -1].set(p[:, -2]) # dp/dx = 0 at x = 2 + p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0 + p = p.at[:, 0].set(p[:, 1]) # dp/dx = 0 at x = 0 + p = p.at[-1, :].set(0) # p = 0 at y = 2 + + return p, None + + p, _ = lax.scan(body_func, p, jnp.arange(nit)) + + return p + + +@partial(jax.jit, static_argnums=(0,1,2,3,10,11,)) +def cavity_flow(nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu): + b = jnp.zeros((ny, nx)) + array_vals = (u, v, p, b) + + def body_func(array_vals, _): + + u, v, p, b = array_vals + + un = u.copy() + vn = v.copy() + + b = build_up_b(b, rho, dt, u, v, dx, dy) + p = pressure_poisson(nit, p, dx, dy, b) + + u = u.at[1:-1, + 1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * + (un[1:-1, 1:-1] - un[1:-1, 0:-2]) - + vn[1:-1, 1:-1] * dt / dy * + (un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) * + (p[1:-1, 2:] - p[1:-1, 0:-2]) + nu * + (dt / dx**2 * + (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) + + dt / dy**2 * + (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1]))) + + v = v.at[1:-1, + 1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * + (vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) - + vn[1:-1, 1:-1] * dt / dy * + (vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) * + (p[2:, 1:-1] - p[0:-2, 1:-1]) + nu * + (dt / dx**2 * + (vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) + + dt / dy**2 * + (vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1]))) + + u = u.at[0, :].set(0) + u = u.at[:, 0].set(0) + u = u.at[:, -1].set(0) + u = u.at[-1, :].set(1) # set velocity on cavity lid equal to 1 + v = v.at[0, :].set(0) + v = v.at[-1, :].set(0) + v = v.at[:, 0].set(0) + v = v.at[:, -1].set(0) + + return (u, v, p, b), None + + out_vals, _ = lax.scan(body_func, array_vals, jnp.arange(nt)) + u, v, p, b = out_vals + + return u, v, p diff --git a/npbench/benchmarks/channel_flow/channel_flow_jax.py b/npbench/benchmarks/channel_flow/channel_flow_jax.py new file mode 100644 index 0000000..eb5ffbc --- /dev/null +++ b/npbench/benchmarks/channel_flow/channel_flow_jax.py @@ -0,0 +1,172 @@ +# Barba, Lorena A., and Forsyth, Gilbert F. (2018). +# CFD Python: the 12 steps to Navier-Stokes equations. +# Journal of Open Source Education, 1(9), 21, +# https://doi.org/10.21105/jose.00021 +# TODO: License +# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth. +# All content is under Creative Commons Attribution CC-BY 4.0, +# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018). + +import jax.numpy as jnp +import jax +from jax import lax +from functools import partial + + +@partial(jax.jit, static_argnums=(0,)) +def build_up_b(rho, dt, dx, dy, u, v): + b = jnp.zeros_like(u) + b = b.at[1:-1, + 1:-1].set((rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) + + (v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) - + ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 * + ((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) * + (v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) - + ((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2))) + + # Periodic BC Pressure @ x = 2 + b = b.at[1:-1, -1].set((rho * (1 / dt * ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx) + + (v[2:, -1] - v[0:-2, -1]) / (2 * dy)) - + ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx))**2 - 2 * + ((u[2:, -1] - u[0:-2, -1]) / (2 * dy) * + (v[1:-1, 0] - v[1:-1, -2]) / (2 * dx)) - + ((v[2:, -1] - v[0:-2, -1]) / (2 * dy))**2))) + + # Periodic BC Pressure @ x = 0 + b = b.at[1:-1, 0].set((rho * (1 / dt * ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx) + + (v[2:, 0] - v[0:-2, 0]) / (2 * dy)) - + ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx))**2 - 2 * + ((u[2:, 0] - u[0:-2, 0]) / (2 * dy) * + (v[1:-1, 1] - v[1:-1, -1]) / + (2 * dx)) - ((v[2:, 0] - v[0:-2, 0]) / (2 * dy))**2))) + + return b + +@partial(jax.jit, static_argnums=(0,)) +def pressure_poisson_periodic(nit, p, dx, dy, b): + + def body_func(p, q): + pn = p.copy() + p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 + + (pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) / + (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / + (2 * (dx**2 + dy**2)) * b[1:-1, 1:-1]) + + # Periodic BC Pressure @ x = 2 + p = p.at[1:-1, -1].set(((pn[1:-1, 0] + pn[1:-1, -2]) * dy**2 + + (pn[2:, -1] + pn[0:-2, -1]) * dx**2) / + (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / + (2 * (dx**2 + dy**2)) * b[1:-1, -1]) + + # Periodic BC Pressure @ x = 0 + p = p.at[1:-1, + 0].set((((pn[1:-1, 1] + pn[1:-1, -1]) * dy**2 + + (pn[2:, 0] + pn[0:-2, 0]) * dx**2) / (2 * (dx**2 + dy**2)) - + dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 0])) + + # Wall boundary conditions, pressure + p = p.at[-1, :].set(p[-2, :]) # dp/dy = 0 at y = 2 + p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0 + + return p, None + + p, _ = lax.scan(body_func, p, jnp.arange(nit)) + + +@partial(jax.jit, static_argnums=(0,7,8,9)) +def channel_flow(nit, u, v, dt, dx, dy, p, rho, nu, F): + udiff = 1 + stepcount = 0 + + array_vals = (udiff, stepcount, u, v, p) + + def conf_func(array_vals): + udiff, _, _, _ , _ = array_vals + return udiff > .001 + + def body_func(array_vals): + _, stepcount, u, v, p = array_vals + + un = u.copy() + vn = v.copy() + + b = build_up_b(rho, dt, dx, dy, u, v) + pressure_poisson_periodic(nit, p, dx, dy, b) + + u = u.at[1:-1, + 1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * + (un[1:-1, 1:-1] - un[1:-1, 0:-2]) - + vn[1:-1, 1:-1] * dt / dy * + (un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) * + (p[1:-1, 2:] - p[1:-1, 0:-2]) + nu * + (dt / dx**2 * + (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) + + dt / dy**2 * + (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])) + + F * dt) + + v = v.at[1:-1, + 1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * + (vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) - + vn[1:-1, 1:-1] * dt / dy * + (vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) * + (p[2:, 1:-1] - p[0:-2, 1:-1]) + nu * + (dt / dx**2 * + (vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) + + dt / dy**2 * + (vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1]))) + + # Periodic BC u @ x = 2 + u = u.at[1:-1, -1].set( + un[1:-1, -1] - un[1:-1, -1] * dt / dx * + (un[1:-1, -1] - un[1:-1, -2]) - vn[1:-1, -1] * dt / dy * + (un[1:-1, -1] - un[0:-2, -1]) - dt / (2 * rho * dx) * + (p[1:-1, 0] - p[1:-1, -2]) + nu * + (dt / dx**2 * + (un[1:-1, 0] - 2 * un[1:-1, -1] + un[1:-1, -2]) + dt / dy**2 * + (un[2:, -1] - 2 * un[1:-1, -1] + un[0:-2, -1])) + F * dt) + + # Periodic BC u @ x = 0 + u = u.at[1:-1, + 0].set(un[1:-1, 0] - un[1:-1, 0] * dt / dx * + (un[1:-1, 0] - un[1:-1, -1]) - vn[1:-1, 0] * dt / dy * + (un[1:-1, 0] - un[0:-2, 0]) - dt / (2 * rho * dx) * + (p[1:-1, 1] - p[1:-1, -1]) + nu * + (dt / dx**2 * + (un[1:-1, 1] - 2 * un[1:-1, 0] + un[1:-1, -1]) + dt / dy**2 * + (un[2:, 0] - 2 * un[1:-1, 0] + un[0:-2, 0])) + F * dt) + + # Periodic BC v @ x = 2 + v = v.at[1:-1, -1].set( + vn[1:-1, -1] - un[1:-1, -1] * dt / dx * + (vn[1:-1, -1] - vn[1:-1, -2]) - vn[1:-1, -1] * dt / dy * + (vn[1:-1, -1] - vn[0:-2, -1]) - dt / (2 * rho * dy) * + (p[2:, -1] - p[0:-2, -1]) + nu * + (dt / dx**2 * + (vn[1:-1, 0] - 2 * vn[1:-1, -1] + vn[1:-1, -2]) + dt / dy**2 * + (vn[2:, -1] - 2 * vn[1:-1, -1] + vn[0:-2, -1]))) + + # Periodic BC v @ x = 0 + v = v.at[1:-1, + 0].set(vn[1:-1, 0] - un[1:-1, 0] * dt / dx * + (vn[1:-1, 0] - vn[1:-1, -1]) - vn[1:-1, 0] * dt / dy * + (vn[1:-1, 0] - vn[0:-2, 0]) - dt / (2 * rho * dy) * + (p[2:, 0] - p[0:-2, 0]) + nu * + (dt / dx**2 * + (vn[1:-1, 1] - 2 * vn[1:-1, 0] + vn[1:-1, -1]) + dt / dy**2 * + (vn[2:, 0] - 2 * vn[1:-1, 0] + vn[0:-2, 0]))) + + # Wall BC: u,v = 0 @ y = 0,2 + u = u.at[0, :].set(0) + u = u.at[-1, :].set(0) + v = v.at[0, :].set(0) + v = v.at[-1, :].set(0) + + udiff = (jnp.sum(u) - jnp.sum(un)) / jnp.sum(u) + stepcount += 1 + + return (udiff, stepcount, u, v, p) + + _, stepcount, _, _, _ = lax.while_loop(conf_func, body_func, array_vals) + + return stepcount diff --git a/npbench/benchmarks/compute/compute_jax.py b/npbench/benchmarks/compute/compute_jax.py new file mode 100644 index 0000000..93d23f0 --- /dev/null +++ b/npbench/benchmarks/compute/compute_jax.py @@ -0,0 +1,8 @@ +# https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html + +import jax.numpy as jnp +import jax + +@jax.jit +def compute(array_1, array_2, a, b, c): + return jnp.clip(array_1, 2, 10) * a + array_2 * b + c diff --git a/npbench/benchmarks/contour_integral/contour_integral_jax.py b/npbench/benchmarks/contour_integral/contour_integral_jax.py new file mode 100644 index 0000000..b319bda --- /dev/null +++ b/npbench/benchmarks/contour_integral/contour_integral_jax.py @@ -0,0 +1,35 @@ +import jax +import jax.numpy as jnp +from functools import partial + +@partial(jax.jit, static_argnums=(0, 1, 2)) +def contour_integral(NR, NM, slab_per_bc, Ham, int_pts, Y): + P0 = jnp.zeros((NR, NM), dtype=jnp.complex128) + P1 = jnp.zeros((NR, NM), dtype=jnp.complex128) + + def body_fun(i, accum): + P0, P1 = accum + z = int_pts[i] + Tz = jnp.zeros((NR, NR), dtype=jnp.complex128) + + def compute_Tz(n, Tz): + zz = jnp.power(z, slab_per_bc / 2 - n) + Tz += zz * Ham[n] + return Tz + + Tz = jax.lax.fori_loop(0, slab_per_bc + 1, compute_Tz, Tz) + + if NR == NM: + X = jnp.linalg.inv(Tz) + else: + X = jnp.linalg.solve(Tz, Y) + X = jax.lax.cond(abs(z) < 1.0, lambda x: -x, lambda x: x, X) + + P0 += X + P1 += z * X + + return P0, P1 + + P0, P1 = jax.lax.fori_loop(0, int_pts.shape[0], body_fun, (P0, P1)) + + return P0, P1 diff --git a/npbench/benchmarks/crc16/crc16_jax.py b/npbench/benchmarks/crc16/crc16_jax.py new file mode 100644 index 0000000..7c630ce --- /dev/null +++ b/npbench/benchmarks/crc16/crc16_jax.py @@ -0,0 +1,33 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def crc16(data, poly=0x8408): + ''' + CRC-16-CCITT Algorithm + ''' + crc = 0xFFFF + + def loop_body(crc, b): + cur_byte = 0xFF & b + + def inner_loop_body(carry, data): + crc, cur_byte = carry + xor_flag = (crc & 0x0001) ^ (cur_byte & 0x0001) + crc = lax.select(xor_flag, (crc >> 1) ^ poly, crc >> 1) + cur_byte >>= 1 + + return (crc, cur_byte), None + + (crc, cur_byte), _ = lax.scan(inner_loop_body, (crc, cur_byte), jnp.arange(8)) + + return crc, None + + crc, _ = lax.scan(loop_body, crc, data) + + crc = (~crc & 0xFFFF) + crc = (crc << 8) | ((crc >> 8) & 0xFF) + + + return crc & 0xFFFF diff --git a/npbench/benchmarks/deep_learning/conv2d_bias/conv2d_jax.py b/npbench/benchmarks/deep_learning/conv2d_bias/conv2d_jax.py new file mode 100644 index 0000000..6a50551 --- /dev/null +++ b/npbench/benchmarks/deep_learning/conv2d_bias/conv2d_jax.py @@ -0,0 +1,44 @@ +import jax.numpy as jnp +import jax +from jax import lax + + +# Deep learning convolutional operator (stride = 1) +@jax.jit +def conv2d(input, weights): + K = weights.shape[0] # Assuming square kernel + N = input.shape[0] + H_out = input.shape[1] - K + 1 + W_out = input.shape[2] - K + 1 + C_out = weights.shape[3] + output = jnp.empty((N, H_out, W_out, C_out), dtype=jnp.float32) + + def row_update(output, i): + def col_update(output, j): + input_slice = lax.dynamic_slice( + input, + (0, i, j, 0), + (N, K, K, input.shape[-1]) + ) + conv_result = jnp.sum( + input_slice[:, :, :, :, None] * weights[None, :, :, :], + axis=(1, 2, 3) + ) + output = lax.dynamic_update_slice( + output, + conv_result[:, None, None, :], + (0, i, j, 0) + ) + return output, None + + output, _ = lax.scan(col_update, output, jnp.arange(W_out)) + return output, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) + + return output + + +@jax.jit +def conv2d_bias(input, weights, bias): + return conv2d(input, weights) + bias diff --git a/npbench/benchmarks/deep_learning/lenet/lenet_jax.py b/npbench/benchmarks/deep_learning/lenet/lenet_jax.py new file mode 100644 index 0000000..f1dbbd4 --- /dev/null +++ b/npbench/benchmarks/deep_learning/lenet/lenet_jax.py @@ -0,0 +1,87 @@ +import jax.numpy as jnp +import jax +from jax import lax +from functools import partial + +@jax.jit +def relu(x): + return jnp.maximum(x, 0) + + +# Deep learning convolutional operator (stride = 1) +@jax.jit +def conv2d(input, weights): + K = weights.shape[0] # Assuming square kernel + N = input.shape[0] + H_out = input.shape[1] - K + 1 + W_out = input.shape[2] - K + 1 + C_out = weights.shape[3] + output = jnp.empty((N, H_out, W_out, C_out), dtype=jnp.float32) + + def row_update(output, i): + def col_update(output, j): + input_slice = lax.dynamic_slice( + input, + (0, i, j, 0), + (N, K, K, input.shape[-1]) + ) + conv_result = jnp.sum( + input_slice[:, :, :, :, None] * weights[None, :, :, :], + axis=(1, 2, 3) + ) + output = lax.dynamic_update_slice( + output, + conv_result[:, None, None, :], + (0, i, j, 0) + ) + return output, None + + output, _ = lax.scan(col_update, output, jnp.arange(W_out)) + return output, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) + + return output + + +# 2x2 maxpool operator, as used in LeNet-5 +@jax.jit +def maxpool2d(x): + output = jnp.empty( + [x.shape[0], x.shape[1] // 2, x.shape[2] // 2, x.shape[3]], + dtype=x.dtype) + + def row_update(output, i): + def col_update(output, j): + input_slice = lax.dynamic_slice( + x, + (0, 2 * i, 2 * j, 0), + (x.shape[0], 2, 2, x.shape[3]) + ) + output = lax.dynamic_update_slice( + output, + jnp.max(input_slice, axis=(1, 2))[:, None, None, :], + (0, i, j, 0) + ) + return output, None + + output, _ = lax.scan(col_update, output, jnp.arange(x.shape[2] // 2)) + return output, None + + output, _ = lax.scan(row_update, output, jnp.arange(x.shape[1] // 2)) + + return output + + +# LeNet-5 Convolutional Neural Network (inference mode) +@partial(jax.jit, static_argnums=(11, 12)) +def lenet5(input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, + fc3w, fc3b, N, C_before_fc1): + x = relu(conv2d(input, conv1) + conv1bias) + x = maxpool2d(x) + x = relu(conv2d(x, conv2) + conv2bias) + x = maxpool2d(x) + x = jnp.reshape(x, (N, C_before_fc1)) + x = relu(x @ fc1w + fc1b) + x = relu(x @ fc2w + fc2b) + return x @ fc3w + fc3b diff --git a/npbench/benchmarks/deep_learning/mlp/mlp_jax.py b/npbench/benchmarks/deep_learning/mlp/mlp_jax.py new file mode 100644 index 0000000..57b8c53 --- /dev/null +++ b/npbench/benchmarks/deep_learning/mlp/mlp_jax.py @@ -0,0 +1,24 @@ +import jax.numpy as jnp +import jax + +@jax.jit +def relu(x): + return jnp.maximum(x, 0) + + +# Numerically-stable version of softmax +@jax.jit +def softmax(x): + tmp_max = jnp.max(x, axis=-1, keepdims=True) + tmp_out = jnp.exp(x - tmp_max) + tmp_sum = jnp.sum(tmp_out, axis=-1, keepdims=True) + return tmp_out / tmp_sum + + +# 3-layer MLP +@jax.jit +def mlp(input, w1, b1, w2, b2, w3, b3): + x = relu(input @ w1 + b1) + x = relu(x @ w2 + b2) + x = softmax(x @ w3 + b3) # Softmax call can be omitted if necessary + return x diff --git a/npbench/benchmarks/deep_learning/resnet/resnet_jax.py b/npbench/benchmarks/deep_learning/resnet/resnet_jax.py new file mode 100644 index 0000000..e573da9 --- /dev/null +++ b/npbench/benchmarks/deep_learning/resnet/resnet_jax.py @@ -0,0 +1,73 @@ +import jax.numpy as jnp +import jax +from jax import lax + +@jax.jit +def relu(x): + return jnp.maximum(x, 0) + + +# Deep learning convolutional operator (stride = 1) +@jax.jit +def conv2d(input, weights): + K = weights.shape[0] # Assuming square kernel + N = input.shape[0] + H_out = input.shape[1] - K + 1 + W_out = input.shape[2] - K + 1 + C_out = weights.shape[3] + output = jnp.empty((N, H_out, W_out, C_out), dtype=jnp.float32) + + def row_update(output, i): + def col_update(output, j): + input_slice = lax.dynamic_slice( + input, + (0, i, j, 0), + (N, K, K, input.shape[-1]) + ) + conv_result = jnp.sum( + input_slice[:, :, :, :, None] * weights[None, :, :, :], + axis=(1, 2, 3) + ) + output = lax.dynamic_update_slice( + output, + conv_result[:, None, None, :], + (0, i, j, 0) + ) + return output, None + + output, _ = lax.scan(col_update, output, jnp.arange(W_out)) + return output, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) + return output + + +# Batch normalization operator, as used in ResNet +@jax.jit +def batchnorm2d(x, eps=1e-5): + mean = jnp.mean(x, axis=0, keepdims=True) + std = jnp.std(x, axis=0, keepdims=True) + return (x - mean) / jnp.sqrt(std + eps) + + +# Bottleneck residual block (after initial convolution, without downsampling) +# in the ResNet-50 CNN (inference) +@jax.jit +def resnet_basicblock(input, conv1, conv2, conv3): + # Pad output of first convolution for second convolution + padded = jnp.zeros((input.shape[0], input.shape[1] + 2, input.shape[2] + 2, + conv1.shape[3]), dtype=jnp.float32) + padded = lax.dynamic_update_slice( + padded, + conv2d(input, conv1), + (0, 1, 1, 0) + ) + x = batchnorm2d(padded) + x = relu(x) + + x = conv2d(x, conv2) + x = batchnorm2d(x) + x = relu(x) + x = conv2d(x, conv3) + x = batchnorm2d(x) + return relu(x + input) diff --git a/npbench/benchmarks/deep_learning/softmax/softmax_jax.py b/npbench/benchmarks/deep_learning/softmax/softmax_jax.py new file mode 100644 index 0000000..c9a4f76 --- /dev/null +++ b/npbench/benchmarks/deep_learning/softmax/softmax_jax.py @@ -0,0 +1,11 @@ +import jax +import jax.numpy as jnp + + +# Numerically-stable version of softmax +@jax.jit +def softmax(x): + tmp_max = jnp.max(x, axis=-1, keepdims=True) + tmp_out = jnp.exp(x - tmp_max) + tmp_sum = jnp.sum(tmp_out, axis=-1, keepdims=True) + return tmp_out / tmp_sum diff --git a/npbench/benchmarks/go_fast/go_fast_jax.py b/npbench/benchmarks/go_fast/go_fast_jax.py new file mode 100644 index 0000000..4cebf26 --- /dev/null +++ b/npbench/benchmarks/go_fast/go_fast_jax.py @@ -0,0 +1,13 @@ +# https://numba.readthedocs.io/en/stable/user/5minguide.html + +import jax +import jax.numpy as jnp + +@jax.jit +def go_fast(a: jax.Array): + trace = 0.0 + def body_fn(i, trace): + trace += jnp.tanh(a[i, i]) + return trace + trace = jax.lax.fori_loop(0, a.shape[0], body_fn, trace) + return a + trace diff --git a/npbench/benchmarks/go_fast/go_fast_jax_lib.py b/npbench/benchmarks/go_fast/go_fast_jax_lib.py new file mode 100644 index 0000000..67cdf3f --- /dev/null +++ b/npbench/benchmarks/go_fast/go_fast_jax_lib.py @@ -0,0 +1,9 @@ +# https://numba.readthedocs.io/en/stable/user/5minguide.html + +import jax +import jax.numpy as jnp + +@jax.jit +def go_fast(a: jax.Array): + trace = jnp.sum(jnp.tanh(jnp.diag(a))) + return a + trace diff --git a/npbench/benchmarks/mandelbrot1/mandelbrot1_jax.py b/npbench/benchmarks/mandelbrot1/mandelbrot1_jax.py new file mode 100644 index 0000000..4953f02 --- /dev/null +++ b/npbench/benchmarks/mandelbrot1/mandelbrot1_jax.py @@ -0,0 +1,30 @@ +# ----------------------------------------------------------------------------- +# From Numpy to Python +# Copyright (2017) Nicolas P. Rougier - BSD license +# More information at https://github.com/rougier/numpy-book +# ----------------------------------------------------------------------------- + +import jax +import jax.numpy as jnp +from functools import partial + +@partial(jax.jit, static_argnames=["xn", "yn", "maxiter"]) +def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0): + # Adapted from https://www.ibm.com/developerworks/community/blogs/jfp/... + # .../entry/How_To_Compute_Mandelbrodt_Set_Quickly?lang=en + X = jnp.linspace(xmin, xmax, xn, dtype=jnp.float64) + Y = jnp.linspace(ymin, ymax, yn, dtype=jnp.float64) + C = X + Y[:, None] * 1j + N = jnp.zeros(C.shape, dtype=jnp.int64) + Z = jnp.zeros(C.shape, dtype=jnp.complex128) + + def body_fun(n, state): + Z, N = state + I = jnp.less(jnp.abs(Z), horizon) + new_N = jnp.where(I, n, N) + new_Z = jnp.where(I, Z**2 + C, Z) + return new_Z, new_N + + Z, N = jax.lax.fori_loop(0, maxiter, body_fun, (Z, N)) + N = jnp.where(N == maxiter-1, 0, N) + return Z, N diff --git a/npbench/benchmarks/mandelbrot2/mandelbrot2_jax.py b/npbench/benchmarks/mandelbrot2/mandelbrot2_jax.py new file mode 100644 index 0000000..4e42958 --- /dev/null +++ b/npbench/benchmarks/mandelbrot2/mandelbrot2_jax.py @@ -0,0 +1,52 @@ +# ----------------------------------------------------------------------------- +# From Numpy to Python +# Copyright (2017) Nicolas P. Rougier - BSD license +# More information at https://github.com/rougier/numpy-book +# ----------------------------------------------------------------------------- + +import jax +import jax.numpy as jnp +from functools import partial + +@partial(jax.jit, static_argnames=["xn", "yn", "itermax"]) +def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, itermax, horizon=2.0): + # Adapted from + # https://thesamovar.wordpress.com/2009/03/22/fast-fractals-with-python-and-numpy/ + Xi, Yi = jnp.mgrid[0:xn, 0:yn] + X = jnp.linspace(xmin, xmax, xn, dtype=jnp.float64)[Xi] + Y = jnp.linspace(ymin, ymax, yn, dtype=jnp.float64)[Yi] + C = X + Y * 1j + N_ = jnp.zeros(C.shape, dtype=jnp.int64) + Z_ = jnp.zeros(C.shape, dtype=jnp.complex128) + + original_shape = C.shape + Xi = Xi.reshape(-1) + Yi = Yi.reshape(-1) + C = C.reshape(-1) + + def body_fun(i, state): + Z, Xi, Yi, C, N_, Z_, mask = state + # Compute for relevant points only + Z = Z * Z + C + + # Failed convergence + I = abs(Z) > horizon + I = I & mask # Only consider points that haven't diverged yet + + N_ = jnp.where(I, i + 1, N_) + Z_ = jnp.where(I, Z, Z_) + + # Keep going with those who have not diverged yet + mask = mask & ~I + Z = jnp.where(mask, Z, 0) + + return (Z, Xi, Yi, C, N_, Z_, mask) + + init_state = (jnp.zeros_like(C, dtype=jnp.complex128), Xi, Yi, C, + N_.reshape(-1), Z_.reshape(-1), jnp.ones_like(C, dtype=bool)) + _, _, _, _, N_, Z_, _ = jax.lax.fori_loop(0, itermax, body_fun, init_state) + + Z_ = Z_.reshape(original_shape) # Reshape results back to original shape + N_ = N_.reshape(original_shape) + + return Z_.T, N_.T diff --git a/npbench/benchmarks/nbody/nbody_jax.py b/npbench/benchmarks/nbody/nbody_jax.py new file mode 100644 index 0000000..ca6a635 --- /dev/null +++ b/npbench/benchmarks/nbody/nbody_jax.py @@ -0,0 +1,132 @@ +# Adapted from https://github.com/pmocz/nbody-python/blob/master/nbody.py +# TODO: Add GPL-3.0 License + +import jax +import jax.numpy as jnp +from jax import lax +from functools import partial +""" +Create Your Own N-body Simulation (With Python) +Philip Mocz (2020) Princeton Univeristy, @PMocz +Simulate orbits of stars interacting due to gravity +Code calculates pairwise forces according to Newton's Law of Gravity +""" + +@jax.jit +def getAcc(pos, mass, G, softening): + """ + Calculate the acceleration on each particle due to Newton's Law + pos is an N x 3 matrix of positions + mass is an N x 1 vector of masses + G is Newton's Gravitational constant + softening is the softening length + a is N x 3 matrix of accelerations + """ + # positions r = [x,y,z] for all particles + x = pos[:, 0:1] + y = pos[:, 1:2] + z = pos[:, 2:3] + + # matrix that stores all pairwise particle separations: r_j - r_i + dx = x.T - x + dy = y.T - y + dz = z.T - z + + # matrix that stores 1/r^3 for all particle pairwise particle separations + inv_r3 = (dx**2 + dy**2 + dz**2 + softening**2) + inv_r3 = jnp.where(inv_r3 > 0, inv_r3**(-1.5), inv_r3) + + ax = G * (dx * inv_r3) @ mass + ay = G * (dy * inv_r3) @ mass + az = G * (dz * inv_r3) @ mass + + # pack together the acceleration components + a = jnp.hstack((ax, ay, az)) + + return a + +@jax.jit +def getEnergy(pos, vel, mass, G): + """ + Get kinetic energy (KE) and potential energy (PE) of simulation + pos is N x 3 matrix of positions + vel is N x 3 matrix of velocities + mass is an N x 1 vector of masses + G is Newton's Gravitational constant + KE is the kinetic energy of the system + PE is the potential energy of the system + """ + # Kinetic Energy: + # KE = 0.5 * np.sum(np.sum( mass * vel**2 )) + KE = 0.5 * jnp.sum(mass * vel**2) + + # Potential Energy: + + # positions r = [x,y,z] for all particles + x = pos[:, 0:1] + y = pos[:, 1:2] + z = pos[:, 2:3] + + # matrix that stores all pairwise particle separations: r_j - r_i + dx = x.T - x + dy = y.T - y + dz = z.T - z + + # matrix that stores 1/r for all particle pairwise particle separations + inv_r = jnp.sqrt(dx**2 + dy**2 + dz**2) + inv_r = jnp.where(inv_r > 0, 1.0 / inv_r, inv_r) + + # sum over upper triangle, to count each interaction only once + # PE = G * np.sum(np.sum(np.triu(-(mass*mass.T)*inv_r,1))) + PE = G * jnp.sum(jnp.triu(-(mass * mass.T) * inv_r, 1)) + + return KE, PE + +@partial(jax.jit, static_argnums=(4,)) +def nbody(mass, pos, vel, N, Nt, dt, G, softening): + + # Convert to Center-of-Mass frame + vel -= jnp.mean(mass * vel, axis=0) / jnp.mean(mass) + + # calculate initial gravitational accelerations + acc = getAcc(pos, mass, G, softening) + + # calculate initial energy of system + KE = jnp.empty(Nt + 1, dtype=jnp.float64) + PE = jnp.empty(Nt + 1, dtype=jnp.float64) + ke, pe = getEnergy(pos, vel, mass, G) + KE = KE.at[0].set(ke) + PE = PE.at[0].set(pe) + + t = 0.0 + + def loop_body(i, loop_vars): + pos, vel, acc, KE, PE, t = loop_vars + + # (1/2) kick + vel += acc * dt / 2.0 + + # drift + pos += vel * dt + + # update accelerations + acc = getAcc(pos, mass, G, softening) + + # (1/2) kick + vel += acc * dt / 2.0 + + # update time + t += dt + + # get energy of system + ke, pe = getEnergy(pos, vel, mass, G) + + KE = KE.at[i + 1].set(ke) + PE = PE.at[i + 1].set(pe) + + return pos, vel, acc, KE, PE, t + + # Simulation Main Loop + pos, vel, acc, KE, PE, t = lax.fori_loop(0, Nt, loop_body, (pos, vel, acc, KE, PE, t)) + + return KE, PE diff --git a/npbench/benchmarks/polybench/adi/adi_jax.py b/npbench/benchmarks/polybench/adi/adi_jax.py new file mode 100644 index 0000000..827a577 --- /dev/null +++ b/npbench/benchmarks/polybench/adi/adi_jax.py @@ -0,0 +1,81 @@ +import jax +import jax.numpy as jnp +from jax import lax + +def kernel(TSTEPS, N, u): + + v = jnp.zeros_like(u) + p = jnp.zeros_like(u) + q = jnp.zeros_like(u) + + DX = 1.0 / N + DY = 1.0 / N + DT = 1.0 / TSTEPS + B1 = 2.0 + B2 = 1.0 + mul1 = B1 * DT / (DX * DX) + mul2 = B2 * DT / (DY * DY) + + a = -mul1 / 2.0 + b = 1.0 + mul2 + c = a + d = -mul2 / 2.0 + e = 1.0 + mul2 + f = d + + def first_j_loop_body(j, carry): + p, q, u = carry + p = p.at[1:N-1, j].set(-c / (a * p[1:N-1, j-1] + b)) + q = q.at[1:N-1, j].set( + (-d * u[j, 0:N-2] + (1.0 + 2.0 * d) * u[j, 1:N-1] - f * u[j, 2:N] - + a * q[1:N-1, j-1]) / (a * p[1:N-1, j-1] + b)) + return (p, q, u) + + def first_backward_j_loop_body(j, carry): + v, p, q = carry + idx = N-2-j + v = v.at[idx, 1:N-1].set(p[1:N-1, idx] * v[idx+1, 1:N-1] + q[1:N-1, idx]) + return (v, p, q) + + def second_j_loop_body(j, carry): + p, q, v = carry + p = p.at[1:N-1, j].set(-f / (d * p[1:N-1, j-1] + e)) + q = q.at[1:N-1, j].set( + (-a * v[0:N-2, j] + (1.0 + 2.0 * a) * v[1:N-1, j] - c * v[2:N, j] - + d * q[1:N-1, j-1]) / (d * p[1:N-1, j-1] + e)) + return (p, q, v) + + def second_backward_j_loop_body(j, carry): + u, p, q = carry + idx = N-2-j + u = u.at[1:N-1, idx].set(p[1:N-1, idx] * u[1:N-1, idx+1] + q[1:N-1, idx]) + return (u, p, q) + + def time_step_body(t, carry): + u, v, p, q = carry + + v = v.at[0, 1:N-1].set(1.0) + p = p.at[1:N-1, 0].set(0.0) + q = q.at[1:N-1, 0].set(v[0, 1:N-1]) + + p, q, u = lax.fori_loop(1, N-1, first_j_loop_body, (p, q, u)) + + v = v.at[N-1, 1:N-1].set(1.0) + + v, p, q = lax.fori_loop(0, N-2, first_backward_j_loop_body, (v, p, q)) + + u = u.at[1:N-1, 0].set(1.0) + p = p.at[1:N-1, 0].set(0.0) + q = q.at[1:N-1, 0].set(u[1:N-1, 0]) + + p, q, v = lax.fori_loop(1, N-1, second_j_loop_body, (p, q, v)) + + u = u.at[1:N-1, N-1].set(1.0) + + u, p, q = lax.fori_loop(0, N-2, second_backward_j_loop_body, (u, p, q)) + + return (u, v, p, q) + + u, v, p, q = lax.fori_loop(1, TSTEPS + 1, time_step_body, (u, v, p, q)) + + return u \ No newline at end of file diff --git a/npbench/benchmarks/polybench/adi/adi_numpy.py b/npbench/benchmarks/polybench/adi/adi_numpy.py index 24920ba..c5df9a7 100644 --- a/npbench/benchmarks/polybench/adi/adi_numpy.py +++ b/npbench/benchmarks/polybench/adi/adi_numpy.py @@ -50,3 +50,5 @@ def kernel(TSTEPS, N, u): u[1:N - 1, N - 1] = 1.0 for j in range(N - 2, 0, -1): u[1:N - 1, j] = p[1:N - 1, j] * u[1:N - 1, j + 1] + q[1:N - 1, j] + + return u \ No newline at end of file diff --git a/npbench/benchmarks/polybench/atax/atax_jax.py b/npbench/benchmarks/polybench/atax/atax_jax.py new file mode 100644 index 0000000..a0c0dda --- /dev/null +++ b/npbench/benchmarks/polybench/atax/atax_jax.py @@ -0,0 +1,6 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(A, x): + return (A @ x) @ A diff --git a/npbench/benchmarks/polybench/bicg/bicg_jax.py b/npbench/benchmarks/polybench/bicg/bicg_jax.py new file mode 100644 index 0000000..9114b40 --- /dev/null +++ b/npbench/benchmarks/polybench/bicg/bicg_jax.py @@ -0,0 +1,7 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(A: jax.Array, p: jax.Array, r: jax.Array): + + return r @ A, A @ p diff --git a/npbench/benchmarks/polybench/cholesky/cholesky_jax.py b/npbench/benchmarks/polybench/cholesky/cholesky_jax.py new file mode 100644 index 0000000..55e6a3d --- /dev/null +++ b/npbench/benchmarks/polybench/cholesky/cholesky_jax.py @@ -0,0 +1,33 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def kernel(A): + + A = A.at[0, 0].set(jnp.sqrt(A[0, 0])) + + def row_update(i, A): + + def col_update(j, A): + mask = jnp.arange(A.shape[1]) < j + + A_i_slice = jnp.where(mask, A[i, :], 0) + A_j_slice = jnp.where(mask, A[j, :], 0) + + dot_product = jnp.dot(A_i_slice, A_j_slice) + A = A.at[i, j].set((A[i, j] - dot_product) / A[j, j]) + + return A + + A = lax.fori_loop(0, i, col_update, A) + + A_i_slice = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0) + dot_product = jnp.dot(A_i_slice, A_i_slice) + A = A.at[i, i].set(jnp.sqrt(A[i, i] - dot_product)) + + return A + + A = lax.fori_loop(1, A.shape[0], row_update, A) + + return A diff --git a/npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py b/npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py new file mode 100644 index 0000000..fa30afb --- /dev/null +++ b/npbench/benchmarks/polybench/cholesky2/cholesky2_jax.py @@ -0,0 +1,10 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(A: jax.Array): + + L = jnp.linalg.cholesky(A) + upper_A = jnp.triu(A, k=1) + + return L + upper_A diff --git a/npbench/benchmarks/polybench/correlation/correlation_jax.py b/npbench/benchmarks/polybench/correlation/correlation_jax.py new file mode 100644 index 0000000..ec03a79 --- /dev/null +++ b/npbench/benchmarks/polybench/correlation/correlation_jax.py @@ -0,0 +1,23 @@ +import jax +import jax.numpy as jnp +from jax import lax +from functools import partial + +@partial(jax.jit, static_argnums=(0,)) +def kernel(M, float_n, data): + + def loop_body(i, corr): + corr.at[i, i].set(1) + return corr + + mean = jnp.mean(data, axis=0) + stddev = jnp.std(data, axis=0) + stddev = jnp.where(stddev <= 0.1, 1.0, stddev) + data = data - mean + data = data / (jnp.sqrt(float_n) * stddev) + + corr = jnp.dot(data.T, data) + + corr = lax.fori_loop(0, M, loop_body, corr) + + return corr diff --git a/npbench/benchmarks/polybench/covariance/covariance_jax.py b/npbench/benchmarks/polybench/covariance/covariance_jax.py new file mode 100644 index 0000000..7ac7ba9 --- /dev/null +++ b/npbench/benchmarks/polybench/covariance/covariance_jax.py @@ -0,0 +1,12 @@ +import jax +import jax.numpy as jnp +from jax import lax +from functools import partial + +@partial(jax.jit, static_argnums=(0,)) +def kernel(M, float_n, data): + + mean = jnp.mean(data, axis=0) + data -= mean + cov = data.T @ data / (float_n - 1.0) + return cov diff --git a/npbench/benchmarks/polybench/covariance/covariance_jax_lib.py b/npbench/benchmarks/polybench/covariance/covariance_jax_lib.py new file mode 100644 index 0000000..e788823 --- /dev/null +++ b/npbench/benchmarks/polybench/covariance/covariance_jax_lib.py @@ -0,0 +1,10 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(M, float_n, data): + + cov = jnp.cov(data, rowvar=False) + + return cov \ No newline at end of file diff --git a/npbench/benchmarks/polybench/deriche/deriche_jax.py b/npbench/benchmarks/polybench/deriche/deriche_jax.py new file mode 100644 index 0000000..978c64f --- /dev/null +++ b/npbench/benchmarks/polybench/deriche/deriche_jax.py @@ -0,0 +1,70 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def kernel(alpha, imgIn): + + k = (1.0 - jnp.exp(-alpha)) * (1.0 - jnp.exp(-alpha)) / ( + 1.0 + alpha * jnp.exp(-alpha) - jnp.exp(2.0 * alpha)) + a1 = a5 = k + a2 = a6 = k * jnp.exp(-alpha) * (alpha - 1.0) + a3 = a7 = k * jnp.exp(-alpha) * (alpha + 1.0) + a4 = a8 = -k * jnp.exp(-2.0 * alpha) + b1 = 2.0**(-alpha) + b2 = -jnp.exp(-2.0 * alpha) + c1 = c2 = 1 + + y1 = jnp.empty_like(imgIn) + y1 = y1.at[:, 0].set(a1 * imgIn[:, 0]) + y1 = y1.at[:, 1].set(a1 * imgIn[:, 1] + a2 * imgIn[:, 0] + b1 * y1[:, 0]) + + def horizontal_forward(j, y1): + return y1.at[:, j].set( + a1 * imgIn[:, j] + a2 * imgIn[:, j-1] + + b1 * y1[:, j-1] + b2 * y1[:, j-2] + ) + + y1 = lax.fori_loop(2, imgIn.shape[1], horizontal_forward, y1) + + y2 = jnp.empty_like(imgIn) + y2 = y2.at[:, -1].set(0.0) + y2 = y2.at[:, -2].set(a3 * imgIn[:, -1]) + + def horizontal_backward(j, y2): + idx = imgIn.shape[1] - 3 - j + return y2.at[:, idx].set( + a3 * imgIn[:, idx+1] + a4 * imgIn[:, idx+2] + + b1 * y2[:, idx+1] + b2 * y2[:, idx+2] + ) + + y2 = lax.fori_loop(0, imgIn.shape[1]-2, horizontal_backward, y2) + + imgOut = c1 * (y1 + y2) + + y1 = jnp.empty_like(imgOut) + y1 = y1.at[0, :].set(a5 * imgOut[0, :]) + y1 = y1.at[1, :].set(a5 * imgOut[1, :] + a6 * imgOut[0, :] + b1 * y1[0, :]) + + def vertical_forward(i, y1): + return y1.at[i, :].set( + a5 * imgOut[i, :] + a6 * imgOut[i-1, :] + + b1 * y1[i-1, :] + b2 * y1[i-2, :] + ) + + y1 = lax.fori_loop(2, imgIn.shape[0], vertical_forward, y1) + + y2 = jnp.empty_like(imgOut) + y2 = y2.at[-1, :].set(0.0) + y2 = y2.at[-2, :].set(a7 * imgOut[-1, :]) + + def vertical_backward(i, y2): + idx = imgIn.shape[0] - 3 - i + return y2.at[idx, :].set( + a7 * imgOut[idx+1, :] + a8 * imgOut[idx+2, :] + + b1 * y2[idx+1, :] + b2 * y2[idx+2, :] + ) + + y2 = lax.fori_loop(0, imgIn.shape[0]-2, vertical_backward, y2) + + return c2 * (y1 + y2) \ No newline at end of file diff --git a/npbench/benchmarks/polybench/doitgen/doitgen_jax.py b/npbench/benchmarks/polybench/doitgen/doitgen_jax.py new file mode 100644 index 0000000..2d77255 --- /dev/null +++ b/npbench/benchmarks/polybench/doitgen/doitgen_jax.py @@ -0,0 +1,14 @@ +import jax +import jax.numpy as jnp + +from functools import partial + +@partial(jax.jit, static_argnums=(0, 1, 2)) +def kernel(NR, NQ, NP, A, C4): + + # for r in range(NR): + # for q in range(NQ): + # sum[:] = A[r, q, :] @ C4 + # A[r, q, :] = sum + A = A.at[:].set(jnp.reshape(jnp.reshape(A, (NR, NQ, 1, NP)) @ C4, (NR, NQ, NP))) + return A diff --git a/npbench/benchmarks/polybench/durbin/durbin_jax.py b/npbench/benchmarks/polybench/durbin/durbin_jax.py new file mode 100644 index 0000000..491ccac --- /dev/null +++ b/npbench/benchmarks/polybench/durbin/durbin_jax.py @@ -0,0 +1,31 @@ +import jax +import jax.numpy as jnp +from jax import lax + + +@jax.jit +def kernel(r): + + y = jnp.empty_like(r) + alpha = -r[0] + beta = 1.0 + y = y.at[0].set(-r[0]) + + def loop_body(k, loop_vars): + alpha, beta, y, r = loop_vars + beta *= 1.0 - alpha * alpha + mask = jnp.arange(r.shape[0]) < k + + products = jnp.where(mask, y * jnp.roll(jnp.flip(r), [k], 0),0.0) + dot_prod = jnp.sum(products) + alpha = -(r[k] + dot_prod) / beta + + y_update_slice = jnp.where(mask, jnp.roll(jnp.flip(y), [k], 0) * alpha, 0.0) + y += y_update_slice + y = y.at[k].set(alpha) + + return alpha, beta, y, r + + _, _, y, _ = lax.fori_loop(1, r.shape[0], loop_body, (alpha, beta, y, r)) + + return y diff --git a/npbench/benchmarks/polybench/fdtd_2d/fdtd_2d_jax.py b/npbench/benchmarks/polybench/fdtd_2d/fdtd_2d_jax.py new file mode 100644 index 0000000..98cf9bb --- /dev/null +++ b/npbench/benchmarks/polybench/fdtd_2d/fdtd_2d_jax.py @@ -0,0 +1,19 @@ +import jax +import jax.numpy as jnp +from jax import lax + + +@jax.jit +def kernel(TMAX, ex, ey, hz, _fict_): + + def loop_body(t, loop_vars): + ex, ey, hz = loop_vars + ey = ey.at[0, :].set(_fict_[t]) + ey = ey.at[1:, :].set(ey[1:, :] - 0.5 * (hz[1:, :] - hz[:-1, :])) + ex = ex.at[:, 1:].set(ex[:, 1:] - 0.5 * (hz[:, 1:] - hz[:, :-1])) + hz = hz.at[:-1, :-1].set(hz[:-1, :-1] - 0.7 * (ex[:-1, 1:] - ex[:-1, :-1] + + ey[1:, :-1] - ey[:-1, :-1])) + return ex, ey, hz + + ex, ey, hz = lax.fori_loop(0, TMAX, loop_body, (ex, ey, hz)) + return ex, ey, hz, _fict_ diff --git a/npbench/benchmarks/polybench/floyd_warshall/floyd_warshall_jax.py b/npbench/benchmarks/polybench/floyd_warshall/floyd_warshall_jax.py new file mode 100644 index 0000000..d5cdc25 --- /dev/null +++ b/npbench/benchmarks/polybench/floyd_warshall/floyd_warshall_jax.py @@ -0,0 +1,15 @@ +import jax +import jax.numpy as jnp +from jax import lax + + +@jax.jit +def kernel(path): + + def loop_func(k, path): + path = path.at[:].set(jnp.minimum(path[:], jnp.add.outer(path[:, k], path[k, :]))) + return path + + path = lax.fori_loop(0, path.shape[0], loop_func, path) + + return path diff --git a/npbench/benchmarks/polybench/gemm/gemm_jax.py b/npbench/benchmarks/polybench/gemm/gemm_jax.py new file mode 100644 index 0000000..b7254f0 --- /dev/null +++ b/npbench/benchmarks/polybench/gemm/gemm_jax.py @@ -0,0 +1,8 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(alpha, beta, C, A, B): + + C = C.at[:].set(alpha * A @ B + beta * C) + return C diff --git a/npbench/benchmarks/polybench/gemver/gemver_jax.py b/npbench/benchmarks/polybench/gemver/gemver_jax.py new file mode 100644 index 0000000..2e106cc --- /dev/null +++ b/npbench/benchmarks/polybench/gemver/gemver_jax.py @@ -0,0 +1,12 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(alpha, beta, A, u1, v1, u2, v2, w, x, y, z): + + A += jnp.outer(u1, v1) + jnp.outer(u2, v2) + x += beta * y @ A + z + w += alpha * A @ x + + return A, x, w diff --git a/npbench/benchmarks/polybench/gesummv/gesummv_jax.py b/npbench/benchmarks/polybench/gesummv/gesummv_jax.py new file mode 100644 index 0000000..f4f3aa2 --- /dev/null +++ b/npbench/benchmarks/polybench/gesummv/gesummv_jax.py @@ -0,0 +1,7 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(alpha, beta, A: jax.Array, B: jax.Array, x: jax.Array): + + return (alpha * A + beta * B) @ x diff --git a/npbench/benchmarks/polybench/gramschmidt/gramschmidt_jax.py b/npbench/benchmarks/polybench/gramschmidt/gramschmidt_jax.py new file mode 100644 index 0000000..55d7dbf --- /dev/null +++ b/npbench/benchmarks/polybench/gramschmidt/gramschmidt_jax.py @@ -0,0 +1,28 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(A): + + Q = jnp.zeros_like(A) + R = jnp.zeros((A.shape[1], A.shape[1]), dtype=A.dtype) + + def body_fun(k, arrays): + Q, R, A = arrays + + nrm = jnp.dot(A[:, k], A[:, k]) + R = R.at[k, k].set(jnp.sqrt(nrm)) + Q = Q.at[:, k].set(A[:, k] / R[k, k]) + + def inner_body_fun(j, arrays): + Q, R, A = arrays + R = R.at[k, j].set(jnp.dot(Q[:, k], A[:, j])) + A = A.at[:, j].add(-Q[:, k] * R[k, j]) + return Q, R, A + + Q, R, A = jax.lax.fori_loop(k + 1, A.shape[1], inner_body_fun, (Q, R, A)) + return Q, R, A + + Q, R, A = jax.lax.fori_loop(0, A.shape[1], body_fun, (Q, R, A)) + + return Q, R diff --git a/npbench/benchmarks/polybench/heat_3d/heat_3d_jax.py b/npbench/benchmarks/polybench/heat_3d/heat_3d_jax.py new file mode 100644 index 0000000..858b8fd --- /dev/null +++ b/npbench/benchmarks/polybench/heat_3d/heat_3d_jax.py @@ -0,0 +1,26 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def kernel(TSTEPS: int, A: jnp.ndarray, B: jnp.ndarray): + def time_step(t, arrays): + A, B = arrays + + B = B.at[1:-1, 1:-1, 1:-1].set( + 0.125 * (A[2:, 1:-1, 1:-1] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[:-2, 1:-1, 1:-1]) + + 0.125 * (A[1:-1, 2:, 1:-1] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[1:-1, :-2, 1:-1]) + + 0.125 * (A[1:-1, 1:-1, 2:] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[1:-1, 1:-1, :-2]) + + A[1:-1, 1:-1, 1:-1] + ) + A = A.at[1:-1, 1:-1, 1:-1].set( + 0.125 * (B[2:, 1:-1, 1:-1] - 2.0 * B[1:-1, 1:-1, 1:-1] + B[:-2, 1:-1, 1:-1]) + + 0.125 * (B[1:-1, 2:, 1:-1] - 2.0 * B[1:-1, 1:-1, 1:-1] + B[1:-1, :-2, 1:-1]) + + 0.125 * (B[1:-1, 1:-1, 2:] - 2.0 * B[1:-1, 1:-1, 1:-1] + B[1:-1, 1:-1, :-2]) + + B[1:-1, 1:-1, 1:-1] + ) + + return A, B + + A, B = lax.fori_loop(1, TSTEPS, time_step, (A, B)) + return A, B diff --git a/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py b/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py new file mode 100644 index 0000000..600ea06 --- /dev/null +++ b/npbench/benchmarks/polybench/jacobi_1d/jacobi_1d_jax.py @@ -0,0 +1,14 @@ +import jax +from jax import lax + +@jax.jit +def kernel(TSTEPS: int, A: jax.Array, B: jax.Array): + + def body_fn(t, arrays): + A, B = arrays + B = B.at[1:-1].set(0.33333 * (A[:-2] + A[1:-1] + A[2:])) + A = A.at[1:-1].set(0.33333 * (B[:-2] + B[1:-1] + B[2:])) + return A, B + + A, B = lax.fori_loop(1, TSTEPS, body_fn, (A, B)) + return A, B \ No newline at end of file diff --git a/npbench/benchmarks/polybench/jacobi_2d/jacobi_2d_jax.py b/npbench/benchmarks/polybench/jacobi_2d/jacobi_2d_jax.py new file mode 100644 index 0000000..054e414 --- /dev/null +++ b/npbench/benchmarks/polybench/jacobi_2d/jacobi_2d_jax.py @@ -0,0 +1,22 @@ +import jax +import jax.numpy as jnp +from jax import lax +from functools import partial + +@partial(jax.jit, static_argnums=(0,)) +def kernel(TSTEPS: int, A: jax.Array, B: jax.Array): + + def body_fn(t, arrays): + A, B = arrays + # Update B based on A + B = B.at[1:-1, 1:-1].set(0.2 * (A[1:-1, 1:-1] + A[1:-1, :-2] + A[1:-1, 2:] + + A[2:, 1:-1] + A[:-2, 1:-1])) + # Update A based on the new B + A = A.at[1:-1, 1:-1].set(0.2 * (B[1:-1, 1:-1] + B[1:-1, :-2] + B[1:-1, 2:] + + B[2:, 1:-1] + B[:-2, 1:-1])) + return A, B + + # Execute the loop for TSTEPS iterations + A, B = lax.fori_loop(1, TSTEPS, body_fn, (A, B)) + + return A, B diff --git a/npbench/benchmarks/polybench/k2mm/k2mm_jax.py b/npbench/benchmarks/polybench/k2mm/k2mm_jax.py new file mode 100644 index 0000000..fcb801c --- /dev/null +++ b/npbench/benchmarks/polybench/k2mm/k2mm_jax.py @@ -0,0 +1,9 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def kernel(alpha, beta, A, B, C, D): + + D = D.at[:].set(alpha * A @ B @ C + beta * D) + return D diff --git a/npbench/benchmarks/polybench/k3mm/k3mm_jax.py b/npbench/benchmarks/polybench/k3mm/k3mm_jax.py new file mode 100644 index 0000000..354d9d0 --- /dev/null +++ b/npbench/benchmarks/polybench/k3mm/k3mm_jax.py @@ -0,0 +1,7 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(A, B, C, D): + + return A @ B @ C @ D diff --git a/npbench/benchmarks/polybench/lu/lu_jax.py b/npbench/benchmarks/polybench/lu/lu_jax.py new file mode 100644 index 0000000..8957077 --- /dev/null +++ b/npbench/benchmarks/polybench/lu/lu_jax.py @@ -0,0 +1,32 @@ +import jax +import jax.numpy as jnp +from jax import lax + + +@jax.jit +def kernel(A): + + def loop_body(i, A): + def inner_loop_1(j, A): + mask = jnp.arange(A.shape[0]) < j + A_slice_1 = jnp.where(mask, A[i, :], 0.0) + A_slice_2 = jnp.where(mask, A[:, j], 0.0) + + A = A.at[i, j].set((A[i, j] - A_slice_1 @ A_slice_2) / A[j, j]) + return A + + def inner_loop_2(j, A): + mask = jnp.arange(A.shape[0]) < i + A_slice_1 = jnp.where(mask, A[i, :], 0.0) + A_slice_2 = jnp.where(mask, A[:, j], 0.0) + A = A.at[i, j].set(A[i, j] - A_slice_1 @ A_slice_2) + return A + + A = lax.fori_loop(0, i, inner_loop_1, A) + A = lax.fori_loop(i, A.shape[0], inner_loop_2, A) + + return A + + A = lax.fori_loop(0, A.shape[0], loop_body, A) + + return A \ No newline at end of file diff --git a/npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py b/npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py new file mode 100644 index 0000000..6c741f8 --- /dev/null +++ b/npbench/benchmarks/polybench/ludcmp/ludcmp_jax.py @@ -0,0 +1,49 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def kernel(A, b): + + x = jnp.zeros_like(b) + y = jnp.zeros_like(b) + + def loop_body_1(i, A): + def inner_loop_1(j, A): + A_slice_1 = jnp.where(jnp.arange(A.shape[1]) < j, A[i, :], 0.0) + A_slice_2 = jnp.where(jnp.arange(A.shape[0]) < j, A[:, j], 0.0) + + A = A.at[i, j].set((A[i, j] - A_slice_1 @ A_slice_2) / A[j, j]) + return A + + def inner_loop_2(j, A): + A_slice_1 = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0.0) + A_slice_2 = jnp.where(jnp.arange(A.shape[0]) < i, A[:, j], 0.0) + A = A.at[i, j].set(A[i, j] - A_slice_1 @ A_slice_2) + return A + + A = lax.fori_loop(0, i, inner_loop_1, A) + A = lax.fori_loop(i, A.shape[0], inner_loop_2, A) + + return A + + def loop_body_2(i, loop_vars): + A, y, b = loop_vars + A_slice = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0.0) + y_slice = jnp.where(jnp.arange(y.shape[0]) < i, y, 0.0) + y = y.at[i].set(b[i] - A_slice @ y_slice) + return A, y, b + + def loop_body_3(t, loop_vars): + A, x, y = loop_vars + i = A.shape[0] - 1 - t + A_slice = jnp.where(jnp.arange(A.shape[1]) > i, A[i, :], 0.0) + x_slice = jnp.where(jnp.arange(x.shape[0]) > i, x, 0.0) + x = x.at[i].set((y[i] - A_slice @ x_slice) / A[i, i]) + return A, x, y + + A = lax.fori_loop(0, A.shape[0], loop_body_1, A) + A, y, b = lax.fori_loop(0, A.shape[0], loop_body_2, (A, y, b)) + A, x, y = lax.fori_loop(0, A.shape[0], loop_body_3, (A, x, y)) + + return x, y diff --git a/npbench/benchmarks/polybench/mvt/mvt_jax.py b/npbench/benchmarks/polybench/mvt/mvt_jax.py new file mode 100644 index 0000000..732051f --- /dev/null +++ b/npbench/benchmarks/polybench/mvt/mvt_jax.py @@ -0,0 +1,10 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def kernel(x1, x2, y_1, y_2, A): + + x1 += A @ y_1 + x2 += y_2 @ A + + return (x1, x2) diff --git a/npbench/benchmarks/polybench/nussinov/nussinov_jax.py b/npbench/benchmarks/polybench/nussinov/nussinov_jax.py new file mode 100644 index 0000000..b59e387 --- /dev/null +++ b/npbench/benchmarks/polybench/nussinov/nussinov_jax.py @@ -0,0 +1,64 @@ +import jax.numpy as jnp +import jax +from jax import lax +from functools import partial + +@jax.jit +def match(b1, b2): + return jnp.where(b1 + b2 == 3, 1, 0) + + +@partial(jax.jit, static_argnums=(0,)) +def kernel(N, seq): + + table = jnp.zeros((N, N), jnp.int32) + + def func_i(i, table): + i = N - 1 - i + def func_j(j, table): + table = table.at[i, j].set( + jnp.where( + j - 1 >= 0, + jnp.maximum(table[i, j], table[i, j - 1]), + table[i, j] + ) + ) + table = table.at[i, j].set( + jnp.where( + i + 1 < N, + jnp.maximum(table[i, j], table[i + 1, j]), + table[i, j] + ) + ) + table = table.at[i, j].set( + jnp.where( + (j - 1 >= 0) & (i + 1 < N) & (i < j - 1), + jnp.maximum(table[i, j], table[i + 1, j - 1] + match(seq[i], seq[j])), + table[i, j] + ) + ) + table = table.at[i, j].set( + jnp.where( + (j - 1 >= 0) & (i + 1 < N) & (i >= j - 1), + jnp.maximum(table[i, j], table[i + 1, j - 1]), + table[i, j] + ) + ) + + def func_k(k, table): + table = table.at[i, j].set( + jnp.maximum( + table[i, j], + table[i, k] + table[k + 1, j] + ) + ) + return table + + table = lax.fori_loop(i + 1, j, func_k, table) + return table + + table = lax.fori_loop(i + 1, N, func_j, table) + return table + + table = lax.fori_loop(0, N, func_i, table) + return table diff --git a/npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py b/npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py new file mode 100644 index 0000000..5f0131a --- /dev/null +++ b/npbench/benchmarks/polybench/seidel_2d/seidel_2d_jax.py @@ -0,0 +1,31 @@ +import jax +import jax.numpy as jnp +from jax import lax + + +@jax.jit +def kernel(TSTEPS, N, A): + + def loop1(t, A): + + def loop2(i, A): + + def loop3(j, A): + A = A.at[i, j].set((A[i, j] + A[i, j - 1]) / 9.0) + return A + + A = A.at[i, 1:-1].set( + A[i, 1:-1] + (A[i - 1, :-2] + A[i - 1, 1:-1] + A[i - 1, 2:] + + A[i, 2:] + A[i + 1, :-2] + A[i + 1, 1:-1] + + A[i + 1, 2:]) + ) + + A = lax.fori_loop(1, N - 1, loop3, A) + return A + + A = lax.fori_loop(1, N - 1, loop2, A) + return A + + A = lax.fori_loop(0, TSTEPS - 1, loop1, A) + + return A diff --git a/npbench/benchmarks/polybench/symm/symm_jax.py b/npbench/benchmarks/polybench/symm/symm_jax.py new file mode 100644 index 0000000..abeafce --- /dev/null +++ b/npbench/benchmarks/polybench/symm/symm_jax.py @@ -0,0 +1,34 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def kernel(alpha, beta, C: jax.Array, A: jax.Array, B: jax.Array): + + temp2 = jnp.empty((C.shape[1], ), dtype=C.dtype) + C *= beta + + def row_update(i, arrays): + C, temp2 = arrays + + def col_update(j, val): + C, temp2 = val + + A_slice = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0.0) + B_slice = jnp.where(jnp.arange(B.shape[0]) < i, B[:, j], 0.0) + + C = lax.dynamic_update_slice( + C, + (C[:,j] + (alpha * B[i, j] * A_slice))[:, None], + (0, j) + ) + temp2 = temp2.at[j].set(B_slice @ A_slice) + return C, temp2 + + C, temp2 = lax.fori_loop(0, C.shape[1], col_update, (C, temp2)) + C = C.at[i, :].add(alpha * B[i, :] * A[i, i] + alpha * temp2) + return C, temp2 + + C, _ = lax.fori_loop(0, C.shape[0], row_update, (C, temp2)) + return C + diff --git a/npbench/benchmarks/polybench/syr2k/syr2k_jax.py b/npbench/benchmarks/polybench/syr2k/syr2k_jax.py new file mode 100644 index 0000000..2458c69 --- /dev/null +++ b/npbench/benchmarks/polybench/syr2k/syr2k_jax.py @@ -0,0 +1,41 @@ +import jax +import jax.numpy as jnp +from jax import lax + + +@jax.jit +def kernel(alpha, beta, C, A, B): + + def loop_body(i, loop_vars): + + def inner_loop(k, loop_vars): + alpha, C, A, B = loop_vars + A_update_slice = jnp.where(jnp.arange(A.shape[0]) < i + 1, A[:, k], 0.0) + A_update_slice *= alpha * B[i, k] + + + B_update_slice = jnp.where(jnp.arange(B.shape[0]) < i + 1, B[:, k], 0.0) + B_update_slice *= alpha * A[i, k] + + C_update_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C[i, :], 0.0) + C_update_slice += A_update_slice + B_update_slice + C_update_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C_update_slice, C[i, :]) + + C = lax.dynamic_update_slice(C, C_update_slice[None, :], (i, 0)) + return alpha, C, A, B + + + alpha, beta, C, A, B = loop_vars + C_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C[i, :], 0.0) + C_slice = C_slice * beta + C_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C_slice, C[i, :]) + + C = lax.dynamic_update_slice(C, C_slice[None, :], (i, 0)) + + _, C, _, _ = lax.fori_loop(0, A.shape[1], inner_loop, (alpha, C, A, B)) + + return alpha, beta, C, A, B + + _, _, C, _, _ = lax.fori_loop(0, A.shape[0], loop_body, (alpha, beta, C, A, B)) + + return C diff --git a/npbench/benchmarks/polybench/syrk/syrk_jax.py b/npbench/benchmarks/polybench/syrk/syrk_jax.py new file mode 100644 index 0000000..8dcab55 --- /dev/null +++ b/npbench/benchmarks/polybench/syrk/syrk_jax.py @@ -0,0 +1,36 @@ +import jax +import jax.numpy as jnp +from jax import lax + + +@jax.jit +def kernel(alpha, beta, C, A): + + def loop_body(i, loop_vars): + + def inner_loop(k, loop_vars): + alpha, C, A = loop_vars + A_update_slice = jnp.where(jnp.arange(A.shape[0]) < i + 1, A[:, k], 0.0) + A_update_slice *= alpha * A[i, k] + + C_update_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C[i, :], 0.0) + C_update_slice += A_update_slice + C_update_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C_update_slice, C[i, :]) + + C = lax.dynamic_update_slice(C, C_update_slice[None, :], (i, 0)) + return alpha, C, A + + alpha, beta, C, A = loop_vars + + C_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C[i, :], 0.0) + C_slice = C_slice * beta + C_slice = jnp.where(jnp.arange(C.shape[1]) < i + 1, C_slice, C[i, :]) + C = lax.dynamic_update_slice(C, C_slice[None, :], (i, 0)) + + _, C, _ = lax.fori_loop(0, A.shape[1], inner_loop, (alpha, C, A)) + + return alpha, beta, C, A + + _, _, C, _ = lax.fori_loop(0, A.shape[0], loop_body, (alpha, beta, C, A)) + + return C diff --git a/npbench/benchmarks/polybench/trisolv/trisolv_jax.py b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py new file mode 100644 index 0000000..192f8d3 --- /dev/null +++ b/npbench/benchmarks/polybench/trisolv/trisolv_jax.py @@ -0,0 +1,19 @@ +import jax +import jax.numpy as jnp +from jax import lax + + +@jax.jit +def kernel(L, x, b): + + def loop_body(i, loop_vars): + L, x, b = loop_vars + mask = jnp.arange(x.shape[0]) < i + products = jnp.where(mask, L[i, :] * x, 0.0) + dot_product = jnp.sum(products) + x = x.at[i].set((b[i] - dot_product) / L[i, i]) + return L, x, b + + _, x, _ = lax.fori_loop(0, x.shape[0], loop_body, (L, x, b)) + + return x diff --git a/npbench/benchmarks/polybench/trisolv/trisolv_jax_lib.py b/npbench/benchmarks/polybench/trisolv/trisolv_jax_lib.py new file mode 100644 index 0000000..287dfd4 --- /dev/null +++ b/npbench/benchmarks/polybench/trisolv/trisolv_jax_lib.py @@ -0,0 +1,9 @@ +import jax +import jax.numpy as jnp +from jax import lax + +@jax.jit +def kernel(L, x, b): + + x = jax.scipy.linalg.solve_triangular(L, b, lower=True) + return x diff --git a/npbench/benchmarks/polybench/trmm/trmm_jax.py b/npbench/benchmarks/polybench/trmm/trmm_jax.py new file mode 100644 index 0000000..cb27935 --- /dev/null +++ b/npbench/benchmarks/polybench/trmm/trmm_jax.py @@ -0,0 +1,15 @@ +import jax +import jax.numpy as jnp + +# import numpy as np + + + +@jax.jit +def kernel(alpha, A, B): + + L = jnp.triu(A, 1) # 1 excludes the main diagonal + B += L @ B + B *= alpha + + return B diff --git a/npbench/benchmarks/pythran/arc_distance/arc_distance_jax.py b/npbench/benchmarks/pythran/arc_distance/arc_distance_jax.py new file mode 100644 index 0000000..883c7ac --- /dev/null +++ b/npbench/benchmarks/pythran/arc_distance/arc_distance_jax.py @@ -0,0 +1,42 @@ +# Copyright (c) 2019, Serge Guelton +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# Neither the name of HPCProject, Serge Guelton nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import jax +import jax.numpy as jnp + + +@jax.jit +def arc_distance(theta_1, phi_1, theta_2, phi_2): + """ + Calculates the pairwise arc distance between all points in vector a and b. + """ + temp = jnp.sin((theta_2 - theta_1) / + 2)**2 + jnp.cos(theta_1) * jnp.cos(theta_2) * jnp.sin( + (phi_2 - phi_1) / 2)**2 + distance_matrix = 2 * (jnp.arctan2(jnp.sqrt(temp), jnp.sqrt(1 - temp))) + return distance_matrix \ No newline at end of file diff --git a/npbench/benchmarks/scattering_self_energies/scattering_self_energies_jax.py b/npbench/benchmarks/scattering_self_energies/scattering_self_energies_jax.py new file mode 100644 index 0000000..8bcc1a9 --- /dev/null +++ b/npbench/benchmarks/scattering_self_energies/scattering_self_energies_jax.py @@ -0,0 +1,35 @@ +import jax +import jax.numpy as jnp + +@jax.jit +def scattering_self_energies(neigh_idx, dH, G, D, Sigma): + def body_fun(sigma, idx): + k, E, q, w, a, b, i, j = idx + + dHG = G[k, E - w, neigh_idx[a, b]] @ dH[a, b, i] + dHD = dH[a, b, j] * D[q, w, a, b, i, j] + + update = jnp.where(E >= w, dHG @ dHD, 0.0) + + return sigma.at[k, E, a].add(update), None + + k_range = jnp.arange(G.shape[0]) + E_range = jnp.arange(G.shape[1]) + q_range = jnp.arange(D.shape[0]) + w_range = jnp.arange(D.shape[1]) + a_range = jnp.arange(neigh_idx.shape[0]) + b_range = jnp.arange(neigh_idx.shape[1]) + i_range = jnp.arange(D.shape[-2]) + j_range = jnp.arange(D.shape[-1]) + + indices = jnp.meshgrid( # Create meshgrid of indices + k_range, E_range, q_range, w_range, + a_range, b_range, i_range, j_range, + indexing='ij' + ) + + indices = jnp.stack([idx.ravel() for idx in indices], axis=1) # Reshape indices into a single array of 8-tuples + + result, _ = jax.lax.scan(body_fun, Sigma, indices) # Use scan to iterate over all index combinations + + return result diff --git a/npbench/benchmarks/spmv/spmv_jax.py b/npbench/benchmarks/spmv/spmv_jax.py new file mode 100644 index 0000000..c4f4645 --- /dev/null +++ b/npbench/benchmarks/spmv/spmv_jax.py @@ -0,0 +1,12 @@ +# Sparse Matrix-Vector Multiplication (SpMV) +from jax.experimental import sparse as jax_sparse +import scipy + +# Matrix-Vector Multiplication with the matrix given in Compressed Sparse Row +# (CSR) format +def spmv(A_row, A_col, A_val, x): + dim = A_row.size - 1 # needed because for the "paper" test size, scipy auto-infers the dims wrong + matrix_in_csr_format = scipy.sparse.csr_matrix((A_val, A_col, A_row), shape=(dim, dim)) + matrix_in_bcoo_format = jax_sparse.BCOO.from_scipy_sparse(matrix_in_csr_format) + + return matrix_in_bcoo_format @ x diff --git a/npbench/benchmarks/stockham_fft/stockham_fft_jax.py b/npbench/benchmarks/stockham_fft/stockham_fft_jax.py new file mode 100644 index 0000000..b222339 --- /dev/null +++ b/npbench/benchmarks/stockham_fft/stockham_fft_jax.py @@ -0,0 +1,32 @@ +import jax +import jax.numpy as jnp +from functools import partial + + +@partial(jax.jit, static_argnames=["N", "R", "K"]) +def stockham_fft(N, R, K, x, y): + + # Generate DFT matrix for radix R. + # Define transient variable for matrix. + i_coord, j_coord = jnp.mgrid[0:R, 0:R] + # dft_mat = jnp.empty((R, R), dtype=jnp.complex128) + dft_mat = jnp.exp(-2.0j * jnp.pi * i_coord * j_coord / R) + y = x + + ii_coord, jj_coord = jnp.mgrid[0:R, 0:R**K] + + # Main Stockham loop + for i in range(K): + # Stride permutation + yv = jnp.reshape(y, (R**i, R, R**(K - i - 1))) + tmp_perm = jnp.transpose(yv, axes=(1, 0, 2)) + + # Twiddle Factor multiplication + tmp = jnp.exp(-2.0j * jnp.pi * ii_coord[:, :R**i] * jj_coord[:, :R**i] / R**(i + 1)) + D = jnp.repeat(jnp.reshape(tmp, (R, R**i, 1)), R**(K - i - 1), axis=2) + tmp_twid = jnp.reshape(tmp_perm, (N, )) * jnp.reshape(D, (N, )) + + # Product with Butterfly + y = jnp.reshape(dft_mat @ jnp.reshape(tmp_twid, (R, R**(K - 1))),(N, )) + + return y \ No newline at end of file diff --git a/npbench/benchmarks/weather_stencils/hdiff/hdiff_jax.py b/npbench/benchmarks/weather_stencils/hdiff/hdiff_jax.py new file mode 100644 index 0000000..fa472fa --- /dev/null +++ b/npbench/benchmarks/weather_stencils/hdiff/hdiff_jax.py @@ -0,0 +1,32 @@ +import jax +import jax.numpy as jnp + +# Adapted from https://github.com/GridTools/gt4py/blob/1caca893034a18d5df1522ed251486659f846589/tests/test_integration/stencil_definitions.py#L194 +@jax.jit +def hdiff(in_field, out_field, coeff): + I, J, K = out_field.shape[0], out_field.shape[1], out_field.shape[2] + lap_field = 4.0 * in_field[1:I + 3, 1:J + 3, :] - ( + in_field[2:I + 4, 1:J + 3, :] + in_field[0:I + 2, 1:J + 3, :] + + in_field[1:I + 3, 2:J + 4, :] + in_field[1:I + 3, 0:J + 2, :]) + + res = lap_field[1:, 1:J + 1, :] - lap_field[:-1, 1:J + 1, :] + flx_field = jnp.where( + (res * + (in_field[2:I + 3, 2:J + 2, :] - in_field[1:I + 2, 2:J + 2, :])) > 0, + 0, + res, + ) + + res = lap_field[1:I + 1, 1:, :] - lap_field[1:I + 1, :-1, :] + fly_field = jnp.where( + (res * + (in_field[2:I + 2, 2:J + 3, :] - in_field[2:I + 2, 1:J + 2, :])) > 0, + 0, + res, + ) + + out_field = out_field.at[:, :, :].set(in_field[2:I + 2, 2:J + 2, :] - coeff[:, :, :] * ( + flx_field[1:, :, :] - flx_field[:-1, :, :] + fly_field[:, 1:, :] - + fly_field[:, :-1, :])) + + return out_field \ No newline at end of file diff --git a/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py b/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py new file mode 100644 index 0000000..a6d6a19 --- /dev/null +++ b/npbench/benchmarks/weather_stencils/vadv/vadv_jax.py @@ -0,0 +1,103 @@ +import jax +import jax.numpy as jnp +from jax import lax + +# Sample constants +BET_M = 0.5 +BET_P = 0.5 + +# Adapted from https://github.com/GridTools/gt4py/blob/1caca893034a18d5df1522ed251486659f846589/tests/test_integration/stencil_definitions.py#L111 +@jax.jit +def vadv(utens_stage, u_stage, wcon, u_pos, utens, dtr_stage): + I, J, K = utens_stage.shape[0], utens_stage.shape[1], utens_stage.shape[2] + ccol = jnp.empty((I, J, K), dtype=utens_stage.dtype) + dcol = jnp.empty((I, J, K), dtype=utens_stage.dtype) + data_col = jnp.empty((I, J), dtype=utens_stage.dtype) + + def loop1(k, loop_vars): + ccol, dcol = loop_vars + gcv = 0.25 * (wcon[1:, :, 0 + 1] + wcon[:-1, :, 0 + 1]) + cs = gcv * BET_M + bs = gcv * BET_P + bcol = dtr_stage - bs + + # update the d column + correction_term = -cs * (u_stage[:, :, k + 1] - u_stage[:, :, k]) + + # Thomas forward + divided = 1.0 / bcol + ccol = ccol.at[:, :, k].set(bs * divided) + dcol = dcol.at[:, :, k].set((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + + utens_stage[:, :, k] + correction_term) * divided) + return ccol, dcol + + ccol, dcol = lax.fori_loop(0, 1, loop1, (ccol, dcol)) + + def loop2(k, loop_vars): + ccol, dcol = loop_vars + gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) + gcv = 0.25 * (wcon[1:, :, k + 1] + wcon[:-1, :, k + 1]) + + as_ = gav * BET_M + cs = gcv * BET_M + bs = gcv * BET_P + + acol = gav * BET_P + bcol = dtr_stage - acol - bs + + # update the d column + correction_term = -as_ * (u_stage[:, :, k - 1] - + u_stage[:, :, k]) - cs * ( + u_stage[:, :, k + 1] - u_stage[:, :, k]) + + # Thomas forward + divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) + ccol = ccol.at[:, :, k].set(bs * divided) + dcol = dcol.at[:, :, k].set(((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + + utens_stage[:, :, k] + correction_term) - (dcol[:, :, k - 1]) * acol) * divided) + + return ccol, dcol + + ccol, dcol = lax.fori_loop(1, K - 1, loop2, (ccol, dcol)) + + def loop3(k, dcol): + gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) + as_ = gav * BET_M + acol = gav * BET_P + bcol = dtr_stage - acol + + # update the d column + correction_term = -as_ * (u_stage[:, :, k - 1] - u_stage[:, :, k]) + + # Thomas forward + divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) + dcol = dcol.at[:, :, k].set(((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + + utens_stage[:, :, k] + correction_term) - (dcol[:, :, k - 1]) * acol) * divided) + + return dcol + + dcol = lax.fori_loop(K - 1, K, loop3, dcol) + + def loop4(k, loop_vars): + data_col, utens_stage = loop_vars + datacol = dcol[:, :, k] + data_col = data_col.at[:].set(datacol) + utens_stage = utens_stage.at[:, :, k].set(dtr_stage * (datacol - u_pos[:, :, k])) + + return data_col, utens_stage + + data_col, utens_stage = lax.fori_loop(K - 1, K, loop4, (data_col, utens_stage)) + + def loop5(k, loop_vars): + data_col, utens_stage = loop_vars + K = utens_stage.shape[2] + k = K - 2 - k + datacol = dcol[:, :, k] - ccol[:, :, k] * data_col[:, :] + data_col = data_col.at[:].set(datacol) + utens_stage = utens_stage.at[:, :, k].set(dtr_stage * (datacol - u_pos[:, :, k])) + + return data_col, utens_stage + + data_col, utens_stage = lax.fori_loop(0, K - 1, loop5, (data_col, utens_stage)) + + return ccol, dcol, data_col, utens_stage diff --git a/npbench/infrastructure/__init__.py b/npbench/infrastructure/__init__.py index ac403b7..8b13b1b 100644 --- a/npbench/infrastructure/__init__.py +++ b/npbench/infrastructure/__init__.py @@ -10,3 +10,4 @@ from .legate_framework import * from .numba_framework import * from .pythran_framework import * +from .jax_framework import * diff --git a/npbench/infrastructure/jax_framework.py b/npbench/infrastructure/jax_framework.py new file mode 100644 index 0000000..b58625d --- /dev/null +++ b/npbench/infrastructure/jax_framework.py @@ -0,0 +1,106 @@ +# Copyright 2021 ETH Zurich and the NPBench authors. All rights reserved. +import pathlib +import jax.numpy as jnp +import jax +jax.config.update("jax_enable_x64", True) + +from npbench.infrastructure import Benchmark, Framework +from typing import Any, Callable, Dict + + +_impl = { + 'lib-implementation': 'lib' +} + +class JaxFramework(Framework): + """ A class for reading and processing framework information. """ + + def __init__(self, fname: str): + """ Reads framework information. + :param fname: The framework name. + """ + + super().__init__(fname) + + def imports(self) -> Dict[str, Any]: + return {'jax': jax} + + def copy_func(self) -> Callable: + """ Returns the copy-method that should be used + for copying the benchmark arguments. """ + return jnp.array + + def impl_files(self, bench: Benchmark): + """ Returns the framework's implementation files for a particular + benchmark. + :param bench: A benchmark. + :returns: A list of the benchmark implementation files. + """ + + parent_folder = pathlib.Path(__file__).parent.absolute() + implementations = [] + + # appending the default implementation + pymod_path = parent_folder.joinpath("..", "..", "npbench", "benchmarks", bench.info["relative_path"], + bench.info["module_name"] + "_" + self.info["postfix"] + ".py") + + implementations.append((pymod_path, 'default')) + + for impl_name, impl_postfix in _impl.items(): + pymod_path = parent_folder.joinpath( + "..", "..", "npbench", "benchmarks", bench.info["relative_path"], + bench.info["module_name"] + "_" + self.info["postfix"] + "_" + impl_postfix + ".py") + implementations.append((pymod_path, impl_name)) + + return implementations + + def implementations(self, bench: Benchmark): + """ Returns the framework's implementations for a particular benchmark. + :param bench: A benchmark. + :returns: A list of the benchmark implementations. + """ + + module_pypath = "npbench.benchmarks.{r}.{m}".format(r=bench.info["relative_path"].replace('/', '.'), + m=bench.info["module_name"]) + if "postfix" in self.info.keys(): + postfix = self.info["postfix"] + else: + postfix = self.fname + module_str = "{m}_{p}".format(m=module_pypath, p=postfix) + func_str = bench.info["func_name"] + + implementations = [] + + # appending the default implementation + try: + ldict = dict() + exec("from {m} import {f} as impl".format(m=module_str, f=func_str), ldict) + implementations.append((ldict['impl'], 'default')) + except Exception as e: + print("Failed to load the {r} {f} implementation.".format(r=self.info["full_name"], f=func_str)) + raise e + + for impl_name, impl_postfix in _impl.items(): + ldict = dict() + try: + exec("from {m}_{p} import {f} as impl".format(m=module_str, p=impl_postfix, f=func_str), ldict) + implementations.append((ldict['impl'], impl_name)) + except ImportError: + continue + except Exception: + print("Failed to load the {r} {f} implementation.".format(r=self.info["full_name"], f=impl_name)) + continue + + return implementations + + def exec_str(self, bench: Benchmark, impl: Callable = None): + """ Generates the execution-string that should be used to call + the benchmark implementation. + :param bench: A benchmark. + :param impl: A benchmark implementation. + """ + + arg_str = self.arg_str(bench, impl) + main_exec_str = "__npb_result = jax.block_until_ready(__npb_impl({a}))".format(a=arg_str) + + return main_exec_str diff --git a/npbench/infrastructure/test.py b/npbench/infrastructure/test.py index da0eafd..cd409b5 100644 --- a/npbench/infrastructure/test.py +++ b/npbench/infrastructure/test.py @@ -101,7 +101,7 @@ def first_execution(impl, impl_name): valid = True if validate and np_out is not None: try: - frmwrk_name = self.frmwrk.info["full_name"] + frmwrk_name = self.frmwrk.info["full_name"] + " - " + impl_name rtol = 1e-5 if not 'rtol' in self.bench.info else self.bench.info['rtol'] atol = 1e-8 if not 'atol' in self.bench.info else self.bench.info['atol']