Skip to content

Commit bacd187

Browse files
Merge pull request #2177 from OceanParcels/curvilinear_index_search_without_while_loop
Further cleaning the handling of grid searching errors
2 parents 8c4aaed + 031eb46 commit bacd187

File tree

7 files changed

+407
-262
lines changed

7 files changed

+407
-262
lines changed

parcels/_index_search.py

Lines changed: 57 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,17 @@
55

66
import numpy as np
77

8-
from parcels._typing import Mesh
9-
from parcels.tools.statuscodes import (
10-
_raise_grid_searching_error,
11-
_raise_time_extrapolation_error,
12-
)
8+
from parcels.tools.statuscodes import _raise_time_extrapolation_error
139

1410
if TYPE_CHECKING:
1511
from parcels.xgrid import XGrid
1612

1713
from .field import Field
1814

1915

16+
GRID_SEARCH_ERROR = -3
17+
18+
2019
def _search_time_index(field: Field, time: datetime):
2120
"""Find and return the index and relative coordinate in the time array associated with a given time.
2221
@@ -40,13 +39,7 @@ def _search_time_index(field: Field, time: datetime):
4039
return np.atleast_1d(tau), np.atleast_1d(ti)
4140

4241

43-
def _search_indices_curvilinear_2d(
44-
grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None
45-
): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays
46-
yi, xi = yi_guess, xi_guess
47-
if yi is None or xi is None:
48-
yi, xi = grid.get_spatial_hash().query(y, x)
49-
42+
def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray, xi: np.ndarray):
5043
xsi = eta = -1.0 * np.ones(len(x), dtype=float)
5144
invA = np.array(
5245
[
@@ -56,67 +49,60 @@ def _search_indices_curvilinear_2d(
5649
[1, -1, 1, -1],
5750
]
5851
)
59-
maxIterSearch = 1e6
60-
it = 0
61-
tol = 1.0e-10
62-
63-
# # ! Error handling for out of bounds
64-
# TODO: Re-enable in some capacity
65-
# if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]:
66-
# if grid.lon[0, 0] < grid.lon[0, -1]:
67-
# _raise_grid_searching_error(y, x)
68-
# elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160]
69-
# _raise_grid_searching_error(z, y, x)
70-
71-
# if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]:
72-
# _raise_grid_searching_error(z, y, x)
73-
74-
while np.any(xsi < -tol) or np.any(xsi > 1 + tol) or np.any(eta < -tol) or np.any(eta > 1 + tol):
75-
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
76-
77-
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
78-
a = np.dot(invA, px)
79-
b = np.dot(invA, py)
80-
81-
aa = a[3] * b[2] - a[2] * b[3]
82-
bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3]
83-
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]
84-
85-
det2 = bb * bb - 4 * aa * cc
86-
with np.errstate(divide="ignore", invalid="ignore"):
87-
det = np.where(det2 > 0, np.sqrt(det2), eta)
88-
89-
eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta))
90-
91-
xsi = np.where(
92-
abs(a[1] + a[3] * eta) < 1e-12,
93-
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
94-
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
95-
)
96-
97-
xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
98-
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))
99-
100-
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh)
101-
it += 1
102-
if it > maxIterSearch:
103-
print(f"Correct cell not found after {maxIterSearch} iterations")
104-
_raise_grid_searching_error(0, y, x)
105-
xsi = np.where(xsi < 0.0, 0.0, np.where(xsi > 1.0, 1.0, xsi))
106-
eta = np.where(eta < 0.0, 0.0, np.where(eta > 1.0, 1.0, eta))
107-
108-
if np.any((xsi < 0) | (xsi > 1) | (eta < 0) | (eta > 1)):
109-
_raise_grid_searching_error(y, x)
110-
return (yi, eta, xi, xsi)
11152

53+
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
54+
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
55+
56+
a, b = np.dot(invA, px), np.dot(invA, py)
57+
aa = a[3] * b[2] - a[2] * b[3]
58+
bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3]
59+
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]
60+
det2 = bb * bb - 4 * aa * cc
11261

113-
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh):
114-
xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi)
115-
xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi)
62+
with np.errstate(divide="ignore", invalid="ignore"):
63+
det = np.where(det2 > 0, np.sqrt(det2), eta)
64+
eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta))
11665

117-
xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi)
66+
xsi = np.where(
67+
abs(a[1] + a[3] * eta) < 1e-12,
68+
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
69+
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
70+
)
11871

119-
yi = np.where(yi < 0, 0, yi)
120-
yi = np.where(yi > ydim - 2, ydim - 2, yi)
72+
is_in_cell = np.where((xsi >= 0) & (xsi <= 1) & (eta >= 0) & (eta <= 1), 1, 0)
12173

122-
return yi, xi
74+
return is_in_cell, np.column_stack((xsi, eta))
75+
76+
77+
def _search_indices_curvilinear_2d(
78+
grid: XGrid, y: np.ndarray, x: np.ndarray, yi_guess: np.ndarray | None = None, xi_guess: np.ndarray | None = None
79+
):
80+
yi_guess = np.array(yi_guess)
81+
xi_guess = np.array(xi_guess)
82+
xi = np.full(len(x), GRID_SEARCH_ERROR, dtype=np.int32)
83+
yi = np.full(len(y), GRID_SEARCH_ERROR, dtype=np.int32)
84+
if np.any(xi_guess):
85+
# If an initial guess is provided, we first perform a point in cell check for all guessed indices
86+
is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi_guess, xi_guess)
87+
y_check = y[is_in_cell == 0]
88+
x_check = x[is_in_cell == 0]
89+
zero_indices = np.where(is_in_cell == 0)[0]
90+
else:
91+
# Otherwise, we need to check all points
92+
y_check = y
93+
x_check = x
94+
coords = -1.0 * np.ones((len(y), 2), dtype=np.float32)
95+
zero_indices = np.arange(len(y))
96+
97+
# If there are any points that were not found in the first step, we query the spatial hash for those points
98+
if len(zero_indices) > 0:
99+
yi_q, xi_q, coords_q = grid.get_spatial_hash().query(y_check, x_check)
100+
# Only those points that were not found in the first step are updated
101+
coords[zero_indices, :] = coords_q
102+
yi[zero_indices] = yi_q
103+
xi[zero_indices] = xi_q
104+
105+
xsi = coords[:, 0]
106+
eta = coords[:, 1]
107+
108+
return (yi, eta, xi, xsi)

parcels/field.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from parcels.uxgrid import UxGrid
3131
from parcels.xgrid import LEFT_OUT_OF_BOUNDS, RIGHT_OUT_OF_BOUNDS, XGrid, _transpose_xfield_data_to_tzyx
3232

33-
from ._index_search import _search_time_index
33+
from ._index_search import GRID_SEARCH_ERROR, _search_time_index
3434

3535
__all__ = ["Field", "VectorField"]
3636

@@ -341,21 +341,25 @@ def __getitem__(self, key):
341341

342342
def _update_particle_states_position(particle, position):
343343
"""Update the particle states based on the position dictionary."""
344-
if particle and "X" in position: # TODO also support uxgrid search
345-
particle.state = np.maximum(
346-
np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state
347-
)
348-
particle.state = np.maximum(
349-
np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state
350-
)
351-
particle.state = np.maximum(
352-
np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state),
353-
particle.state,
354-
)
355-
particle.state = np.maximum(
356-
np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state),
357-
particle.state,
358-
)
344+
if particle: # TODO also support uxgrid search
345+
for dim in ["X", "Y"]:
346+
if dim in position:
347+
particle.state = np.maximum(
348+
np.where(position[dim][0] == -1, StatusCode.ErrorOutOfBounds, particle.state), particle.state
349+
)
350+
particle.state = np.maximum(
351+
np.where(position[dim][0] == GRID_SEARCH_ERROR, StatusCode.ErrorGridSearching, particle.state),
352+
particle.state,
353+
)
354+
if "Z" in position:
355+
particle.state = np.maximum(
356+
np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state),
357+
particle.state,
358+
)
359+
particle.state = np.maximum(
360+
np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state),
361+
particle.state,
362+
)
359363

360364

361365
def _update_particle_states_interp_value(particle, value):

0 commit comments

Comments
 (0)