Skip to content

Commit af8cff8

Browse files
Merge branch 'v4-dev' into spatial_slip_interpolation
2 parents 17ec41b + 089b8f9 commit af8cff8

File tree

10 files changed

+354
-175
lines changed

10 files changed

+354
-175
lines changed

docs/v4/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The key goals of this update are
77
1. to support `Fields` on unstructured grids;
88
2. to allow for user-defined interpolation methods (somewhat similar to user-defined kernels);
99
3. to make the codebase more modular, easier to extend, and more maintainable;
10-
4. to align Parcels more with other tools in the [Pangeo ecosystemand](https://www.pangeo.io/#ecosystem), particularly by leveraging `xarray` more; and
10+
4. to align Parcels more with other tools in the [Pangeo ecosystem](https://www.pangeo.io/#ecosystem), particularly by leveraging `xarray` more; and
1111
5. to improve the performance of Parcels.
1212

1313
The timeline for the release of Parcels v4 is not yet fixed, but we are aiming for a release of an 'alpha' version in September 2025. This v4-alpha will have support for unstructured grids and user-defined interpolation methods, but is not yet performance-optimised.

parcels/_index_search.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
import numpy as np
77

8+
from parcels._typing import Mesh
89
from parcels.tools.statuscodes import (
9-
_raise_field_out_of_bound_error,
10-
_raise_field_sampling_error,
10+
_raise_grid_searching_error,
1111
_raise_time_extrapolation_error,
1212
)
1313

@@ -64,12 +64,12 @@ def _search_indices_curvilinear_2d(
6464
# TODO: Re-enable in some capacity
6565
# if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]:
6666
# if grid.lon[0, 0] < grid.lon[0, -1]:
67-
# _raise_field_out_of_bound_error(y, x)
67+
# _raise_grid_searching_error(y, x)
6868
# elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160]
69-
# _raise_field_out_of_bound_error(z, y, x)
69+
# _raise_grid_searching_error(z, y, x)
7070

7171
# if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]:
72-
# _raise_field_out_of_bound_error(z, y, x)
72+
# _raise_grid_searching_error(z, y, x)
7373

7474
while np.any(xsi < -tol) or np.any(xsi > 1 + tol) or np.any(eta < -tol) or np.any(eta > 1 + tol):
7575
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
@@ -99,30 +99,22 @@ def _search_indices_curvilinear_2d(
9999
it += 1
100100
if it > maxIterSearch:
101101
print(f"Correct cell not found after {maxIterSearch} iterations")
102-
_raise_field_out_of_bound_error(0, y, x)
102+
_raise_grid_searching_error(0, y, x)
103103
xsi = np.where(xsi < 0.0, 0.0, np.where(xsi > 1.0, 1.0, xsi))
104104
eta = np.where(eta < 0.0, 0.0, np.where(eta > 1.0, 1.0, eta))
105105

106106
if np.any((xsi < 0) | (xsi > 1) | (eta < 0) | (eta > 1)):
107-
_raise_field_sampling_error(y, x)
107+
_raise_grid_searching_error(y, x)
108108
return (yi, eta, xi, xsi)
109109

110110

111-
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool):
112-
if xi < 0:
113-
if sphere_mesh:
114-
xi = xdim - 2
115-
else:
116-
xi = 0
117-
if xi > xdim - 2:
118-
if sphere_mesh:
119-
xi = 0
120-
else:
121-
xi = xdim - 2
122-
if yi < 0:
123-
yi = 0
124-
if yi > ydim - 2:
125-
yi = ydim - 2
126-
if sphere_mesh:
127-
xi = xdim - xi
111+
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh):
112+
xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi)
113+
xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi)
114+
115+
xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi)
116+
117+
yi = np.where(yi < 0, 0, yi)
118+
yi = np.where(yi > ydim - 2, ydim - 2, yi)
119+
128120
return yi, xi

parcels/field.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,12 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
238238

239239
tau, ti = _search_time_index(self, time)
240240
position = self.grid.search(z, y, x, ei=_ei)
241-
_update_particle_states(particle, position)
241+
_update_particle_states_position(particle, position)
242242

243243
value = self._interp_method(self, ti, position, tau, time, z, y, x)
244244

245+
_update_particle_states_interp_value(particle, value)
246+
245247
if applyConversion:
246248
value = self.units.to_target(value, z, y, x)
247249
return value
@@ -318,16 +320,21 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
318320

319321
tau, ti = _search_time_index(self.U, time)
320322
position = self.grid.search(z, y, x, ei=_ei)
321-
_update_particle_states(particle, position)
323+
_update_particle_states_position(particle, position)
322324

323325
if self._vector_interp_method is None:
324326
u = self.U._interp_method(self.U, ti, position, tau, time, z, y, x)
325327
v = self.V._interp_method(self.V, ti, position, tau, time, z, y, x)
326328
if "3D" in self.vector_type:
327329
w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x)
330+
else:
331+
w = 0.0
328332
else:
329333
(u, v, w) = self._vector_interp_method(self, ti, position, tau, time, z, y, x)
330334

335+
for vel in (u, v, w):
336+
_update_particle_states_interp_value(particle, vel)
337+
331338
if applyConversion:
332339
u = self.U.units.to_target(u, z, y, x)
333340
v = self.V.units.to_target(v, z, y, x)
@@ -349,14 +356,30 @@ def __getitem__(self, key):
349356
return _deal_with_errors(error, key, vector_type=self.vector_type)
350357

351358

352-
def _update_particle_states(particle, position):
359+
def _update_particle_states_position(particle, position):
353360
"""Update the particle states based on the position dictionary."""
354361
if particle and "X" in position: # TODO also support uxgrid search
355-
particle.state = np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state)
356-
particle.state = np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state)
357-
particle.state = np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state)
358-
particle.state = np.where(
359-
position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state
362+
particle.state = np.maximum(
363+
np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state
364+
)
365+
particle.state = np.maximum(
366+
np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state
367+
)
368+
particle.state = np.maximum(
369+
np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state),
370+
particle.state,
371+
)
372+
particle.state = np.maximum(
373+
np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state),
374+
particle.state,
375+
)
376+
377+
378+
def _update_particle_states_interp_value(particle, value):
379+
"""Update the particle states based on the interpolated value, but only if state is not an Error already."""
380+
if particle:
381+
particle.state = np.maximum(
382+
np.where(np.isnan(value), StatusCode.ErrorInterpolation, particle.state), particle.state
360383
)
361384

362385

parcels/kernel.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from parcels.basegrid import GridType
1616
from parcels.tools.statuscodes import (
1717
StatusCode,
18+
_raise_field_interpolation_error,
1819
_raise_field_out_of_bound_error,
1920
_raise_field_out_of_bound_surface_error,
20-
_raise_field_sampling_error,
21+
_raise_general_error,
22+
_raise_grid_searching_error,
2123
_raise_time_extrapolation_error,
2224
)
2325
from parcels.tools.warnings import KernelWarning
@@ -28,6 +30,16 @@
2830
__all__ = ["Kernel"]
2931

3032

33+
ErrorsToThrow = {
34+
StatusCode.ErrorTimeExtrapolation: _raise_time_extrapolation_error,
35+
StatusCode.ErrorOutOfBounds: _raise_field_out_of_bound_error,
36+
StatusCode.ErrorThroughSurface: _raise_field_out_of_bound_surface_error,
37+
StatusCode.ErrorInterpolation: _raise_field_interpolation_error,
38+
StatusCode.ErrorGridSearching: _raise_grid_searching_error,
39+
StatusCode.Error: _raise_general_error,
40+
}
41+
42+
3143
class Kernel:
3244
"""Kernel object that encapsulates auto-generated code.
3345
@@ -270,14 +282,7 @@ def execute(self, pset, endtime, dt):
270282
if np.any(pset.state == StatusCode.StopAllExecution):
271283
return StatusCode.StopAllExecution
272284

273-
errors_to_throw = {
274-
StatusCode.ErrorTimeExtrapolation: _raise_time_extrapolation_error,
275-
StatusCode.ErrorOutOfBounds: _raise_field_out_of_bound_error,
276-
StatusCode.ErrorThroughSurface: _raise_field_out_of_bound_surface_error,
277-
StatusCode.Error: _raise_field_sampling_error,
278-
}
279-
280-
for error_code, error_func in errors_to_throw.items():
285+
for error_code, error_func in ErrorsToThrow.items():
281286
if np.any(pset.state == error_code):
282287
inds = pset.state == error_code
283288
if error_code == StatusCode.ErrorTimeExtrapolation:

parcels/particleset.py

Lines changed: 68 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import warnings
33
from collections.abc import Iterable
4+
from typing import Literal
45

56
import numpy as np
67
import xarray as xr
@@ -506,59 +507,17 @@ def execute(
506507
if output_file:
507508
output_file.metadata["parcels_kernels"] = self._kernel.name
508509

509-
if (dt is not None) and (not isinstance(dt, np.timedelta64)):
510-
raise TypeError("dt must be a np.timedelta64 object")
511-
if dt is None or np.isnat(dt):
510+
if dt is None:
512511
dt = np.timedelta64(1, "s")
513-
self._data["dt"][:] = dt
514-
sign_dt = np.sign(dt).astype(int)
515-
if sign_dt not in [-1, 1]:
516-
raise ValueError("dt must be a positive or negative np.timedelta64 object")
517512

518-
if self.fieldset.time_interval is None:
519-
start_time = np.timedelta64(0, "s") # For the execution loop, we need a start time as a timedelta object
520-
if runtime is None:
521-
raise TypeError("The runtime must be provided when the time_interval is not defined for a fieldset.")
513+
if not isinstance(dt, np.timedelta64) or np.isnat(dt) or (sign_dt := np.sign(dt).astype(int)) not in [-1, 1]:
514+
raise ValueError(f"dt must be a positive or negative np.timedelta64 object, got {dt=!r}")
522515

523-
else:
524-
if isinstance(runtime, np.timedelta64):
525-
end_time = runtime
526-
else:
527-
raise TypeError("The runtime must be a np.timedelta64 object")
516+
self._data["dt"][:] = dt
528517

529-
else:
530-
if not np.isnat(self.time_nextloop).any():
531-
if sign_dt > 0:
532-
start_time = self.time_nextloop.min()
533-
else:
534-
start_time = self.time_nextloop.max()
535-
else:
536-
if sign_dt > 0:
537-
start_time = self.fieldset.time_interval.left
538-
else:
539-
start_time = self.fieldset.time_interval.right
540-
541-
if runtime is None:
542-
if endtime is None:
543-
raise ValueError(
544-
"Must provide either runtime or endtime when time_interval is defined for a fieldset."
545-
)
546-
# Ensure that the endtime uses the same type as the start_time
547-
if isinstance(endtime, self.fieldset.time_interval.left.__class__):
548-
if sign_dt > 0:
549-
if endtime < self.fieldset.time_interval.left:
550-
raise ValueError("The endtime must be after the start time of the fieldset.time_interval")
551-
end_time = min(endtime, self.fieldset.time_interval.right)
552-
else:
553-
if endtime > self.fieldset.time_interval.right:
554-
raise ValueError(
555-
"The endtime must be before the end time of the fieldset.time_interval when dt < 0"
556-
)
557-
end_time = max(endtime, self.fieldset.time_interval.left)
558-
else:
559-
raise TypeError("The endtime must be of the same type as the fieldset.time_interval start time.")
560-
else:
561-
end_time = start_time + runtime * sign_dt
518+
start_time, end_time = _get_simulation_start_and_end_times(
519+
self.fieldset.time_interval, self._data["time_nextloop"], runtime, endtime, sign_dt
520+
)
562521

563522
# Set the time of the particles if it hadn't been set on initialisation
564523
if np.isnat(self._data["time"]).any():
@@ -619,15 +578,69 @@ def _warn_particle_times_outside_fieldset_time_bounds(release_times: np.ndarray,
619578

620579
if isinstance(time.left, np.datetime64) and isinstance(release_times[0], np.timedelta64):
621580
release_times = np.array([t + time.left for t in release_times])
622-
if np.any(release_times < time.left):
581+
if np.any(release_times < time.left) or np.any(release_times > time.right):
623582
warnings.warn(
624583
"Some particles are set to be released outside the FieldSet's executable time domain.",
625584
ParticleSetWarning,
626585
stacklevel=2,
627586
)
628-
if np.any(release_times > time.right):
629-
warnings.warn(
630-
"Some particles are set to be released after the fieldset's last time and the fields are not constant in time.",
631-
ParticleSetWarning,
632-
stacklevel=2,
587+
588+
589+
def _get_simulation_start_and_end_times(
590+
time_interval: TimeInterval,
591+
particle_release_times: np.ndarray,
592+
runtime: np.timedelta64 | None,
593+
endtime: np.datetime64 | None,
594+
sign_dt: Literal[-1, 1],
595+
) -> tuple[np.datetime64, np.datetime64]:
596+
if runtime is not None and endtime is not None:
597+
raise ValueError(
598+
f"runtime and endtime are mutually exclusive - provide one or the other. Got {runtime=!r}, {endtime=!r}"
633599
)
600+
601+
if runtime is None and time_interval is None:
602+
raise ValueError("The runtime must be provided when the time_interval is not defined for a fieldset.")
603+
604+
if sign_dt == 1:
605+
first_release_time = particle_release_times.min()
606+
else:
607+
first_release_time = particle_release_times.max()
608+
609+
start_time = _get_start_time(first_release_time, time_interval, sign_dt, runtime)
610+
611+
if endtime is None:
612+
if not isinstance(runtime, np.timedelta64):
613+
raise ValueError(f"The runtime must be a np.timedelta64 object. Got {type(runtime)}")
614+
615+
endtime = start_time + sign_dt * runtime
616+
617+
if time_interval is not None:
618+
if type(endtime) != type(time_interval.left): # noqa: E721
619+
raise ValueError(
620+
f"The endtime must be of the same type as the fieldset.time_interval start time. Got {endtime=!r} with {time_interval=!r}"
621+
)
622+
if endtime not in time_interval:
623+
msg = (
624+
f"Calculated/provided end time of {endtime!r} is not in fieldset time interval {time_interval!r}. Either reduce your runtime, modify your "
625+
"provided endtime, or change your release timing."
626+
"Important info:\n"
627+
f" First particle release: {first_release_time!r}\n"
628+
f" runtime: {runtime!r}\n"
629+
f" (calculated) endtime: {endtime!r}"
630+
)
631+
raise ValueError(msg)
632+
633+
return start_time, endtime
634+
635+
636+
def _get_start_time(first_release_time, time_interval, sign_dt, runtime):
637+
if time_interval is None:
638+
time_interval = TimeInterval(left=np.timedelta64(0, "s"), right=runtime)
639+
640+
if sign_dt == 1:
641+
fieldset_start = time_interval.left
642+
else:
643+
fieldset_start = time_interval.right
644+
645+
start_time = first_release_time if not np.isnat(first_release_time) else fieldset_start
646+
return start_time

0 commit comments

Comments
 (0)