Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX Support to NPBench and Implement JAX Benchmarks #31

Open
wants to merge 114 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
f1e1e99
Add .dacecache to gitignore
hardik01shah Oct 29, 2024
b16a0af
Add Jax framework
jaksicf Oct 29, 2024
a331c45
Add Jax "compute" benchmark
jaksicf Oct 29, 2024
3de6ec7
Add jax_framework
hardik01shah Nov 2, 2024
76cd00a
Add jax_framework
hardik01shah Nov 2, 2024
8462fce
Merge branch 'main' of https://github.com/hardik01shah/npbench into main
hardik01shah Nov 2, 2024
cacf7f5
Add azimint_hist jax implementation
hardik01shah Nov 2, 2024
e9ae934
Add go_fast jax implementation
hardik01shah Nov 2, 2024
40dd641
Add bicg jax implementation
hardik01shah Nov 2, 2024
053d87b
Add cholesky2 jax implementation
hardik01shah Nov 2, 2024
7ddec27
Add gesummv jax implementation
hardik01shah Nov 2, 2024
a69848c
Add jacobi_1d jax implementation
hardik01shah Nov 2, 2024
aea5a70
Add jacobi_2d jax implementation
hardik01shah Nov 2, 2024
24b0622
Add cholesky jax implementation
hardik01shah Nov 2, 2024
3d43c3b
Update go_fast jax implementation
hardik01shah Nov 3, 2024
5d23224
Update cholesky2 jax implementation
hardik01shah Nov 3, 2024
069029d
Update gesummv jax implementation
hardik01shah Nov 3, 2024
cc88ecd
Update jacobi_1d jax implementation
hardik01shah Nov 3, 2024
7d9e7f2
Update jacobi_2d jax implementation
hardik01shah Nov 3, 2024
dc24a9f
add covariance jax
sushant1212 Nov 4, 2024
6ffc18d
Add atax Jax implementation
jaksicf Nov 4, 2024
de3d8f6
Add doitgen Jax implementation
jaksicf Nov 4, 2024
3967a00
Add gemm Jax implementation
jaksicf Nov 4, 2024
eca1dc8
Add k2mm Jax implementation
jaksicf Nov 4, 2024
da9e2b5
Add k3mm Jax implementation
jaksicf Nov 4, 2024
7264536
Add mvt Jax implementation
jaksicf Nov 4, 2024
92cdc38
Add softmax Jax implementation
jaksicf Nov 4, 2024
058faf9
Add trmm Jax implementation
jaksicf Nov 4, 2024
eb854cb
Merge branch 'main' of github.com:hardik01shah/npbench
sushant1212 Nov 5, 2024
f877c2b
add fdtd_2d
sushant1212 Nov 5, 2024
b39be28
add floyd_warshall
sushant1212 Nov 5, 2024
84559e7
add lu
sushant1212 Nov 5, 2024
30635b6
add seidel_2d
sushant1212 Nov 5, 2024
0d4e625
add syrk
sushant1212 Nov 5, 2024
d298aa9
add syr2k
sushant1212 Nov 5, 2024
f0fd399
add trisolv
sushant1212 Nov 5, 2024
4b871d8
Update block_until_ready() call in exec_str so returning modified arr…
hardik01shah Nov 6, 2024
27337cf
update floyd_warshall
sushant1212 Nov 10, 2024
fea3c9c
add azimint_naive
sushant1212 Nov 12, 2024
b7a140d
add correlation
sushant1212 Nov 12, 2024
c23ec04
Add symm jax implementation
hardik01shah Nov 16, 2024
6e53257
Add conv2d jax implementation
hardik01shah Nov 16, 2024
e46e344
Add heat_3d jax implementation
hardik01shah Nov 16, 2024
57d7aa4
Add mlp jax implementation
hardik01shah Nov 16, 2024
09d404a
Fix symm jax implementation
hardik01shah Nov 16, 2024
a01496e
Add contour_integral jax implementation
hardik01shah Nov 16, 2024
87df016
Fix jacobi_1d jax implementation
hardik01shah Nov 16, 2024
17e00ab
Add gramschmidt jax implementation
hardik01shah Nov 16, 2024
34db9e1
Fix cholesky jax implementation
hardik01shah Nov 17, 2024
591f79f
Add spmv jax implementation
hardik01shah Nov 17, 2024
330511d
add durbin
sushant1212 Nov 18, 2024
6986abe
add lud_cmp
sushant1212 Nov 18, 2024
e6b37d1
update lu
sushant1212 Nov 18, 2024
bb8661d
update trisolv
sushant1212 Nov 18, 2024
936a90a
update syr2k
sushant1212 Nov 18, 2024
4c15b77
update syrk
sushant1212 Nov 18, 2024
be2c320
Add mandelbrot1 Jax implementation
jaksicf Nov 18, 2024
855778d
Add stockham_fft implementation
jaksicf Nov 18, 2024
8d6af1a
Add scattering_self_energies Jax implementation
jaksicf Nov 18, 2024
0991340
Add deriche Jax implementation
jaksicf Nov 18, 2024
b670765
Add adi Jax implementation
jaksicf Nov 19, 2024
731f0d8
update durbin
sushant1212 Nov 21, 2024
091b295
update_correlation
sushant1212 Nov 21, 2024
eed0974
update azimint_naive
sushant1212 Nov 21, 2024
2a2a981
add crc16
sushant1212 Nov 21, 2024
f99a816
add arc_distance
sushant1212 Nov 21, 2024
caf0141
add gemver
sushant1212 Nov 21, 2024
d707312
add gemver
sushant1212 Nov 21, 2024
bbc546e
Update conv2d_jax with lax.scan
hardik01shah Nov 21, 2024
1d213cb
Add lenet jax implementation
hardik01shah Nov 21, 2024
00128b7
Add resnet jax implementation
hardik01shah Nov 21, 2024
6626d60
Add channel_flow jax implementation
hardik01shah Nov 21, 2024
8caad35
Add cavity_flow jax implementation
hardik01shah Nov 21, 2024
5d568b6
Update channel_flow_jax
hardik01shah Nov 21, 2024
1c42505
Update cavity_flow_jax: remove return of modified arrays used for val…
hardik01shah Nov 21, 2024
fb10d47
add nbody
sushant1212 Nov 21, 2024
ea2986d
Merge branch 'main' of github.com:hardik01shah/npbench
sushant1212 Nov 21, 2024
aa22c51
Remove standalone comments and some type hints to not inflate "lines …
jaksicf Nov 22, 2024
c9b9f61
update correlation
sushant1212 Nov 23, 2024
4b045b8
add hdiff
sushant1212 Nov 23, 2024
75fc1c9
Merge branch 'main' of github.com:hardik01shah/npbench
sushant1212 Nov 23, 2024
448e92d
Add mandelbrot2 Jax implementation
jaksicf Nov 23, 2024
578e440
add vadv
sushant1212 Nov 23, 2024
107584b
update nbody
sushant1212 Nov 23, 2024
bb972c2
Merge branch 'main' of github.com:hardik01shah/npbench
sushant1212 Nov 23, 2024
a2d4464
Add nussinov jax implementation
hardik01shah Nov 23, 2024
fe87b83
Fix(cavity_flow): return result arrays
jaksicf Nov 24, 2024
23919b3
add lib_implementation to jax_framework
sushant1212 Dec 15, 2024
c016a09
separate cov lib and cov default
sushant1212 Dec 15, 2024
53ac30e
fix comments
sushant1212 Dec 16, 2024
0810e9e
Merge pull request #1 from hardik01shah/sushant/lib-implementation
sushant1212 Dec 16, 2024
c92ec2f
minor fix in frmwrk-name for validation
hardik01shah Dec 17, 2024
970c3fd
Update go_fast jax implementation
hardik01shah Dec 17, 2024
7823055
Add go_fast jax_lib implementation
hardik01shah Dec 17, 2024
bcf0336
add trisolv lib implementation
sushant1212 Dec 19, 2024
51c1f80
update trisolv
sushant1212 Dec 19, 2024
f8c14ad
remove excessive loop vars
sushant1212 Dec 19, 2024
c526ad7
Make cholesky go brrr
jaksicf Dec 19, 2024
1dd3977
Make lu go brrr
jaksicf Dec 19, 2024
babb413
Make spmv go brrr
jaksicf Dec 19, 2024
67f6569
Make durbin go brrr (a bit)
jaksicf Dec 19, 2024
19aaaf7
Merge branch 'main' of github.com:hardik01shah/npbench
sushant1212 Dec 19, 2024
441e139
Rename vars in spmv to be more intuitive
jaksicf Dec 19, 2024
0cb56c0
Update seidel_2d jax implementation
hardik01shah Dec 19, 2024
732bbbe
Merge branch 'main' of https://github.com/hardik01shah/npbench into main
hardik01shah Dec 19, 2024
88c809e
Update ludcmp jax implementation
hardik01shah Dec 19, 2024
5d2d00f
Update cholesky jax implementation
hardik01shah Dec 19, 2024
35490a5
Update vadv jax implementation
hardik01shah Dec 19, 2024
312971c
update trisolv
sushant1212 Dec 20, 2024
52e0f76
update covariance
sushant1212 Dec 20, 2024
61fd0a6
update correlation
sushant1212 Dec 20, 2024
58f2cdb
fix trisolv
sushant1212 Dec 20, 2024
941215f
Update README.md
hardik01shah Dec 22, 2024
9bce82f
Update README.md with JAX installation instructions
hardik01shah Jan 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# dace
.dacecache/
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ python plot_results.py
Currently, the following frameworks are supported (in alphabetical order):
- CuPy
- DaCe
- JAX
- Numba
- NumPy
- Pythran
Expand Down Expand Up @@ -55,6 +56,24 @@ However, you may want to install the latest version from the [GitHub repository]
To run NPBench with DaCe, you have to select as framework (see details below)
either `dace_cpu` or `dace_gpu`.

### Jax

JAX can be installed with pip:
- CPU-only (Linux/macOS/Windows)
```sh
pip install -U jax
```
- GPU (NVIDIA, CUDA 12)
```sh
pip install -U "jax[cuda12]"
```
- TPU (Google Cloud TPU VM)
```sh
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
For more installation options, please consult the JAX [installation guide](https://jax.readthedocs.io/en/latest/installation.html#installation).


### Numba

Numba can be installed with pip:
Expand Down
10 changes: 10 additions & 0 deletions framework_info/jax.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"framework": {
"simple_name": "jax",
"full_name": "Jax",
"prefix": "jax",
"postfix": "jax",
"class": "JaxFramework",
"arch": "cpu"
}
}
19 changes: 19 additions & 0 deletions npbench/benchmarks/azimint_hist/azimint_hist_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2014 Jérôme Kieffer et al.
# This is an open-access article distributed under the terms of the
# Creative Commons Attribution License, which permits unrestricted use,
# distribution, and reproduction in any medium, provided the original author
# and source are credited.
# http://creativecommons.org/licenses/by/3.0/
# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for
# high performance azimuthal integration on gpu, 2014. In Proceedings of the
# 7th European Conference on Python in Science (EuroSciPy 2014).

import jax
import jax.numpy as jnp
from functools import partial

@partial(jax.jit, static_argnums=(2,))
def azimint_hist(data: jax.Array, radius: jax.Array, npt):
histu = jnp.histogram(radius, npt)[0]
histw = jnp.histogram(radius, npt, weights=data)[0]
return histw / histu
32 changes: 32 additions & 0 deletions npbench/benchmarks/azimint_naive/azimint_naive_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2014 Jérôme Kieffer et al.
# This is an open-access article distributed under the terms of the
# Creative Commons Attribution License, which permits unrestricted use,
# distribution, and reproduction in any medium, provided the original author
# and source are credited.
# http://creativecommons.org/licenses/by/3.0/
# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for
# high performance azimuthal integration on gpu, 2014. In Proceedings of the
# 7th European Conference on Python in Science (EuroSciPy 2014).

import jax
import jax.numpy as jnp
from jax import lax
from functools import partial


@partial(jax.jit, static_argnums=(2,))
def azimint_naive(data, radius, npt):
rmax = radius.max()
res = jnp.zeros(npt, dtype=jnp.float64)

def loop_body(i, res):
r1 = rmax * i / npt
r2 = rmax * (i + 1) / npt
mask_r12 = jnp.logical_and((r1 <= radius), (radius < r2))
mean = jnp.where(mask_r12, data, 0).mean(where=mask_r12)
res = res.at[i].set(mean)
return res

res = lax.fori_loop(0, npt, loop_body, res)

return res
102 changes: 102 additions & 0 deletions npbench/benchmarks/cavity_flow/cavity_flow_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Barba, Lorena A., and Forsyth, Gilbert F. (2018).
# CFD Python: the 12 steps to Navier-Stokes equations.
# Journal of Open Source Education, 1(9), 21,
# https://doi.org/10.21105/jose.00021
# TODO: License
# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth.
# All content is under Creative Commons Attribution CC-BY 4.0,
# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018).

import jax.numpy as jnp
import jax
from jax import lax
from functools import partial


@partial(jax.jit, static_argnums=(1,))
def build_up_b(b, rho, dt, u, v, dx, dy):

b = b.at[1:-1,
1:-1].set(rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) +
(v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) -
((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 *
((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) *
(v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) -
((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2))

return b


@partial(jax.jit, static_argnums=(0,))
def pressure_poisson(nit, p, dx, dy, b):
def body_func(p, _):
pn = p.copy()
p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 +
(pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) /
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 /
(2 * (dx**2 + dy**2)) * b[1:-1, 1:-1])

p = p.at[:, -1].set(p[:, -2]) # dp/dx = 0 at x = 2
p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0
p = p.at[:, 0].set(p[:, 1]) # dp/dx = 0 at x = 0
p = p.at[-1, :].set(0) # p = 0 at y = 2

return p, None

p, _ = lax.scan(body_func, p, jnp.arange(nit))

return p


@partial(jax.jit, static_argnums=(0,1,2,3,10,11,))
def cavity_flow(nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu):
b = jnp.zeros((ny, nx))
array_vals = (u, v, p, b)

def body_func(array_vals, _):

u, v, p, b = array_vals

un = u.copy()
vn = v.copy()

b = build_up_b(b, rho, dt, u, v, dx, dy)
p = pressure_poisson(nit, p, dx, dy, b)

u = u.at[1:-1,
1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
(un[1:-1, 1:-1] - un[1:-1, 0:-2]) -
vn[1:-1, 1:-1] * dt / dy *
(un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) *
(p[1:-1, 2:] - p[1:-1, 0:-2]) + nu *
(dt / dx**2 *
(un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) +
dt / dy**2 *
(un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])))

v = v.at[1:-1,
1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
(vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) -
vn[1:-1, 1:-1] * dt / dy *
(vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) *
(p[2:, 1:-1] - p[0:-2, 1:-1]) + nu *
(dt / dx**2 *
(vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) +
dt / dy**2 *
(vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1])))

u = u.at[0, :].set(0)
u = u.at[:, 0].set(0)
u = u.at[:, -1].set(0)
u = u.at[-1, :].set(1) # set velocity on cavity lid equal to 1
v = v.at[0, :].set(0)
v = v.at[-1, :].set(0)
v = v.at[:, 0].set(0)
v = v.at[:, -1].set(0)

return (u, v, p, b), None

out_vals, _ = lax.scan(body_func, array_vals, jnp.arange(nt))
u, v, p, b = out_vals

return u, v, p
172 changes: 172 additions & 0 deletions npbench/benchmarks/channel_flow/channel_flow_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Barba, Lorena A., and Forsyth, Gilbert F. (2018).
# CFD Python: the 12 steps to Navier-Stokes equations.
# Journal of Open Source Education, 1(9), 21,
# https://doi.org/10.21105/jose.00021
# TODO: License
# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth.
# All content is under Creative Commons Attribution CC-BY 4.0,
# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018).

import jax.numpy as jnp
import jax
from jax import lax
from functools import partial


@partial(jax.jit, static_argnums=(0,))
def build_up_b(rho, dt, dx, dy, u, v):
b = jnp.zeros_like(u)
b = b.at[1:-1,
1:-1].set((rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) +
(v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) -
((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 *
((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) *
(v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) -
((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2)))

# Periodic BC Pressure @ x = 2
b = b.at[1:-1, -1].set((rho * (1 / dt * ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx) +
(v[2:, -1] - v[0:-2, -1]) / (2 * dy)) -
((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx))**2 - 2 *
((u[2:, -1] - u[0:-2, -1]) / (2 * dy) *
(v[1:-1, 0] - v[1:-1, -2]) / (2 * dx)) -
((v[2:, -1] - v[0:-2, -1]) / (2 * dy))**2)))

# Periodic BC Pressure @ x = 0
b = b.at[1:-1, 0].set((rho * (1 / dt * ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx) +
(v[2:, 0] - v[0:-2, 0]) / (2 * dy)) -
((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx))**2 - 2 *
((u[2:, 0] - u[0:-2, 0]) / (2 * dy) *
(v[1:-1, 1] - v[1:-1, -1]) /
(2 * dx)) - ((v[2:, 0] - v[0:-2, 0]) / (2 * dy))**2)))

return b

@partial(jax.jit, static_argnums=(0,))
def pressure_poisson_periodic(nit, p, dx, dy, b):

def body_func(p, q):
pn = p.copy()
p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 +
(pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) /
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 /
(2 * (dx**2 + dy**2)) * b[1:-1, 1:-1])

# Periodic BC Pressure @ x = 2
p = p.at[1:-1, -1].set(((pn[1:-1, 0] + pn[1:-1, -2]) * dy**2 +
(pn[2:, -1] + pn[0:-2, -1]) * dx**2) /
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 /
(2 * (dx**2 + dy**2)) * b[1:-1, -1])

# Periodic BC Pressure @ x = 0
p = p.at[1:-1,
0].set((((pn[1:-1, 1] + pn[1:-1, -1]) * dy**2 +
(pn[2:, 0] + pn[0:-2, 0]) * dx**2) / (2 * (dx**2 + dy**2)) -
dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 0]))

# Wall boundary conditions, pressure
p = p.at[-1, :].set(p[-2, :]) # dp/dy = 0 at y = 2
p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0

return p, None

p, _ = lax.scan(body_func, p, jnp.arange(nit))


@partial(jax.jit, static_argnums=(0,7,8,9))
def channel_flow(nit, u, v, dt, dx, dy, p, rho, nu, F):
udiff = 1
stepcount = 0

array_vals = (udiff, stepcount, u, v, p)

def conf_func(array_vals):
udiff, _, _, _ , _ = array_vals
return udiff > .001

def body_func(array_vals):
_, stepcount, u, v, p = array_vals

un = u.copy()
vn = v.copy()

b = build_up_b(rho, dt, dx, dy, u, v)
pressure_poisson_periodic(nit, p, dx, dy, b)

u = u.at[1:-1,
1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
(un[1:-1, 1:-1] - un[1:-1, 0:-2]) -
vn[1:-1, 1:-1] * dt / dy *
(un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) *
(p[1:-1, 2:] - p[1:-1, 0:-2]) + nu *
(dt / dx**2 *
(un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) +
dt / dy**2 *
(un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])) +
F * dt)

v = v.at[1:-1,
1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
(vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) -
vn[1:-1, 1:-1] * dt / dy *
(vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) *
(p[2:, 1:-1] - p[0:-2, 1:-1]) + nu *
(dt / dx**2 *
(vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) +
dt / dy**2 *
(vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1])))

# Periodic BC u @ x = 2
u = u.at[1:-1, -1].set(
un[1:-1, -1] - un[1:-1, -1] * dt / dx *
(un[1:-1, -1] - un[1:-1, -2]) - vn[1:-1, -1] * dt / dy *
(un[1:-1, -1] - un[0:-2, -1]) - dt / (2 * rho * dx) *
(p[1:-1, 0] - p[1:-1, -2]) + nu *
(dt / dx**2 *
(un[1:-1, 0] - 2 * un[1:-1, -1] + un[1:-1, -2]) + dt / dy**2 *
(un[2:, -1] - 2 * un[1:-1, -1] + un[0:-2, -1])) + F * dt)

# Periodic BC u @ x = 0
u = u.at[1:-1,
0].set(un[1:-1, 0] - un[1:-1, 0] * dt / dx *
(un[1:-1, 0] - un[1:-1, -1]) - vn[1:-1, 0] * dt / dy *
(un[1:-1, 0] - un[0:-2, 0]) - dt / (2 * rho * dx) *
(p[1:-1, 1] - p[1:-1, -1]) + nu *
(dt / dx**2 *
(un[1:-1, 1] - 2 * un[1:-1, 0] + un[1:-1, -1]) + dt / dy**2 *
(un[2:, 0] - 2 * un[1:-1, 0] + un[0:-2, 0])) + F * dt)

# Periodic BC v @ x = 2
v = v.at[1:-1, -1].set(
vn[1:-1, -1] - un[1:-1, -1] * dt / dx *
(vn[1:-1, -1] - vn[1:-1, -2]) - vn[1:-1, -1] * dt / dy *
(vn[1:-1, -1] - vn[0:-2, -1]) - dt / (2 * rho * dy) *
(p[2:, -1] - p[0:-2, -1]) + nu *
(dt / dx**2 *
(vn[1:-1, 0] - 2 * vn[1:-1, -1] + vn[1:-1, -2]) + dt / dy**2 *
(vn[2:, -1] - 2 * vn[1:-1, -1] + vn[0:-2, -1])))

# Periodic BC v @ x = 0
v = v.at[1:-1,
0].set(vn[1:-1, 0] - un[1:-1, 0] * dt / dx *
(vn[1:-1, 0] - vn[1:-1, -1]) - vn[1:-1, 0] * dt / dy *
(vn[1:-1, 0] - vn[0:-2, 0]) - dt / (2 * rho * dy) *
(p[2:, 0] - p[0:-2, 0]) + nu *
(dt / dx**2 *
(vn[1:-1, 1] - 2 * vn[1:-1, 0] + vn[1:-1, -1]) + dt / dy**2 *
(vn[2:, 0] - 2 * vn[1:-1, 0] + vn[0:-2, 0])))

# Wall BC: u,v = 0 @ y = 0,2
u = u.at[0, :].set(0)
u = u.at[-1, :].set(0)
v = v.at[0, :].set(0)
v = v.at[-1, :].set(0)

udiff = (jnp.sum(u) - jnp.sum(un)) / jnp.sum(u)
stepcount += 1

return (udiff, stepcount, u, v, p)

_, stepcount, _, _, _ = lax.while_loop(conf_func, body_func, array_vals)

return stepcount
8 changes: 8 additions & 0 deletions npbench/benchmarks/compute/compute_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html

import jax.numpy as jnp
import jax

@jax.jit
def compute(array_1, array_2, a, b, c):
return jnp.clip(array_1, 2, 10) * a + array_2 * b + c
Loading