55
66import numpy as np
77
8- from parcels ._typing import Mesh
9- from parcels .tools .statuscodes import (
10- _raise_grid_searching_error ,
11- _raise_time_extrapolation_error ,
12- )
8+ from parcels .tools .statuscodes import _raise_time_extrapolation_error
139
1410if TYPE_CHECKING :
1511 from parcels .xgrid import XGrid
1612
1713 from .field import Field
1814
1915
16+ GRID_SEARCH_ERROR = - 3
17+
18+
2019def _search_time_index (field : Field , time : datetime ):
2120 """Find and return the index and relative coordinate in the time array associated with a given time.
2221
@@ -40,13 +39,7 @@ def _search_time_index(field: Field, time: datetime):
4039 return np .atleast_1d (tau ), np .atleast_1d (ti )
4140
4241
43- def _search_indices_curvilinear_2d (
44- grid : XGrid , y : float , x : float , yi_guess : int | None = None , xi_guess : int | None = None
45- ): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays
46- yi , xi = yi_guess , xi_guess
47- if yi is None or xi is None :
48- yi , xi = grid .get_spatial_hash ().query (y , x )
49-
42+ def curvilinear_point_in_cell (grid , y : np .ndarray , x : np .ndarray , yi : np .ndarray , xi : np .ndarray ):
5043 xsi = eta = - 1.0 * np .ones (len (x ), dtype = float )
5144 invA = np .array (
5245 [
@@ -56,67 +49,60 @@ def _search_indices_curvilinear_2d(
5649 [1 , - 1 , 1 , - 1 ],
5750 ]
5851 )
59- maxIterSearch = 1e6
60- it = 0
61- tol = 1.0e-10
62-
63- # # ! Error handling for out of bounds
64- # TODO: Re-enable in some capacity
65- # if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]:
66- # if grid.lon[0, 0] < grid.lon[0, -1]:
67- # _raise_grid_searching_error(y, x)
68- # elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160]
69- # _raise_grid_searching_error(z, y, x)
70-
71- # if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]:
72- # _raise_grid_searching_error(z, y, x)
73-
74- while np .any (xsi < - tol ) or np .any (xsi > 1 + tol ) or np .any (eta < - tol ) or np .any (eta > 1 + tol ):
75- px = np .array ([grid .lon [yi , xi ], grid .lon [yi , xi + 1 ], grid .lon [yi + 1 , xi + 1 ], grid .lon [yi + 1 , xi ]])
76-
77- py = np .array ([grid .lat [yi , xi ], grid .lat [yi , xi + 1 ], grid .lat [yi + 1 , xi + 1 ], grid .lat [yi + 1 , xi ]])
78- a = np .dot (invA , px )
79- b = np .dot (invA , py )
80-
81- aa = a [3 ] * b [2 ] - a [2 ] * b [3 ]
82- bb = a [3 ] * b [0 ] - a [0 ] * b [3 ] + a [1 ] * b [2 ] - a [2 ] * b [1 ] + x * b [3 ] - y * a [3 ]
83- cc = a [1 ] * b [0 ] - a [0 ] * b [1 ] + x * b [1 ] - y * a [1 ]
84-
85- det2 = bb * bb - 4 * aa * cc
86- with np .errstate (divide = "ignore" , invalid = "ignore" ):
87- det = np .where (det2 > 0 , np .sqrt (det2 ), eta )
88-
89- eta = np .where (abs (aa ) < 1e-12 , - cc / bb , np .where (det2 > 0 , (- bb + det ) / (2 * aa ), eta ))
90-
91- xsi = np .where (
92- abs (a [1 ] + a [3 ] * eta ) < 1e-12 ,
93- ((y - py [0 ]) / (py [1 ] - py [0 ]) + (y - py [3 ]) / (py [2 ] - py [3 ])) * 0.5 ,
94- (x - a [0 ] - a [2 ] * eta ) / (a [1 ] + a [3 ] * eta ),
95- )
96-
97- xi = np .where (xsi < - tol , xi - 1 , np .where (xsi > 1 + tol , xi + 1 , xi ))
98- yi = np .where (eta < - tol , yi - 1 , np .where (eta > 1 + tol , yi + 1 , yi ))
99-
100- (yi , xi ) = _reconnect_bnd_indices (yi , xi , grid .ydim , grid .xdim , grid ._mesh )
101- it += 1
102- if it > maxIterSearch :
103- print (f"Correct cell not found after { maxIterSearch } iterations" )
104- _raise_grid_searching_error (0 , y , x )
105- xsi = np .where (xsi < 0.0 , 0.0 , np .where (xsi > 1.0 , 1.0 , xsi ))
106- eta = np .where (eta < 0.0 , 0.0 , np .where (eta > 1.0 , 1.0 , eta ))
107-
108- if np .any ((xsi < 0 ) | (xsi > 1 ) | (eta < 0 ) | (eta > 1 )):
109- _raise_grid_searching_error (y , x )
110- return (yi , eta , xi , xsi )
11152
53+ px = np .array ([grid .lon [yi , xi ], grid .lon [yi , xi + 1 ], grid .lon [yi + 1 , xi + 1 ], grid .lon [yi + 1 , xi ]])
54+ py = np .array ([grid .lat [yi , xi ], grid .lat [yi , xi + 1 ], grid .lat [yi + 1 , xi + 1 ], grid .lat [yi + 1 , xi ]])
55+
56+ a , b = np .dot (invA , px ), np .dot (invA , py )
57+ aa = a [3 ] * b [2 ] - a [2 ] * b [3 ]
58+ bb = a [3 ] * b [0 ] - a [0 ] * b [3 ] + a [1 ] * b [2 ] - a [2 ] * b [1 ] + x * b [3 ] - y * a [3 ]
59+ cc = a [1 ] * b [0 ] - a [0 ] * b [1 ] + x * b [1 ] - y * a [1 ]
60+ det2 = bb * bb - 4 * aa * cc
11261
113- def _reconnect_bnd_indices ( yi : int , xi : int , ydim : int , xdim : int , mesh : Mesh ):
114- xi = np .where (xi < 0 , ( xdim - 2 ) if mesh == "spherical" else 0 , xi )
115- xi = np .where (xi > xdim - 2 , 0 if mesh == "spherical" else ( xdim - 2 ), xi )
62+ with np . errstate ( divide = "ignore" , invalid = "ignore" ):
63+ det = np .where (det2 > 0 , np . sqrt ( det2 ), eta )
64+ eta = np .where (abs ( aa ) < 1e-12 , - cc / bb , np . where ( det2 > 0 , ( - bb + det ) / ( 2 * aa ), eta ) )
11665
117- xi = np .where (yi > ydim - 2 , xdim - xi if mesh == "spherical" else xi , xi )
66+ xsi = np .where (
67+ abs (a [1 ] + a [3 ] * eta ) < 1e-12 ,
68+ ((y - py [0 ]) / (py [1 ] - py [0 ]) + (y - py [3 ]) / (py [2 ] - py [3 ])) * 0.5 ,
69+ (x - a [0 ] - a [2 ] * eta ) / (a [1 ] + a [3 ] * eta ),
70+ )
11871
119- yi = np .where (yi < 0 , 0 , yi )
120- yi = np .where (yi > ydim - 2 , ydim - 2 , yi )
72+ is_in_cell = np .where ((xsi >= 0 ) & (xsi <= 1 ) & (eta >= 0 ) & (eta <= 1 ), 1 , 0 )
12173
122- return yi , xi
74+ return is_in_cell , np .column_stack ((xsi , eta ))
75+
76+
77+ def _search_indices_curvilinear_2d (
78+ grid : XGrid , y : np .ndarray , x : np .ndarray , yi_guess : np .ndarray | None = None , xi_guess : np .ndarray | None = None
79+ ):
80+ yi_guess = np .array (yi_guess )
81+ xi_guess = np .array (xi_guess )
82+ xi = np .full (len (x ), GRID_SEARCH_ERROR , dtype = np .int32 )
83+ yi = np .full (len (y ), GRID_SEARCH_ERROR , dtype = np .int32 )
84+ if np .any (xi_guess ):
85+ # If an initial guess is provided, we first perform a point in cell check for all guessed indices
86+ is_in_cell , coords = curvilinear_point_in_cell (grid , y , x , yi_guess , xi_guess )
87+ y_check = y [is_in_cell == 0 ]
88+ x_check = x [is_in_cell == 0 ]
89+ zero_indices = np .where (is_in_cell == 0 )[0 ]
90+ else :
91+ # Otherwise, we need to check all points
92+ y_check = y
93+ x_check = x
94+ coords = - 1.0 * np .ones ((len (y ), 2 ), dtype = np .float32 )
95+ zero_indices = np .arange (len (y ))
96+
97+ # If there are any points that were not found in the first step, we query the spatial hash for those points
98+ if len (zero_indices ) > 0 :
99+ yi_q , xi_q , coords_q = grid .get_spatial_hash ().query (y_check , x_check )
100+ # Only those points that were not found in the first step are updated
101+ coords [zero_indices , :] = coords_q
102+ yi [zero_indices ] = yi_q
103+ xi [zero_indices ] = xi_q
104+
105+ xsi = coords [:, 0 ]
106+ eta = coords [:, 1 ]
107+
108+ return (yi , eta , xi , xsi )
0 commit comments