diff --git a/parcels/particleset.py b/parcels/particleset.py index 828f85c5e..f67b6b34e 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -1,6 +1,7 @@ import sys import warnings from collections.abc import Iterable +from typing import Literal import numpy as np import xarray as xr @@ -506,59 +507,17 @@ def execute( if output_file: output_file.metadata["parcels_kernels"] = self._kernel.name - if (dt is not None) and (not isinstance(dt, np.timedelta64)): - raise TypeError("dt must be a np.timedelta64 object") - if dt is None or np.isnat(dt): + if dt is None: dt = np.timedelta64(1, "s") - self._data["dt"][:] = dt - sign_dt = np.sign(dt).astype(int) - if sign_dt not in [-1, 1]: - raise ValueError("dt must be a positive or negative np.timedelta64 object") - if self.fieldset.time_interval is None: - start_time = np.timedelta64(0, "s") # For the execution loop, we need a start time as a timedelta object - if runtime is None: - raise TypeError("The runtime must be provided when the time_interval is not defined for a fieldset.") + if not isinstance(dt, np.timedelta64) or np.isnat(dt) or (sign_dt := np.sign(dt).astype(int)) not in [-1, 1]: + raise ValueError(f"dt must be a positive or negative np.timedelta64 object, got {dt=!r}") - else: - if isinstance(runtime, np.timedelta64): - end_time = runtime - else: - raise TypeError("The runtime must be a np.timedelta64 object") + self._data["dt"][:] = dt - else: - if not np.isnat(self.time_nextloop).any(): - if sign_dt > 0: - start_time = self.time_nextloop.min() - else: - start_time = self.time_nextloop.max() - else: - if sign_dt > 0: - start_time = self.fieldset.time_interval.left - else: - start_time = self.fieldset.time_interval.right - - if runtime is None: - if endtime is None: - raise ValueError( - "Must provide either runtime or endtime when time_interval is defined for a fieldset." - ) - # Ensure that the endtime uses the same type as the start_time - if isinstance(endtime, self.fieldset.time_interval.left.__class__): - if sign_dt > 0: - if endtime < self.fieldset.time_interval.left: - raise ValueError("The endtime must be after the start time of the fieldset.time_interval") - end_time = min(endtime, self.fieldset.time_interval.right) - else: - if endtime > self.fieldset.time_interval.right: - raise ValueError( - "The endtime must be before the end time of the fieldset.time_interval when dt < 0" - ) - end_time = max(endtime, self.fieldset.time_interval.left) - else: - raise TypeError("The endtime must be of the same type as the fieldset.time_interval start time.") - else: - end_time = start_time + runtime * sign_dt + start_time, end_time = _get_simulation_start_and_end_times( + self.fieldset.time_interval, self._data["time_nextloop"], runtime, endtime, sign_dt + ) # Set the time of the particles if it hadn't been set on initialisation if np.isnat(self._data["time"]).any(): @@ -619,15 +578,69 @@ def _warn_particle_times_outside_fieldset_time_bounds(release_times: np.ndarray, if isinstance(time.left, np.datetime64) and isinstance(release_times[0], np.timedelta64): release_times = np.array([t + time.left for t in release_times]) - if np.any(release_times < time.left): + if np.any(release_times < time.left) or np.any(release_times > time.right): warnings.warn( "Some particles are set to be released outside the FieldSet's executable time domain.", ParticleSetWarning, stacklevel=2, ) - if np.any(release_times > time.right): - warnings.warn( - "Some particles are set to be released after the fieldset's last time and the fields are not constant in time.", - ParticleSetWarning, - stacklevel=2, + + +def _get_simulation_start_and_end_times( + time_interval: TimeInterval, + particle_release_times: np.ndarray, + runtime: np.timedelta64 | None, + endtime: np.datetime64 | None, + sign_dt: Literal[-1, 1], +) -> tuple[np.datetime64, np.datetime64]: + if runtime is not None and endtime is not None: + raise ValueError( + f"runtime and endtime are mutually exclusive - provide one or the other. Got {runtime=!r}, {endtime=!r}" ) + + if runtime is None and time_interval is None: + raise ValueError("The runtime must be provided when the time_interval is not defined for a fieldset.") + + if sign_dt == 1: + first_release_time = particle_release_times.min() + else: + first_release_time = particle_release_times.max() + + start_time = _get_start_time(first_release_time, time_interval, sign_dt, runtime) + + if endtime is None: + if not isinstance(runtime, np.timedelta64): + raise ValueError(f"The runtime must be a np.timedelta64 object. Got {type(runtime)}") + + endtime = start_time + sign_dt * runtime + + if time_interval is not None: + if type(endtime) != type(time_interval.left): # noqa: E721 + raise ValueError( + f"The endtime must be of the same type as the fieldset.time_interval start time. Got {endtime=!r} with {time_interval=!r}" + ) + if endtime not in time_interval: + msg = ( + f"Calculated/provided end time of {endtime!r} is not in fieldset time interval {time_interval!r}. Either reduce your runtime, modify your " + "provided endtime, or change your release timing." + "Important info:\n" + f" First particle release: {first_release_time!r}\n" + f" runtime: {runtime!r}\n" + f" (calculated) endtime: {endtime!r}" + ) + raise ValueError(msg) + + return start_time, endtime + + +def _get_start_time(first_release_time, time_interval, sign_dt, runtime): + if time_interval is None: + time_interval = TimeInterval(left=np.timedelta64(0, "s"), right=runtime) + + if sign_dt == 1: + fieldset_start = time_interval.left + else: + fieldset_start = time_interval.right + + start_time = first_release_time if not np.isnat(first_release_time) else fieldset_start + return start_time diff --git a/tests/v4/test_particleset.py b/tests/v4/test_particleset.py index 7c3f2058c..f3f2c48fa 100644 --- a/tests/v4/test_particleset.py +++ b/tests/v4/test_particleset.py @@ -114,21 +114,6 @@ def test_pset_create_outside_time(fieldset): ParticleSet(fieldset, pclass=Particle, lon=[0] * len(time), lat=[0] * len(time), time=time) -@pytest.mark.parametrize( - "dt, expectation", - [ - (np.timedelta64(5, "s"), does_not_raise()), - (5.0, pytest.raises(TypeError)), - (np.datetime64("2000-01-02T00:00:00"), pytest.raises(TypeError)), - (timedelta(seconds=2), pytest.raises(TypeError)), - ], -) -def test_particleset_dt_type(fieldset, dt, expectation): - pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle) - with expectation: - pset.execute(runtime=np.timedelta64(10, "s"), dt=dt, pyfunc=DoNothing) - - def test_pset_starttime_not_multiple_dt(fieldset): times = [0, 1, 2] datetimes = [fieldset.time_interval.left + np.timedelta64(t, "s") for t in times] @@ -141,38 +126,6 @@ def Addlon(particle, fieldset, time): # pragma: no cover assert np.allclose([p.lon_nextloop for p in pset], [8 - t for t in times]) -@pytest.mark.parametrize( - "runtime, expectation", - [ - (np.timedelta64(5, "s"), does_not_raise()), - (5.0, pytest.raises(TypeError)), - (timedelta(seconds=2), pytest.raises(TypeError)), - (np.datetime64("2001-01-02T00:00:00"), pytest.raises(TypeError)), - (datetime(2000, 1, 2, 0, 0, 0), pytest.raises(TypeError)), - ], -) -def test_particleset_runtime_type(fieldset, runtime, expectation): - pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle) - with expectation: - pset.execute(runtime=runtime, dt=np.timedelta64(10, "s"), pyfunc=DoNothing) - - -@pytest.mark.parametrize( - "endtime, expectation", - [ - (np.datetime64("2000-01-02T00:00:00"), does_not_raise()), - (5.0, pytest.raises(TypeError)), - (np.timedelta64(5, "s"), pytest.raises(TypeError)), - (timedelta(seconds=2), pytest.raises(TypeError)), - (datetime(2000, 1, 2, 0, 0, 0), pytest.raises(TypeError)), - ], -) -def test_particleset_endtime_type(fieldset, endtime, expectation): - pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle) - with expectation: - pset.execute(endtime=endtime, dt=np.timedelta64(10, "m"), pyfunc=DoNothing) - - def test_pset_add_explicit(fieldset): npart = 11 lon = np.linspace(0, 1, npart) diff --git a/tests/v4/test_particleset_execute.py b/tests/v4/test_particleset_execute.py index 68a699626..af631da29 100644 --- a/tests/v4/test_particleset_execute.py +++ b/tests/v4/test_particleset_execute.py @@ -1,3 +1,6 @@ +from contextlib import nullcontext as does_not_raise +from datetime import datetime, timedelta + import numpy as np import pytest @@ -27,7 +30,20 @@ def fieldset() -> FieldSet: grid = XGrid.from_dataset(ds, mesh="flat") U = Field("U", ds["U (A grid)"], grid) V = Field("V", ds["V (A grid)"], grid) - return FieldSet([U, V]) + UV = VectorField("UV", U, V) + return FieldSet([U, V, UV]) + + +@pytest.fixture +def fieldset_no_time_interval() -> FieldSet: + # i.e., no time variation + ds = datasets_structured["ds_2d_left"].isel(time=0).drop("time") + + grid = XGrid.from_dataset(ds, mesh="flat") + U = Field("U", ds["U (A grid)"], grid) + V = Field("V", ds["V (A grid)"], grid) + UV = VectorField("UV", U, V) + return FieldSet([U, V, UV]) @pytest.fixture @@ -41,6 +57,98 @@ def zonal_flow_fieldset() -> FieldSet: return FieldSet([U, V, UV]) +def test_pset_execute_implicit_dt_one_second(fieldset): + pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle) + pset.execute(DoNothing, runtime=np.timedelta64(1, "s")) + + time = pset.time.copy() + + pset.execute(DoNothing, runtime=np.timedelta64(1, "s")) + np.testing.assert_array_equal(pset.time, time + np.timedelta64(1, "s")) + + +def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval): + for dt in [1, np.timedelta64(0, "s"), np.timedelta64(None)]: + with pytest.raises( + ValueError, + match="dt must be a positive or negative np.timedelta64 object, got .*", + ): + ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(dt=dt) + + with pytest.raises( + ValueError, + match="runtime and endtime are mutually exclusive - provide one or the other. Got .*", + ): + ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute( + runtime=np.timedelta64(1, "s"), endtime=np.datetime64("2100-01-01") + ) + + with pytest.raises( + ValueError, + match="The runtime must be a np.timedelta64 object. Got .*", + ): + ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(runtime=1) + + msg = """Calculated/provided end time of .* is not in fieldset time interval .* Either reduce your runtime, modify your provided endtime, or change your release timing.*""" + with pytest.raises( + ValueError, + match=msg, + ): + ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=np.datetime64("1990-01-01")) + + with pytest.raises( + ValueError, + match=msg, + ): + ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute( + endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(-1, "s") + ) + + with pytest.raises( + ValueError, + match="The endtime must be of the same type as the fieldset.time_interval start time. Got .*", + ): + ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=12345) + + with pytest.raises( + ValueError, + match="The runtime must be provided when the time_interval is not defined for a fieldset.", + ): + ParticleSet(fieldset_no_time_interval, lon=[0.2], lat=[5.0], pclass=Particle).execute() + + +@pytest.mark.parametrize( + "runtime, expectation", + [ + (np.timedelta64(5, "s"), does_not_raise()), + (5.0, pytest.raises(ValueError)), + (timedelta(seconds=2), pytest.raises(ValueError)), + (np.datetime64("2001-01-02T00:00:00"), pytest.raises(ValueError)), + (datetime(2000, 1, 2, 0, 0, 0), pytest.raises(ValueError)), + ], +) +def test_particleset_runtime_type(fieldset, runtime, expectation): + pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle) + with expectation: + pset.execute(runtime=runtime, dt=np.timedelta64(10, "s"), pyfunc=DoNothing) + + +@pytest.mark.parametrize( + "endtime, expectation", + [ + (np.datetime64("2000-01-02T00:00:00"), does_not_raise()), + (5.0, pytest.raises(ValueError)), + (np.timedelta64(5, "s"), pytest.raises(ValueError)), + (timedelta(seconds=2), pytest.raises(ValueError)), + (datetime(2000, 1, 2, 0, 0, 0), pytest.raises(ValueError)), + ], +) +def test_particleset_endtime_type(fieldset, endtime, expectation): + pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle) + with expectation: + pset.execute(endtime=endtime, dt=np.timedelta64(10, "m"), pyfunc=DoNothing) + + def test_pset_remove_particle_in_kernel(fieldset): npart = 100 pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart)) @@ -92,7 +200,8 @@ def AddLat(particle, fieldset, time): # pragma: no cover def test_execution_endtime(fieldset, starttime, endtime, dt): starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s") endtime = fieldset.time_interval.left + np.timedelta64(endtime, "s") - dt = np.timedelta64(dt, "s") + if dt is not None: + dt = np.timedelta64(dt, "s") pset = ParticleSet(fieldset, time=starttime, lon=0, lat=0) pset.execute(DoNothing, endtime=endtime, dt=dt) assert abs(pset.time_nextloop - endtime) < np.timedelta64(1, "ms") @@ -152,10 +261,10 @@ def test_some_particles_throw_outoftime(fieldset): pset = ParticleSet(fieldset, lon=np.zeros_like(time), lat=np.zeros_like(time), time=time) def FieldAccessOutsideTime(particle, fieldset, time): # pragma: no cover - fieldset.U[particle.time + np.timedelta64(1, "D"), particle.depth, particle.lat, particle.lon, particle] + fieldset.U[particle.time + np.timedelta64(400, "D"), particle.depth, particle.lat, particle.lon, particle] with pytest.raises(TimeExtrapolationError): - pset.execute(FieldAccessOutsideTime, runtime=np.timedelta64(400, "D"), dt=np.timedelta64(10, "D")) + pset.execute(FieldAccessOutsideTime, runtime=np.timedelta64(1, "D"), dt=np.timedelta64(10, "D")) def test_execution_check_stopallexecution(fieldset): @@ -200,7 +309,8 @@ def test_execution_runtime(fieldset, starttime, runtime, dt, npart): starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s") runtime = np.timedelta64(runtime, "s") sign_dt = 1 if dt is None else np.sign(dt) - dt = np.timedelta64(dt, "s") + if dt is not None: + dt = np.timedelta64(dt, "s") pset = ParticleSet(fieldset, time=starttime, lon=np.zeros(npart), lat=np.zeros(npart)) pset.execute(DoNothing, runtime=runtime, dt=dt) assert all([abs(p.time_nextloop - starttime - runtime * sign_dt) < np.timedelta64(1, "ms") for p in pset])