Skip to content

Commit 4603797

Browse files
Merge pull request #1816 from OceanParcels/v/refactor-interp
Refactoring field interpolation and allow custom interpolation methods in Scipy mode
2 parents ecf9d80 + d029cf8 commit 4603797

File tree

15 files changed

+1022
-570
lines changed

15 files changed

+1022
-570
lines changed

parcels/_compat.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,17 @@
1717
from sklearn.cluster import KMeans # type: ignore[no-redef]
1818
except ModuleNotFoundError:
1919
pass
20+
21+
22+
def add_note(e: Exception, note: str, *, before=False) -> Exception: # TODO: Remove once py3.10 support is dropped
23+
"""Implements something similar to PEP 678 but for python <3.11.
24+
25+
https://stackoverflow.com/a/75549200/15545258
26+
"""
27+
args = e.args
28+
if not args:
29+
arg0 = note
30+
else:
31+
arg0 = f"{note}\n{args[0]}" if before else f"{args[0]}\n{note}"
32+
e.args = (arg0,) + args[1:]
33+
return e

parcels/_index_search.py

Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import numpy as np
6+
7+
from parcels._typing import (
8+
GridIndexingType,
9+
InterpMethodOption,
10+
)
11+
from parcels.tools.statuscodes import (
12+
FieldOutOfBoundError,
13+
FieldOutOfBoundSurfaceError,
14+
_raise_field_out_of_bound_error,
15+
_raise_field_out_of_bound_surface_error,
16+
_raise_field_sampling_error,
17+
)
18+
19+
from .grid import GridType
20+
21+
if TYPE_CHECKING:
22+
from .field import Field
23+
from .grid import Grid
24+
25+
26+
def search_indices_vertical_z(grid: Grid, gridindexingtype: GridIndexingType, z: float):
27+
if grid.depth[-1] > grid.depth[0]:
28+
if z < grid.depth[0]:
29+
# Since MOM5 is indexed at cell bottom, allow z at depth[0] - dz where dz = (depth[1] - depth[0])
30+
if gridindexingtype == "mom5" and z > 2 * grid.depth[0] - grid.depth[1]:
31+
return (-1, z / grid.depth[0])
32+
else:
33+
_raise_field_out_of_bound_surface_error(z, None, None)
34+
elif z > grid.depth[-1]:
35+
# In case of CROCO, allow particles in last (uppermost) layer using depth[-1]
36+
if gridindexingtype in ["croco"] and z < 0:
37+
return (-2, 1)
38+
_raise_field_out_of_bound_error(z, None, None)
39+
depth_indices = grid.depth < z
40+
if z >= grid.depth[-1]:
41+
zi = len(grid.depth) - 2
42+
else:
43+
zi = depth_indices.argmin() - 1 if z > grid.depth[0] else 0
44+
else:
45+
if z > grid.depth[0]:
46+
_raise_field_out_of_bound_surface_error(z, None, None)
47+
elif z < grid.depth[-1]:
48+
_raise_field_out_of_bound_error(z, None, None)
49+
depth_indices = grid.depth > z
50+
if z <= grid.depth[-1]:
51+
zi = len(grid.depth) - 2
52+
else:
53+
zi = depth_indices.argmin() - 1 if z < grid.depth[0] else 0
54+
zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi])
55+
while zeta > 1:
56+
zi += 1
57+
zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi])
58+
while zeta < 0:
59+
zi -= 1
60+
zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi])
61+
return (zi, zeta)
62+
63+
64+
def search_indices_vertical_s(
65+
grid: Grid,
66+
interp_method: InterpMethodOption,
67+
time: float,
68+
z: float,
69+
y: float,
70+
x: float,
71+
ti: int,
72+
yi: int,
73+
xi: int,
74+
eta: float,
75+
xsi: float,
76+
):
77+
if interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]:
78+
xsi = 1
79+
eta = 1
80+
if time < grid.time[ti]:
81+
ti -= 1
82+
if grid._z4d:
83+
if ti == len(grid.time) - 1:
84+
depth_vector = (
85+
(1 - xsi) * (1 - eta) * grid.depth[-1, :, yi, xi]
86+
+ xsi * (1 - eta) * grid.depth[-1, :, yi, xi + 1]
87+
+ xsi * eta * grid.depth[-1, :, yi + 1, xi + 1]
88+
+ (1 - xsi) * eta * grid.depth[-1, :, yi + 1, xi]
89+
)
90+
else:
91+
dv2 = (
92+
(1 - xsi) * (1 - eta) * grid.depth[ti : ti + 2, :, yi, xi]
93+
+ xsi * (1 - eta) * grid.depth[ti : ti + 2, :, yi, xi + 1]
94+
+ xsi * eta * grid.depth[ti : ti + 2, :, yi + 1, xi + 1]
95+
+ (1 - xsi) * eta * grid.depth[ti : ti + 2, :, yi + 1, xi]
96+
)
97+
tt = (time - grid.time[ti]) / (grid.time[ti + 1] - grid.time[ti])
98+
assert tt >= 0 and tt <= 1, "Vertical s grid is being wrongly interpolated in time"
99+
depth_vector = dv2[0, :] * (1 - tt) + dv2[1, :] * tt
100+
else:
101+
depth_vector = (
102+
(1 - xsi) * (1 - eta) * grid.depth[:, yi, xi]
103+
+ xsi * (1 - eta) * grid.depth[:, yi, xi + 1]
104+
+ xsi * eta * grid.depth[:, yi + 1, xi + 1]
105+
+ (1 - xsi) * eta * grid.depth[:, yi + 1, xi]
106+
)
107+
z = np.float32(z) # type: ignore # TODO: remove type ignore once we migrate to float64
108+
109+
if depth_vector[-1] > depth_vector[0]:
110+
if z < depth_vector[0]:
111+
_raise_field_out_of_bound_error(z, None, None)
112+
elif z > depth_vector[-1]:
113+
_raise_field_out_of_bound_error(z, None, None)
114+
depth_indices = depth_vector < z
115+
if z >= depth_vector[-1]:
116+
zi = len(depth_vector) - 2
117+
else:
118+
zi = depth_indices.argmin() - 1 if z > depth_vector[0] else 0
119+
else:
120+
if z > depth_vector[0]:
121+
_raise_field_out_of_bound_error(z, None, None)
122+
elif z < depth_vector[-1]:
123+
_raise_field_out_of_bound_error(z, None, None)
124+
depth_indices = depth_vector > z
125+
if z <= depth_vector[-1]:
126+
zi = len(depth_vector) - 2
127+
else:
128+
zi = depth_indices.argmin() - 1 if z < depth_vector[0] else 0
129+
zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi])
130+
while zeta > 1:
131+
zi += 1
132+
zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi])
133+
while zeta < 0:
134+
zi -= 1
135+
zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi])
136+
return (zi, zeta)
137+
138+
139+
def _search_indices_rectilinear(
140+
field: Field, time: float, z: float, y: float, x: float, ti=-1, particle=None, search2D=False
141+
):
142+
grid = field.grid
143+
144+
if grid.xdim > 1 and (not grid.zonal_periodic):
145+
if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]:
146+
_raise_field_out_of_bound_error(z, y, x)
147+
if grid.ydim > 1 and (y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]):
148+
_raise_field_out_of_bound_error(z, y, x)
149+
150+
if grid.xdim > 1:
151+
if grid.mesh != "spherical":
152+
lon_index = grid.lon < x
153+
if lon_index.all():
154+
xi = len(grid.lon) - 2
155+
else:
156+
xi = lon_index.argmin() - 1 if lon_index.any() else 0
157+
xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi])
158+
if xsi < 0:
159+
xi -= 1
160+
xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi])
161+
elif xsi > 1:
162+
xi += 1
163+
xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi])
164+
else:
165+
lon_fixed = grid.lon.copy()
166+
indices = lon_fixed >= lon_fixed[0]
167+
if not indices.all():
168+
lon_fixed[indices.argmin() :] += 360
169+
if x < lon_fixed[0]:
170+
lon_fixed -= 360
171+
172+
lon_index = lon_fixed < x
173+
if lon_index.all():
174+
xi = len(lon_fixed) - 2
175+
else:
176+
xi = lon_index.argmin() - 1 if lon_index.any() else 0
177+
xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi])
178+
if xsi < 0:
179+
xi -= 1
180+
xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi])
181+
elif xsi > 1:
182+
xi += 1
183+
xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi])
184+
else:
185+
xi, xsi = -1, 0
186+
187+
if grid.ydim > 1:
188+
lat_index = grid.lat < y
189+
if lat_index.all():
190+
yi = len(grid.lat) - 2
191+
else:
192+
yi = lat_index.argmin() - 1 if lat_index.any() else 0
193+
194+
eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi])
195+
if eta < 0:
196+
yi -= 1
197+
eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi])
198+
elif eta > 1:
199+
yi += 1
200+
eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi])
201+
else:
202+
yi, eta = -1, 0
203+
204+
if grid.zdim > 1 and not search2D:
205+
if grid._gtype == GridType.RectilinearZGrid:
206+
try:
207+
(zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z)
208+
except FieldOutOfBoundError:
209+
_raise_field_out_of_bound_error(z, y, x)
210+
except FieldOutOfBoundSurfaceError:
211+
_raise_field_out_of_bound_surface_error(z, y, x)
212+
elif grid._gtype == GridType.RectilinearSGrid:
213+
(zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi)
214+
else:
215+
zi, zeta = -1, 0
216+
217+
if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)):
218+
_raise_field_sampling_error(z, y, x)
219+
220+
if particle:
221+
particle.xi[field.igrid] = xi
222+
particle.yi[field.igrid] = yi
223+
particle.zi[field.igrid] = zi
224+
225+
return (zeta, eta, xsi, zi, yi, xi)
226+
227+
228+
def _search_indices_curvilinear(field: Field, time, z, y, x, ti=-1, particle=None, search2D=False):
229+
if particle:
230+
xi = particle.xi[field.igrid]
231+
yi = particle.yi[field.igrid]
232+
else:
233+
xi = int(field.grid.xdim / 2) - 1
234+
yi = int(field.grid.ydim / 2) - 1
235+
xsi = eta = -1
236+
grid = field.grid
237+
invA = np.array([[1, 0, 0, 0], [-1, 1, 0, 0], [-1, 0, 0, 1], [1, -1, 1, -1]])
238+
maxIterSearch = 1e6
239+
it = 0
240+
tol = 1.0e-10
241+
if not grid.zonal_periodic:
242+
if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]:
243+
if grid.lon[0, 0] < grid.lon[0, -1]:
244+
_raise_field_out_of_bound_error(z, y, x)
245+
elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160]
246+
_raise_field_out_of_bound_error(z, y, x)
247+
if y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]:
248+
_raise_field_out_of_bound_error(z, y, x)
249+
250+
while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol:
251+
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
252+
if grid.mesh == "spherical":
253+
px[0] = px[0] + 360 if px[0] < x - 225 else px[0]
254+
px[0] = px[0] - 360 if px[0] > x + 225 else px[0]
255+
px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:])
256+
px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:])
257+
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
258+
a = np.dot(invA, px)
259+
b = np.dot(invA, py)
260+
261+
aa = a[3] * b[2] - a[2] * b[3]
262+
bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3]
263+
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]
264+
if abs(aa) < 1e-12: # Rectilinear cell, or quasi
265+
eta = -cc / bb
266+
else:
267+
det2 = bb * bb - 4 * aa * cc
268+
if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter
269+
det = np.sqrt(det2)
270+
eta = (-bb + det) / (2 * aa)
271+
if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg
272+
xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5
273+
else:
274+
xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta)
275+
if xsi < 0 and eta < 0 and xi == 0 and yi == 0:
276+
_raise_field_out_of_bound_error(0, y, x)
277+
if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1:
278+
_raise_field_out_of_bound_error(0, y, x)
279+
if xsi < -tol:
280+
xi -= 1
281+
elif xsi > 1 + tol:
282+
xi += 1
283+
if eta < -tol:
284+
yi -= 1
285+
elif eta > 1 + tol:
286+
yi += 1
287+
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh)
288+
it += 1
289+
if it > maxIterSearch:
290+
print(f"Correct cell not found after {maxIterSearch} iterations")
291+
_raise_field_out_of_bound_error(0, y, x)
292+
xsi = max(0.0, xsi)
293+
eta = max(0.0, eta)
294+
xsi = min(1.0, xsi)
295+
eta = min(1.0, eta)
296+
297+
if grid.zdim > 1 and not search2D:
298+
if grid._gtype == GridType.CurvilinearZGrid:
299+
try:
300+
(zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z)
301+
except FieldOutOfBoundError:
302+
_raise_field_out_of_bound_error(z, y, x)
303+
elif grid._gtype == GridType.CurvilinearSGrid:
304+
(zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi)
305+
else:
306+
zi = -1
307+
zeta = 0
308+
309+
if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)):
310+
_raise_field_sampling_error(z, y, x)
311+
312+
if particle:
313+
particle.xi[field.igrid] = xi
314+
particle.yi[field.igrid] = yi
315+
particle.zi[field.igrid] = zi
316+
317+
return (zeta, eta, xsi, zi, yi, xi)
318+
319+
320+
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool):
321+
if xi < 0:
322+
if sphere_mesh:
323+
xi = xdim - 2
324+
else:
325+
xi = 0
326+
if xi > xdim - 2:
327+
if sphere_mesh:
328+
xi = 0
329+
else:
330+
xi = xdim - 2
331+
if yi < 0:
332+
yi = 0
333+
if yi > ydim - 2:
334+
yi = ydim - 2
335+
if sphere_mesh:
336+
xi = xdim - xi
337+
return yi, xi

0 commit comments

Comments
 (0)