|
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 |
|
8 | | -from parcels._typing import ( |
9 | | - GridIndexingType, |
10 | | - InterpMethodOption, |
11 | | - Mesh, |
12 | | -) |
13 | 8 | from parcels.tools.statuscodes import ( |
14 | | - FieldOutOfBoundError, |
15 | | - FieldOutOfBoundSurfaceError, |
16 | 9 | _raise_field_out_of_bound_error, |
17 | | - _raise_field_out_of_bound_surface_error, |
18 | 10 | _raise_field_sampling_error, |
19 | 11 | _raise_time_extrapolation_error, |
20 | 12 | ) |
21 | 13 |
|
22 | | -from .basegrid import GridType |
23 | | - |
24 | 14 | if TYPE_CHECKING: |
25 | 15 | from parcels.xgrid import XGrid |
26 | 16 |
|
@@ -50,206 +40,6 @@ def _search_time_index(field: Field, time: datetime): |
50 | 40 | return np.atleast_1d(tau), np.atleast_1d(ti) |
51 | 41 |
|
52 | 42 |
|
53 | | -def search_indices_vertical_z(depth, gridindexingtype: GridIndexingType, z: float): |
54 | | - if depth[-1] > depth[0]: |
55 | | - if z < depth[0]: |
56 | | - # Since MOM5 is indexed at cell bottom, allow z at depth[0] - dz where dz = (depth[1] - depth[0]) |
57 | | - if gridindexingtype == "mom5" and z > 2 * depth[0] - depth[1]: |
58 | | - return (-1, z / depth[0]) |
59 | | - else: |
60 | | - _raise_field_out_of_bound_surface_error(z, None, None) |
61 | | - elif z > depth[-1]: |
62 | | - # In case of CROCO, allow particles in last (uppermost) layer using depth[-1] |
63 | | - if gridindexingtype in ["croco"] and z < 0: |
64 | | - return (-2, 1) |
65 | | - _raise_field_out_of_bound_error(z, None, None) |
66 | | - depth_indices = depth < z |
67 | | - if z >= depth[-1]: |
68 | | - zi = len(depth) - 2 |
69 | | - else: |
70 | | - zi = depth_indices.argmin() - 1 if z > depth[0] else 0 |
71 | | - else: |
72 | | - if z > depth[0]: |
73 | | - _raise_field_out_of_bound_surface_error(z, None, None) |
74 | | - elif z < depth[-1]: |
75 | | - _raise_field_out_of_bound_error(z, None, None) |
76 | | - depth_indices = depth > z |
77 | | - if z <= depth[-1]: |
78 | | - zi = len(depth) - 2 |
79 | | - else: |
80 | | - zi = depth_indices.argmin() - 1 if z < depth[0] else 0 |
81 | | - zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi]) |
82 | | - while zeta > 1: |
83 | | - zi += 1 |
84 | | - zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi]) |
85 | | - while zeta < 0: |
86 | | - zi -= 1 |
87 | | - zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi]) |
88 | | - return (zi, zeta) |
89 | | - |
90 | | - |
91 | | -## TODO : Still need to implement the search_indices_vertical_s function |
92 | | -def search_indices_vertical_s( |
93 | | - field: Field, |
94 | | - interp_method: InterpMethodOption, |
95 | | - time: float, |
96 | | - z: float, |
97 | | - y: float, |
98 | | - x: float, |
99 | | - ti: int, |
100 | | - yi: int, |
101 | | - xi: int, |
102 | | - eta: float, |
103 | | - xsi: float, |
104 | | -): |
105 | | - if interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]: |
106 | | - xsi = 1 |
107 | | - eta = 1 |
108 | | - if time < field.time[ti]: |
109 | | - ti -= 1 |
110 | | - if field._z4d: # type: ignore[attr-defined] |
111 | | - if ti == len(field.time) - 1: |
112 | | - depth_vector = ( |
113 | | - (1 - xsi) * (1 - eta) * field.depth[-1, :, yi, xi] |
114 | | - + xsi * (1 - eta) * field.depth[-1, :, yi, xi + 1] |
115 | | - + xsi * eta * field.depth[-1, :, yi + 1, xi + 1] |
116 | | - + (1 - xsi) * eta * field.depth[-1, :, yi + 1, xi] |
117 | | - ) |
118 | | - else: |
119 | | - dv2 = ( |
120 | | - (1 - xsi) * (1 - eta) * field.depth[ti : ti + 2, :, yi, xi] |
121 | | - + xsi * (1 - eta) * field.depth[ti : ti + 2, :, yi, xi + 1] |
122 | | - + xsi * eta * field.depth[ti : ti + 2, :, yi + 1, xi + 1] |
123 | | - + (1 - xsi) * eta * field.depth[ti : ti + 2, :, yi + 1, xi] |
124 | | - ) |
125 | | - tt = (time - field.time[ti]) / (field.time[ti + 1] - field.time[ti]) |
126 | | - assert tt >= 0 and tt <= 1, "Vertical s grid is being wrongly interpolated in time" |
127 | | - depth_vector = dv2[0, :] * (1 - tt) + dv2[1, :] * tt |
128 | | - else: |
129 | | - depth_vector = ( |
130 | | - (1 - xsi) * (1 - eta) * field.depth[:, yi, xi] |
131 | | - + xsi * (1 - eta) * field.depth[:, yi, xi + 1] |
132 | | - + xsi * eta * field.depth[:, yi + 1, xi + 1] |
133 | | - + (1 - xsi) * eta * field.depth[:, yi + 1, xi] |
134 | | - ) |
135 | | - z = np.float32(z) # type: ignore # TODO: remove type ignore once we migrate to float64 |
136 | | - |
137 | | - if depth_vector[-1] > depth_vector[0]: |
138 | | - if z < depth_vector[0]: |
139 | | - _raise_field_out_of_bound_error(z, None, None) |
140 | | - elif z > depth_vector[-1]: |
141 | | - _raise_field_out_of_bound_error(z, None, None) |
142 | | - depth_indices = depth_vector < z |
143 | | - if z >= depth_vector[-1]: |
144 | | - zi = len(depth_vector) - 2 |
145 | | - else: |
146 | | - zi = depth_indices.argmin() - 1 if z > depth_vector[0] else 0 |
147 | | - else: |
148 | | - if z > depth_vector[0]: |
149 | | - _raise_field_out_of_bound_error(z, None, None) |
150 | | - elif z < depth_vector[-1]: |
151 | | - _raise_field_out_of_bound_error(z, None, None) |
152 | | - depth_indices = depth_vector > z |
153 | | - if z <= depth_vector[-1]: |
154 | | - zi = len(depth_vector) - 2 |
155 | | - else: |
156 | | - zi = depth_indices.argmin() - 1 if z < depth_vector[0] else 0 |
157 | | - zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) |
158 | | - while zeta > 1: |
159 | | - zi += 1 |
160 | | - zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) |
161 | | - while zeta < 0: |
162 | | - zi -= 1 |
163 | | - zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) |
164 | | - return (zi, zeta) |
165 | | - |
166 | | - |
167 | | -def _search_indices_rectilinear( |
168 | | - field: Field, time: datetime, z: float, y: float, x: float, ti: int, ei: int | None = None, search2D=False |
169 | | -): |
170 | | - # TODO : If ei is provided, check if particle is in the same cell |
171 | | - if field.xdim > 1 and (not field.zonal_periodic): |
172 | | - if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: |
173 | | - _raise_field_out_of_bound_error(z, y, x) |
174 | | - if field.ydim > 1 and (y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]): |
175 | | - _raise_field_out_of_bound_error(z, y, x) |
176 | | - |
177 | | - if field.xdim > 1: |
178 | | - if field._mesh != "spherical": |
179 | | - lon_index = field.lon < x |
180 | | - if lon_index.all(): |
181 | | - xi = len(field.lon) - 2 |
182 | | - else: |
183 | | - xi = lon_index.argmin() - 1 if lon_index.any() else 0 |
184 | | - xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi]) |
185 | | - if xsi < 0: |
186 | | - xi -= 1 |
187 | | - xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi]) |
188 | | - elif xsi > 1: |
189 | | - xi += 1 |
190 | | - xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi]) |
191 | | - else: |
192 | | - lon_fixed = field.lon.copy() |
193 | | - indices = lon_fixed >= lon_fixed[0] |
194 | | - if not indices.all(): |
195 | | - lon_fixed[indices.argmin() :] += 360 |
196 | | - if x < lon_fixed[0]: |
197 | | - lon_fixed -= 360 |
198 | | - |
199 | | - lon_index = lon_fixed < x |
200 | | - if lon_index.all(): |
201 | | - xi = len(lon_fixed) - 2 |
202 | | - else: |
203 | | - xi = lon_index.argmin() - 1 if lon_index.any() else 0 |
204 | | - xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) |
205 | | - if xsi < 0: |
206 | | - xi -= 1 |
207 | | - xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) |
208 | | - elif xsi > 1: |
209 | | - xi += 1 |
210 | | - xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) |
211 | | - else: |
212 | | - xi, xsi = -1, 0 |
213 | | - |
214 | | - if field.ydim > 1: |
215 | | - lat_index = field.lat < y |
216 | | - if lat_index.all(): |
217 | | - yi = len(field.lat) - 2 |
218 | | - else: |
219 | | - yi = lat_index.argmin() - 1 if lat_index.any() else 0 |
220 | | - |
221 | | - eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi]) |
222 | | - if eta < 0: |
223 | | - yi -= 1 |
224 | | - eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi]) |
225 | | - elif eta > 1: |
226 | | - yi += 1 |
227 | | - eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi]) |
228 | | - else: |
229 | | - yi, eta = -1, 0 |
230 | | - |
231 | | - if field.zdim > 1 and not search2D: |
232 | | - if field._gtype == GridType.RectilinearZGrid: |
233 | | - try: |
234 | | - (zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z) |
235 | | - except FieldOutOfBoundError: |
236 | | - _raise_field_out_of_bound_error(z, y, x) |
237 | | - except FieldOutOfBoundSurfaceError: |
238 | | - _raise_field_out_of_bound_surface_error(z, y, x) |
239 | | - elif field._gtype == GridType.RectilinearSGrid: |
240 | | - ## TODO : Still need to implement the search_indices_vertical_s function |
241 | | - (zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi) |
242 | | - else: |
243 | | - zi, zeta = -1, 0 |
244 | | - |
245 | | - if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)): |
246 | | - _raise_field_sampling_error(z, y, x) |
247 | | - |
248 | | - _ei = field.ravel_index(zi, yi, xi) |
249 | | - |
250 | | - return (zeta, eta, xsi, _ei) |
251 | | - |
252 | | - |
253 | 43 | def _search_indices_curvilinear_2d( |
254 | 44 | grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None |
255 | 45 | ): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays |
@@ -318,103 +108,21 @@ def _search_indices_curvilinear_2d( |
318 | 108 | return (yi, eta, xi, xsi) |
319 | 109 |
|
320 | 110 |
|
321 | | -## TODO : Still need to implement the search_indices_curvilinear |
322 | | -def _search_indices_curvilinear(field, time, z, y, x, ti, particle=None, search2D=False): |
323 | | - if particle: |
324 | | - zi, yi, xi = field.unravel_index(particle.ei) |
325 | | - else: |
326 | | - xi = int(field.xdim / 2) - 1 |
327 | | - yi = int(field.ydim / 2) - 1 |
328 | | - xsi = eta = -1.0 |
329 | | - grid = field.grid |
330 | | - invA = np.array([[1, 0, 0, 0], [-1, 1, 0, 0], [-1, 0, 0, 1], [1, -1, 1, -1]]) |
331 | | - maxIterSearch = 1e6 |
332 | | - it = 0 |
333 | | - tol = 1.0e-10 |
334 | | - if not grid.zonal_periodic: |
335 | | - if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: |
336 | | - if grid.lon[0, 0] < grid.lon[0, -1]: |
337 | | - _raise_field_out_of_bound_error(z, y, x) |
338 | | - elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] |
339 | | - _raise_field_out_of_bound_error(z, y, x) |
340 | | - if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]: |
341 | | - _raise_field_out_of_bound_error(z, y, x) |
342 | | - |
343 | | - while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol: |
344 | | - px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) |
345 | | - if grid.mesh == "spherical": |
346 | | - px[0] = px[0] + 360 if px[0] < x - 225 else px[0] |
347 | | - px[0] = px[0] - 360 if px[0] > x + 225 else px[0] |
348 | | - px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:]) |
349 | | - px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:]) |
350 | | - py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) |
351 | | - a = np.dot(invA, px) |
352 | | - b = np.dot(invA, py) |
353 | | - |
354 | | - aa = a[3] * b[2] - a[2] * b[3] |
355 | | - bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] |
356 | | - cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] |
357 | | - if abs(aa) < 1e-12: # Rectilinear cell, or quasi |
358 | | - eta = -cc / bb |
| 111 | +def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool): |
| 112 | + if xi < 0: |
| 113 | + if sphere_mesh: |
| 114 | + xi = xdim - 2 |
359 | 115 | else: |
360 | | - det2 = bb * bb - 4 * aa * cc |
361 | | - if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter |
362 | | - det = np.sqrt(det2) |
363 | | - eta = (-bb + det) / (2 * aa) |
364 | | - if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg |
365 | | - xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5 |
| 116 | + xi = 0 |
| 117 | + if xi > xdim - 2: |
| 118 | + if sphere_mesh: |
| 119 | + xi = 0 |
366 | 120 | else: |
367 | | - xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta) |
368 | | - if xsi < 0 and eta < 0 and xi == 0 and yi == 0: |
369 | | - _raise_field_out_of_bound_error(0, y, x) |
370 | | - if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1: |
371 | | - _raise_field_out_of_bound_error(0, y, x) |
372 | | - if xsi < -tol: |
373 | | - xi -= 1 |
374 | | - elif xsi > 1 + tol: |
375 | | - xi += 1 |
376 | | - if eta < -tol: |
377 | | - yi -= 1 |
378 | | - elif eta > 1 + tol: |
379 | | - yi += 1 |
380 | | - (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh) |
381 | | - it += 1 |
382 | | - if it > maxIterSearch: |
383 | | - print(f"Correct cell not found after {maxIterSearch} iterations") |
384 | | - _raise_field_out_of_bound_error(0, y, x) |
385 | | - xsi = max(0.0, xsi) |
386 | | - eta = max(0.0, eta) |
387 | | - xsi = min(1.0, xsi) |
388 | | - eta = min(1.0, eta) |
389 | | - |
390 | | - if grid.zdim > 1 and not search2D: |
391 | | - if grid._gtype == GridType.CurvilinearZGrid: |
392 | | - try: |
393 | | - (zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z) |
394 | | - except FieldOutOfBoundError: |
395 | | - _raise_field_out_of_bound_error(z, y, x) |
396 | | - elif grid._gtype == GridType.CurvilinearSGrid: |
397 | | - (zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi) |
398 | | - else: |
399 | | - zi = -1 |
400 | | - zeta = 0 |
401 | | - |
402 | | - if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)): |
403 | | - _raise_field_sampling_error(z, y, x) |
404 | | - |
405 | | - if particle: |
406 | | - particle.ei[field.igrid] = field.ravel_index(zi, yi, xi) |
407 | | - |
408 | | - return (zeta, eta, xsi, zi, yi, xi) |
409 | | - |
410 | | - |
411 | | -def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh): |
412 | | - xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi) |
413 | | - xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi) |
414 | | - |
415 | | - xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi) |
416 | | - |
417 | | - yi = np.where(yi < 0, 0, yi) |
418 | | - yi = np.where(yi > ydim - 2, ydim - 2, yi) |
419 | | - |
| 121 | + xi = xdim - 2 |
| 122 | + if yi < 0: |
| 123 | + yi = 0 |
| 124 | + if yi > ydim - 2: |
| 125 | + yi = ydim - 2 |
| 126 | + if sphere_mesh: |
| 127 | + xi = xdim - xi |
420 | 128 | return yi, xi |
0 commit comments