Skip to content

Commit 1002bd0

Browse files
Merge pull request #2242 from OceanParcels/subsecond_dt_support
Adding support for subsecond dt values in pset.execute
2 parents 65be282 + 061f996 commit 1002bd0

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

parcels/particleset.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,19 @@ def _num_error_particles(self):
425425
"""
426426
return np.sum(np.isin(self._data["state"], [StatusCode.Success, StatusCode.Evaluate], invert=True))
427427

428+
def update_dt_dtype(self, dt_dtype: np.dtype):
429+
"""Update the dtype of dt
430+
431+
Parameters
432+
----------
433+
dt_dtype : np.dtype
434+
New dtype for dt.
435+
"""
436+
if dt_dtype not in [np.timedelta64, "timedelta64[ns]", "timedelta64[ms]", "timedelta64[s]"]:
437+
raise ValueError(f"dt_dtype must be a numpy timedelta64 dtype. Got {dt_dtype=!r}")
438+
439+
self._data["dt"] = self._data["dt"].astype(dt_dtype)
440+
428441
def set_variable_write_status(self, var, write_status):
429442
"""Method to set the write status of a Variable.
430443
@@ -500,6 +513,17 @@ def execute(
500513
except (ValueError, AssertionError) as e:
501514
raise ValueError(f"dt must be a non-zero datetime.timedelta or np.timedelta64 object, got {dt=!r}") from e
502515

516+
# Check if particle dt has finer resolution than input dt
517+
particle_resolution = np.timedelta64(1, np.datetime_data(self._data["dt"].dtype))
518+
input_resolution = np.timedelta64(1, np.datetime_data(dt.dtype))
519+
520+
if input_resolution >= particle_resolution:
521+
self._data["dt"][:] = dt
522+
else:
523+
raise ValueError(
524+
f"The dtype of dt ({dt.dtype}) is coarser than the dtype of the particle dt ({self._data['dt'].dtype}). Please use ParticleSet.set_dt_dtype() to provide a dt with at least the same precision as the particle dt."
525+
)
526+
503527
if runtime is not None:
504528
try:
505529
runtime = maybe_convert_python_timedelta_to_numpy(runtime)
@@ -508,8 +532,6 @@ def execute(
508532
f"The runtime must be a datetime.timedelta or np.timedelta64 object. Got {type(runtime)}"
509533
) from e
510534

511-
self._data["dt"][:] = dt
512-
513535
start_time, end_time = _get_simulation_start_and_end_times(
514536
self.fieldset.time_interval, self._data["time_nextloop"], runtime, endtime, sign_dt
515537
)

tests/v4/test_particleset_execute.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ParticleSet,
1313
StatusCode,
1414
UXPiecewiseConstantFace,
15+
Variable,
1516
VectorField,
1617
)
1718
from parcels._datasets.structured.generated import simple_UV_dataset
@@ -150,6 +151,27 @@ def test_particleset_endtime_type(fieldset, endtime, expectation):
150151
pset.execute(endtime=endtime, dt=np.timedelta64(10, "m"), pyfunc=DoNothing)
151152

152153

154+
@pytest.mark.parametrize(
155+
"dt", [np.timedelta64(1, "s"), np.timedelta64(1, "ms"), np.timedelta64(10, "ms"), np.timedelta64(1, "ns")]
156+
)
157+
def test_pset_execute_subsecond_dt(fieldset, dt):
158+
def AddDt(particles, fieldset): # pragma: no cover
159+
dt = particles.dt / np.timedelta64(1, "s")
160+
particles.added_dt += dt
161+
162+
pclass = Particle.add_variable(Variable("added_dt", dtype=np.float32, initial=0))
163+
pset = ParticleSet(fieldset, pclass=pclass, lon=0, lat=0)
164+
pset.update_dt_dtype(dt.dtype)
165+
pset.execute(AddDt, runtime=dt * 10, dt=dt)
166+
np.testing.assert_allclose(pset[0].added_dt, 10.0 * dt / np.timedelta64(1, "s"), atol=1e-5)
167+
168+
169+
def test_pset_execute_subsecond_dt_error(fieldset):
170+
pset = ParticleSet(fieldset, lon=0, lat=0)
171+
with pytest.raises(ValueError, match="The dtype of dt"):
172+
pset.execute(DoNothing, runtime=np.timedelta64(10, "ms"), dt=np.timedelta64(1, "ms"))
173+
174+
153175
def test_pset_remove_particle_in_kernel(fieldset):
154176
npart = 100
155177
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart))

0 commit comments

Comments
 (0)