Skip to content

Commit 0952df3

Browse files
committed
shallow-water: simplify reduction calls
1 parent 57b40c1 commit 0952df3

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

examples/shallow_water.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode):
6161
def transpose(a):
6262
return np.permute_dims(a, [1, 0])
6363

64-
all_axes = [0, 1]
6564
init(False)
6665

6766
elif backend == "numpy":
@@ -76,7 +75,6 @@ def transpose(a):
7675
transpose = np.transpose
7776

7877
fini = sync = lambda x=None: None
79-
all_axes = None
8078
else:
8179
raise ValueError(f'Unknown backend: "{backend}"')
8280

@@ -207,11 +205,11 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
207205
# set bathymetry
208206
h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly)
209207
# steady state potential energy
210-
pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
208+
pe_offset = 0.5 * g * float(np.sum(h**2.0)) / nx / ny
211209

212210
# compute time step
213211
alpha = 0.5
214-
h_max = float(np.max(h, all_axes))
212+
h_max = float(np.max(h))
215213
c = (g * h_max) ** 0.5
216214
dt = alpha * dx / c
217215
dt = t_export / int(math.ceil(t_export / dt))
@@ -344,22 +342,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
344342
t = i * dt
345343

346344
if t >= next_t_export - 1e-8:
347-
_elev_max = np.max(e, all_axes)
348-
_u_max = np.max(u, all_axes)
349-
_q_max = np.max(q, all_axes)
350-
_total_v = np.sum(e + h, all_axes)
345+
_elev_max = np.max(e)
346+
_u_max = np.max(u)
347+
_q_max = np.max(q)
348+
_total_v = np.sum(e + h)
351349

352350
# potential energy
353351
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
354-
_total_pe = np.sum(_pe, all_axes)
352+
_total_pe = np.sum(_pe)
355353

356354
# kinetic energy
357355
u2 = u * u
358356
v2 = v * v
359357
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
360358
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
361359
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
362-
_total_ke = np.sum(_ke, all_axes)
360+
_total_ke = np.sum(_ke)
363361

364362
total_pe = float(_total_pe) * dx * dy
365363
total_ke = float(_total_ke) * dx * dy
@@ -406,7 +404,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
406404
2
407405
]
408406
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
409-
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
407+
err_L2 = math.sqrt(float(np.sum(err2)))
410408
info(f"L2 error: {err_L2:7.15e}")
411409

412410
if nx < 128 or ny < 128:

0 commit comments

Comments
 (0)