Skip to content

Commit 85bbeff

Browse files
Merge branch 'v4-dev' into spatial_slip_interpolation
2 parents 0ed524d + 25fb9d0 commit 85bbeff

File tree

5 files changed

+363
-504
lines changed

5 files changed

+363
-504
lines changed

parcels/_index_search.py

Lines changed: 16 additions & 309 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,12 @@
55

66
import numpy as np
77

8-
from parcels._typing import (
9-
GridIndexingType,
10-
InterpMethodOption,
11-
Mesh,
12-
)
138
from parcels.tools.statuscodes import (
14-
FieldOutOfBoundError,
15-
FieldOutOfBoundSurfaceError,
169
_raise_field_out_of_bound_error,
17-
_raise_field_out_of_bound_surface_error,
1810
_raise_field_sampling_error,
1911
_raise_time_extrapolation_error,
2012
)
2113

22-
from .basegrid import GridType
23-
2414
if TYPE_CHECKING:
2515
from parcels.xgrid import XGrid
2616

@@ -50,213 +40,12 @@ def _search_time_index(field: Field, time: datetime):
5040
return np.atleast_1d(tau), np.atleast_1d(ti)
5141

5242

53-
def search_indices_vertical_z(depth, gridindexingtype: GridIndexingType, z: float):
54-
if depth[-1] > depth[0]:
55-
if z < depth[0]:
56-
# Since MOM5 is indexed at cell bottom, allow z at depth[0] - dz where dz = (depth[1] - depth[0])
57-
if gridindexingtype == "mom5" and z > 2 * depth[0] - depth[1]:
58-
return (-1, z / depth[0])
59-
else:
60-
_raise_field_out_of_bound_surface_error(z, None, None)
61-
elif z > depth[-1]:
62-
# In case of CROCO, allow particles in last (uppermost) layer using depth[-1]
63-
if gridindexingtype in ["croco"] and z < 0:
64-
return (-2, 1)
65-
_raise_field_out_of_bound_error(z, None, None)
66-
depth_indices = depth < z
67-
if z >= depth[-1]:
68-
zi = len(depth) - 2
69-
else:
70-
zi = depth_indices.argmin() - 1 if z > depth[0] else 0
71-
else:
72-
if z > depth[0]:
73-
_raise_field_out_of_bound_surface_error(z, None, None)
74-
elif z < depth[-1]:
75-
_raise_field_out_of_bound_error(z, None, None)
76-
depth_indices = depth > z
77-
if z <= depth[-1]:
78-
zi = len(depth) - 2
79-
else:
80-
zi = depth_indices.argmin() - 1 if z < depth[0] else 0
81-
zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi])
82-
while zeta > 1:
83-
zi += 1
84-
zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi])
85-
while zeta < 0:
86-
zi -= 1
87-
zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi])
88-
return (zi, zeta)
89-
90-
91-
## TODO : Still need to implement the search_indices_vertical_s function
92-
def search_indices_vertical_s(
93-
field: Field,
94-
interp_method: InterpMethodOption,
95-
time: float,
96-
z: float,
97-
y: float,
98-
x: float,
99-
ti: int,
100-
yi: int,
101-
xi: int,
102-
eta: float,
103-
xsi: float,
104-
):
105-
if interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]:
106-
xsi = 1
107-
eta = 1
108-
if time < field.time[ti]:
109-
ti -= 1
110-
if field._z4d: # type: ignore[attr-defined]
111-
if ti == len(field.time) - 1:
112-
depth_vector = (
113-
(1 - xsi) * (1 - eta) * field.depth[-1, :, yi, xi]
114-
+ xsi * (1 - eta) * field.depth[-1, :, yi, xi + 1]
115-
+ xsi * eta * field.depth[-1, :, yi + 1, xi + 1]
116-
+ (1 - xsi) * eta * field.depth[-1, :, yi + 1, xi]
117-
)
118-
else:
119-
dv2 = (
120-
(1 - xsi) * (1 - eta) * field.depth[ti : ti + 2, :, yi, xi]
121-
+ xsi * (1 - eta) * field.depth[ti : ti + 2, :, yi, xi + 1]
122-
+ xsi * eta * field.depth[ti : ti + 2, :, yi + 1, xi + 1]
123-
+ (1 - xsi) * eta * field.depth[ti : ti + 2, :, yi + 1, xi]
124-
)
125-
tt = (time - field.time[ti]) / (field.time[ti + 1] - field.time[ti])
126-
assert tt >= 0 and tt <= 1, "Vertical s grid is being wrongly interpolated in time"
127-
depth_vector = dv2[0, :] * (1 - tt) + dv2[1, :] * tt
128-
else:
129-
depth_vector = (
130-
(1 - xsi) * (1 - eta) * field.depth[:, yi, xi]
131-
+ xsi * (1 - eta) * field.depth[:, yi, xi + 1]
132-
+ xsi * eta * field.depth[:, yi + 1, xi + 1]
133-
+ (1 - xsi) * eta * field.depth[:, yi + 1, xi]
134-
)
135-
z = np.float32(z) # type: ignore # TODO: remove type ignore once we migrate to float64
136-
137-
if depth_vector[-1] > depth_vector[0]:
138-
if z < depth_vector[0]:
139-
_raise_field_out_of_bound_error(z, None, None)
140-
elif z > depth_vector[-1]:
141-
_raise_field_out_of_bound_error(z, None, None)
142-
depth_indices = depth_vector < z
143-
if z >= depth_vector[-1]:
144-
zi = len(depth_vector) - 2
145-
else:
146-
zi = depth_indices.argmin() - 1 if z > depth_vector[0] else 0
147-
else:
148-
if z > depth_vector[0]:
149-
_raise_field_out_of_bound_error(z, None, None)
150-
elif z < depth_vector[-1]:
151-
_raise_field_out_of_bound_error(z, None, None)
152-
depth_indices = depth_vector > z
153-
if z <= depth_vector[-1]:
154-
zi = len(depth_vector) - 2
155-
else:
156-
zi = depth_indices.argmin() - 1 if z < depth_vector[0] else 0
157-
zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi])
158-
while zeta > 1:
159-
zi += 1
160-
zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi])
161-
while zeta < 0:
162-
zi -= 1
163-
zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi])
164-
return (zi, zeta)
165-
166-
167-
def _search_indices_rectilinear(
168-
field: Field, time: datetime, z: float, y: float, x: float, ti: int, ei: int | None = None, search2D=False
169-
):
170-
# TODO : If ei is provided, check if particle is in the same cell
171-
if field.xdim > 1 and (not field.zonal_periodic):
172-
if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]:
173-
_raise_field_out_of_bound_error(z, y, x)
174-
if field.ydim > 1 and (y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]):
175-
_raise_field_out_of_bound_error(z, y, x)
176-
177-
if field.xdim > 1:
178-
if field._mesh != "spherical":
179-
lon_index = field.lon < x
180-
if lon_index.all():
181-
xi = len(field.lon) - 2
182-
else:
183-
xi = lon_index.argmin() - 1 if lon_index.any() else 0
184-
xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi])
185-
if xsi < 0:
186-
xi -= 1
187-
xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi])
188-
elif xsi > 1:
189-
xi += 1
190-
xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi])
191-
else:
192-
lon_fixed = field.lon.copy()
193-
indices = lon_fixed >= lon_fixed[0]
194-
if not indices.all():
195-
lon_fixed[indices.argmin() :] += 360
196-
if x < lon_fixed[0]:
197-
lon_fixed -= 360
198-
199-
lon_index = lon_fixed < x
200-
if lon_index.all():
201-
xi = len(lon_fixed) - 2
202-
else:
203-
xi = lon_index.argmin() - 1 if lon_index.any() else 0
204-
xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi])
205-
if xsi < 0:
206-
xi -= 1
207-
xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi])
208-
elif xsi > 1:
209-
xi += 1
210-
xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi])
211-
else:
212-
xi, xsi = -1, 0
213-
214-
if field.ydim > 1:
215-
lat_index = field.lat < y
216-
if lat_index.all():
217-
yi = len(field.lat) - 2
218-
else:
219-
yi = lat_index.argmin() - 1 if lat_index.any() else 0
220-
221-
eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi])
222-
if eta < 0:
223-
yi -= 1
224-
eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi])
225-
elif eta > 1:
226-
yi += 1
227-
eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi])
228-
else:
229-
yi, eta = -1, 0
230-
231-
if field.zdim > 1 and not search2D:
232-
if field._gtype == GridType.RectilinearZGrid:
233-
try:
234-
(zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z)
235-
except FieldOutOfBoundError:
236-
_raise_field_out_of_bound_error(z, y, x)
237-
except FieldOutOfBoundSurfaceError:
238-
_raise_field_out_of_bound_surface_error(z, y, x)
239-
elif field._gtype == GridType.RectilinearSGrid:
240-
## TODO : Still need to implement the search_indices_vertical_s function
241-
(zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi)
242-
else:
243-
zi, zeta = -1, 0
244-
245-
if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)):
246-
_raise_field_sampling_error(z, y, x)
247-
248-
_ei = field.ravel_index(zi, yi, xi)
249-
250-
return (zeta, eta, xsi, _ei)
251-
252-
25343
def _search_indices_curvilinear_2d(
25444
grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None
25545
): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays
25646
yi, xi = yi_guess, xi_guess
25747
if yi is None or xi is None:
258-
faces = grid.get_spatial_hash().query(np.column_stack((y, x)))
259-
yi, xi = faces[0]
48+
yi, xi = grid.get_spatial_hash().query(y, x)
26049

26150
xsi = eta = -1.0 * np.ones(len(x), dtype=float)
26251
invA = np.array(
@@ -319,103 +108,21 @@ def _search_indices_curvilinear_2d(
319108
return (yi, eta, xi, xsi)
320109

321110

322-
## TODO : Still need to implement the search_indices_curvilinear
323-
def _search_indices_curvilinear(field, time, z, y, x, ti, particle=None, search2D=False):
324-
if particle:
325-
zi, yi, xi = field.unravel_index(particle.ei)
326-
else:
327-
xi = int(field.xdim / 2) - 1
328-
yi = int(field.ydim / 2) - 1
329-
xsi = eta = -1.0
330-
grid = field.grid
331-
invA = np.array([[1, 0, 0, 0], [-1, 1, 0, 0], [-1, 0, 0, 1], [1, -1, 1, -1]])
332-
maxIterSearch = 1e6
333-
it = 0
334-
tol = 1.0e-10
335-
if not grid.zonal_periodic:
336-
if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]:
337-
if grid.lon[0, 0] < grid.lon[0, -1]:
338-
_raise_field_out_of_bound_error(z, y, x)
339-
elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160]
340-
_raise_field_out_of_bound_error(z, y, x)
341-
if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]:
342-
_raise_field_out_of_bound_error(z, y, x)
343-
344-
while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol:
345-
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
346-
if grid.mesh == "spherical":
347-
px[0] = px[0] + 360 if px[0] < x - 225 else px[0]
348-
px[0] = px[0] - 360 if px[0] > x + 225 else px[0]
349-
px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:])
350-
px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:])
351-
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
352-
a = np.dot(invA, px)
353-
b = np.dot(invA, py)
354-
355-
aa = a[3] * b[2] - a[2] * b[3]
356-
bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3]
357-
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]
358-
if abs(aa) < 1e-12: # Rectilinear cell, or quasi
359-
eta = -cc / bb
111+
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool):
112+
if xi < 0:
113+
if sphere_mesh:
114+
xi = xdim - 2
360115
else:
361-
det2 = bb * bb - 4 * aa * cc
362-
if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter
363-
det = np.sqrt(det2)
364-
eta = (-bb + det) / (2 * aa)
365-
if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg
366-
xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5
116+
xi = 0
117+
if xi > xdim - 2:
118+
if sphere_mesh:
119+
xi = 0
367120
else:
368-
xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta)
369-
if xsi < 0 and eta < 0 and xi == 0 and yi == 0:
370-
_raise_field_out_of_bound_error(0, y, x)
371-
if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1:
372-
_raise_field_out_of_bound_error(0, y, x)
373-
if xsi < -tol:
374-
xi -= 1
375-
elif xsi > 1 + tol:
376-
xi += 1
377-
if eta < -tol:
378-
yi -= 1
379-
elif eta > 1 + tol:
380-
yi += 1
381-
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh)
382-
it += 1
383-
if it > maxIterSearch:
384-
print(f"Correct cell not found after {maxIterSearch} iterations")
385-
_raise_field_out_of_bound_error(0, y, x)
386-
xsi = max(0.0, xsi)
387-
eta = max(0.0, eta)
388-
xsi = min(1.0, xsi)
389-
eta = min(1.0, eta)
390-
391-
if grid.zdim > 1 and not search2D:
392-
if grid._gtype == GridType.CurvilinearZGrid:
393-
try:
394-
(zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z)
395-
except FieldOutOfBoundError:
396-
_raise_field_out_of_bound_error(z, y, x)
397-
elif grid._gtype == GridType.CurvilinearSGrid:
398-
(zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi)
399-
else:
400-
zi = -1
401-
zeta = 0
402-
403-
if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)):
404-
_raise_field_sampling_error(z, y, x)
405-
406-
if particle:
407-
particle.ei[field.igrid] = field.ravel_index(zi, yi, xi)
408-
409-
return (zeta, eta, xsi, zi, yi, xi)
410-
411-
412-
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh):
413-
xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi)
414-
xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi)
415-
416-
xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi)
417-
418-
yi = np.where(yi < 0, 0, yi)
419-
yi = np.where(yi > ydim - 2, ydim - 2, yi)
420-
121+
xi = xdim - 2
122+
if yi < 0:
123+
yi = 0
124+
if yi > ydim - 2:
125+
yi = ydim - 2
126+
if sphere_mesh:
127+
xi = xdim - xi
421128
return yi, xi

parcels/particleset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,11 @@ def __init__(
9696
assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don't all have the same lenghts"
9797

9898
if time is None or len(time) == 0:
99-
time = type(fieldset.U.data.time[0].values)(
100-
"NaT", "ns"
101-
) # do not set a time yet (because sign_dt not known)
99+
# do not set a time yet (because sign_dt not known)
100+
if fieldset.time_interval is None:
101+
time = np.timedelta64("NaT", "ns")
102+
else:
103+
time = type(fieldset.time_interval.left)("NaT", "ns")
102104
elif type(time[0]) in [np.datetime64, np.timedelta64]:
103105
pass # already in the right format
104106
else:

0 commit comments

Comments
 (0)