Skip to content

Commit 44928f6

Browse files
Merge pull request #2150 from OceanParcels/cleaning_index_search
Cleaning up the index-search file
2 parents 62a13d8 + 0ca39bd commit 44928f6

File tree

1 file changed

+15
-307
lines changed

1 file changed

+15
-307
lines changed

parcels/_index_search.py

Lines changed: 15 additions & 307 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,12 @@
55

66
import numpy as np
77

8-
from parcels._typing import (
9-
GridIndexingType,
10-
InterpMethodOption,
11-
Mesh,
12-
)
138
from parcels.tools.statuscodes import (
14-
FieldOutOfBoundError,
15-
FieldOutOfBoundSurfaceError,
169
_raise_field_out_of_bound_error,
17-
_raise_field_out_of_bound_surface_error,
1810
_raise_field_sampling_error,
1911
_raise_time_extrapolation_error,
2012
)
2113

22-
from .basegrid import GridType
23-
2414
if TYPE_CHECKING:
2515
from parcels.xgrid import XGrid
2616

@@ -50,206 +40,6 @@ def _search_time_index(field: Field, time: datetime):
5040
return np.atleast_1d(tau), np.atleast_1d(ti)
5141

5242

53-
def search_indices_vertical_z(depth, gridindexingtype: GridIndexingType, z: float):
54-
if depth[-1] > depth[0]:
55-
if z < depth[0]:
56-
# Since MOM5 is indexed at cell bottom, allow z at depth[0] - dz where dz = (depth[1] - depth[0])
57-
if gridindexingtype == "mom5" and z > 2 * depth[0] - depth[1]:
58-
return (-1, z / depth[0])
59-
else:
60-
_raise_field_out_of_bound_surface_error(z, None, None)
61-
elif z > depth[-1]:
62-
# In case of CROCO, allow particles in last (uppermost) layer using depth[-1]
63-
if gridindexingtype in ["croco"] and z < 0:
64-
return (-2, 1)
65-
_raise_field_out_of_bound_error(z, None, None)
66-
depth_indices = depth < z
67-
if z >= depth[-1]:
68-
zi = len(depth) - 2
69-
else:
70-
zi = depth_indices.argmin() - 1 if z > depth[0] else 0
71-
else:
72-
if z > depth[0]:
73-
_raise_field_out_of_bound_surface_error(z, None, None)
74-
elif z < depth[-1]:
75-
_raise_field_out_of_bound_error(z, None, None)
76-
depth_indices = depth > z
77-
if z <= depth[-1]:
78-
zi = len(depth) - 2
79-
else:
80-
zi = depth_indices.argmin() - 1 if z < depth[0] else 0
81-
zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi])
82-
while zeta > 1:
83-
zi += 1
84-
zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi])
85-
while zeta < 0:
86-
zi -= 1
87-
zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi])
88-
return (zi, zeta)
89-
90-
91-
## TODO : Still need to implement the search_indices_vertical_s function
92-
def search_indices_vertical_s(
93-
field: Field,
94-
interp_method: InterpMethodOption,
95-
time: float,
96-
z: float,
97-
y: float,
98-
x: float,
99-
ti: int,
100-
yi: int,
101-
xi: int,
102-
eta: float,
103-
xsi: float,
104-
):
105-
if interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]:
106-
xsi = 1
107-
eta = 1
108-
if time < field.time[ti]:
109-
ti -= 1
110-
if field._z4d: # type: ignore[attr-defined]
111-
if ti == len(field.time) - 1:
112-
depth_vector = (
113-
(1 - xsi) * (1 - eta) * field.depth[-1, :, yi, xi]
114-
+ xsi * (1 - eta) * field.depth[-1, :, yi, xi + 1]
115-
+ xsi * eta * field.depth[-1, :, yi + 1, xi + 1]
116-
+ (1 - xsi) * eta * field.depth[-1, :, yi + 1, xi]
117-
)
118-
else:
119-
dv2 = (
120-
(1 - xsi) * (1 - eta) * field.depth[ti : ti + 2, :, yi, xi]
121-
+ xsi * (1 - eta) * field.depth[ti : ti + 2, :, yi, xi + 1]
122-
+ xsi * eta * field.depth[ti : ti + 2, :, yi + 1, xi + 1]
123-
+ (1 - xsi) * eta * field.depth[ti : ti + 2, :, yi + 1, xi]
124-
)
125-
tt = (time - field.time[ti]) / (field.time[ti + 1] - field.time[ti])
126-
assert tt >= 0 and tt <= 1, "Vertical s grid is being wrongly interpolated in time"
127-
depth_vector = dv2[0, :] * (1 - tt) + dv2[1, :] * tt
128-
else:
129-
depth_vector = (
130-
(1 - xsi) * (1 - eta) * field.depth[:, yi, xi]
131-
+ xsi * (1 - eta) * field.depth[:, yi, xi + 1]
132-
+ xsi * eta * field.depth[:, yi + 1, xi + 1]
133-
+ (1 - xsi) * eta * field.depth[:, yi + 1, xi]
134-
)
135-
z = np.float32(z) # type: ignore # TODO: remove type ignore once we migrate to float64
136-
137-
if depth_vector[-1] > depth_vector[0]:
138-
if z < depth_vector[0]:
139-
_raise_field_out_of_bound_error(z, None, None)
140-
elif z > depth_vector[-1]:
141-
_raise_field_out_of_bound_error(z, None, None)
142-
depth_indices = depth_vector < z
143-
if z >= depth_vector[-1]:
144-
zi = len(depth_vector) - 2
145-
else:
146-
zi = depth_indices.argmin() - 1 if z > depth_vector[0] else 0
147-
else:
148-
if z > depth_vector[0]:
149-
_raise_field_out_of_bound_error(z, None, None)
150-
elif z < depth_vector[-1]:
151-
_raise_field_out_of_bound_error(z, None, None)
152-
depth_indices = depth_vector > z
153-
if z <= depth_vector[-1]:
154-
zi = len(depth_vector) - 2
155-
else:
156-
zi = depth_indices.argmin() - 1 if z < depth_vector[0] else 0
157-
zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi])
158-
while zeta > 1:
159-
zi += 1
160-
zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi])
161-
while zeta < 0:
162-
zi -= 1
163-
zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi])
164-
return (zi, zeta)
165-
166-
167-
def _search_indices_rectilinear(
168-
field: Field, time: datetime, z: float, y: float, x: float, ti: int, ei: int | None = None, search2D=False
169-
):
170-
# TODO : If ei is provided, check if particle is in the same cell
171-
if field.xdim > 1 and (not field.zonal_periodic):
172-
if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]:
173-
_raise_field_out_of_bound_error(z, y, x)
174-
if field.ydim > 1 and (y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]):
175-
_raise_field_out_of_bound_error(z, y, x)
176-
177-
if field.xdim > 1:
178-
if field._mesh != "spherical":
179-
lon_index = field.lon < x
180-
if lon_index.all():
181-
xi = len(field.lon) - 2
182-
else:
183-
xi = lon_index.argmin() - 1 if lon_index.any() else 0
184-
xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi])
185-
if xsi < 0:
186-
xi -= 1
187-
xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi])
188-
elif xsi > 1:
189-
xi += 1
190-
xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi])
191-
else:
192-
lon_fixed = field.lon.copy()
193-
indices = lon_fixed >= lon_fixed[0]
194-
if not indices.all():
195-
lon_fixed[indices.argmin() :] += 360
196-
if x < lon_fixed[0]:
197-
lon_fixed -= 360
198-
199-
lon_index = lon_fixed < x
200-
if lon_index.all():
201-
xi = len(lon_fixed) - 2
202-
else:
203-
xi = lon_index.argmin() - 1 if lon_index.any() else 0
204-
xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi])
205-
if xsi < 0:
206-
xi -= 1
207-
xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi])
208-
elif xsi > 1:
209-
xi += 1
210-
xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi])
211-
else:
212-
xi, xsi = -1, 0
213-
214-
if field.ydim > 1:
215-
lat_index = field.lat < y
216-
if lat_index.all():
217-
yi = len(field.lat) - 2
218-
else:
219-
yi = lat_index.argmin() - 1 if lat_index.any() else 0
220-
221-
eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi])
222-
if eta < 0:
223-
yi -= 1
224-
eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi])
225-
elif eta > 1:
226-
yi += 1
227-
eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi])
228-
else:
229-
yi, eta = -1, 0
230-
231-
if field.zdim > 1 and not search2D:
232-
if field._gtype == GridType.RectilinearZGrid:
233-
try:
234-
(zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z)
235-
except FieldOutOfBoundError:
236-
_raise_field_out_of_bound_error(z, y, x)
237-
except FieldOutOfBoundSurfaceError:
238-
_raise_field_out_of_bound_surface_error(z, y, x)
239-
elif field._gtype == GridType.RectilinearSGrid:
240-
## TODO : Still need to implement the search_indices_vertical_s function
241-
(zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi)
242-
else:
243-
zi, zeta = -1, 0
244-
245-
if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)):
246-
_raise_field_sampling_error(z, y, x)
247-
248-
_ei = field.ravel_index(zi, yi, xi)
249-
250-
return (zeta, eta, xsi, _ei)
251-
252-
25343
def _search_indices_curvilinear_2d(
25444
grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None
25545
): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays
@@ -318,103 +108,21 @@ def _search_indices_curvilinear_2d(
318108
return (yi, eta, xi, xsi)
319109

320110

321-
## TODO : Still need to implement the search_indices_curvilinear
322-
def _search_indices_curvilinear(field, time, z, y, x, ti, particle=None, search2D=False):
323-
if particle:
324-
zi, yi, xi = field.unravel_index(particle.ei)
325-
else:
326-
xi = int(field.xdim / 2) - 1
327-
yi = int(field.ydim / 2) - 1
328-
xsi = eta = -1.0
329-
grid = field.grid
330-
invA = np.array([[1, 0, 0, 0], [-1, 1, 0, 0], [-1, 0, 0, 1], [1, -1, 1, -1]])
331-
maxIterSearch = 1e6
332-
it = 0
333-
tol = 1.0e-10
334-
if not grid.zonal_periodic:
335-
if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]:
336-
if grid.lon[0, 0] < grid.lon[0, -1]:
337-
_raise_field_out_of_bound_error(z, y, x)
338-
elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160]
339-
_raise_field_out_of_bound_error(z, y, x)
340-
if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]:
341-
_raise_field_out_of_bound_error(z, y, x)
342-
343-
while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol:
344-
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
345-
if grid.mesh == "spherical":
346-
px[0] = px[0] + 360 if px[0] < x - 225 else px[0]
347-
px[0] = px[0] - 360 if px[0] > x + 225 else px[0]
348-
px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:])
349-
px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:])
350-
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
351-
a = np.dot(invA, px)
352-
b = np.dot(invA, py)
353-
354-
aa = a[3] * b[2] - a[2] * b[3]
355-
bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3]
356-
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]
357-
if abs(aa) < 1e-12: # Rectilinear cell, or quasi
358-
eta = -cc / bb
111+
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool):
112+
if xi < 0:
113+
if sphere_mesh:
114+
xi = xdim - 2
359115
else:
360-
det2 = bb * bb - 4 * aa * cc
361-
if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter
362-
det = np.sqrt(det2)
363-
eta = (-bb + det) / (2 * aa)
364-
if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg
365-
xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5
116+
xi = 0
117+
if xi > xdim - 2:
118+
if sphere_mesh:
119+
xi = 0
366120
else:
367-
xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta)
368-
if xsi < 0 and eta < 0 and xi == 0 and yi == 0:
369-
_raise_field_out_of_bound_error(0, y, x)
370-
if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1:
371-
_raise_field_out_of_bound_error(0, y, x)
372-
if xsi < -tol:
373-
xi -= 1
374-
elif xsi > 1 + tol:
375-
xi += 1
376-
if eta < -tol:
377-
yi -= 1
378-
elif eta > 1 + tol:
379-
yi += 1
380-
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh)
381-
it += 1
382-
if it > maxIterSearch:
383-
print(f"Correct cell not found after {maxIterSearch} iterations")
384-
_raise_field_out_of_bound_error(0, y, x)
385-
xsi = max(0.0, xsi)
386-
eta = max(0.0, eta)
387-
xsi = min(1.0, xsi)
388-
eta = min(1.0, eta)
389-
390-
if grid.zdim > 1 and not search2D:
391-
if grid._gtype == GridType.CurvilinearZGrid:
392-
try:
393-
(zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z)
394-
except FieldOutOfBoundError:
395-
_raise_field_out_of_bound_error(z, y, x)
396-
elif grid._gtype == GridType.CurvilinearSGrid:
397-
(zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi)
398-
else:
399-
zi = -1
400-
zeta = 0
401-
402-
if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)):
403-
_raise_field_sampling_error(z, y, x)
404-
405-
if particle:
406-
particle.ei[field.igrid] = field.ravel_index(zi, yi, xi)
407-
408-
return (zeta, eta, xsi, zi, yi, xi)
409-
410-
411-
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh):
412-
xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi)
413-
xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi)
414-
415-
xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi)
416-
417-
yi = np.where(yi < 0, 0, yi)
418-
yi = np.where(yi > ydim - 2, ydim - 2, yi)
419-
121+
xi = xdim - 2
122+
if yi < 0:
123+
yi = 0
124+
if yi > ydim - 2:
125+
yi = ydim - 2
126+
if sphere_mesh:
127+
xi = xdim - xi
420128
return yi, xi

0 commit comments

Comments
 (0)