@@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode):
61
61
def transpose (a ):
62
62
return np .permute_dims (a , [1 , 0 ])
63
63
64
- all_axes = [0 , 1 ]
65
64
init (False )
66
65
67
66
elif backend == "numpy" :
@@ -76,7 +75,6 @@ def transpose(a):
76
75
transpose = np .transpose
77
76
78
77
fini = sync = lambda x = None : None
79
- all_axes = None
80
78
else :
81
79
raise ValueError (f'Unknown backend: "{ backend } "' )
82
80
@@ -207,11 +205,11 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
207
205
# set bathymetry
208
206
h [:, :] = bathymetry (x_t_2d , y_t_2d , lx , ly )
209
207
# steady state potential energy
210
- pe_offset = 0.5 * g * float (np .sum (h ** 2.0 , all_axes )) / nx / ny
208
+ pe_offset = 0.5 * g * float (np .sum (h ** 2.0 )) / nx / ny
211
209
212
210
# compute time step
213
211
alpha = 0.5
214
- h_max = float (np .max (h , all_axes ))
212
+ h_max = float (np .max (h ))
215
213
c = (g * h_max ) ** 0.5
216
214
dt = alpha * dx / c
217
215
dt = t_export / int (math .ceil (t_export / dt ))
@@ -344,22 +342,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
344
342
t = i * dt
345
343
346
344
if t >= next_t_export - 1e-8 :
347
- _elev_max = np .max (e , all_axes )
348
- _u_max = np .max (u , all_axes )
349
- _q_max = np .max (q , all_axes )
350
- _total_v = np .sum (e + h , all_axes )
345
+ _elev_max = np .max (e )
346
+ _u_max = np .max (u )
347
+ _q_max = np .max (q )
348
+ _total_v = np .sum (e + h )
351
349
352
350
# potential energy
353
351
_pe = 0.5 * g * (e + h ) * (e - h ) + pe_offset
354
- _total_pe = np .sum (_pe , all_axes )
352
+ _total_pe = np .sum (_pe )
355
353
356
354
# kinetic energy
357
355
u2 = u * u
358
356
v2 = v * v
359
357
u2_at_t = 0.5 * (u2 [1 :, :] + u2 [:- 1 , :])
360
358
v2_at_t = 0.5 * (v2 [:, 1 :] + v2 [:, :- 1 ])
361
359
_ke = 0.5 * (u2_at_t + v2_at_t ) * (e + h )
362
- _total_ke = np .sum (_ke , all_axes )
360
+ _total_ke = np .sum (_ke )
363
361
364
362
total_pe = float (_total_pe ) * dx * dy
365
363
total_ke = float (_total_ke ) * dx * dy
@@ -406,7 +404,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
406
404
2
407
405
]
408
406
err2 = (e_exact - e ) * (e_exact - e ) * dx * dy / lx / ly
409
- err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
407
+ err_L2 = math .sqrt (float (np .sum (err2 )))
410
408
info (f"L2 error: { err_L2 :7.15e} " )
411
409
412
410
if nx < 128 or ny < 128 :
0 commit comments