Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,19 @@ def _num_error_particles(self):
"""
return np.sum(np.isin(self._data["state"], [StatusCode.Success, StatusCode.Evaluate], invert=True))

def update_dt_dtype(self, dt_dtype: np.dtype):
"""Update the dtype of dt

Parameters
----------
dt_dtype : np.dtype
New dtype for dt.
"""
if dt_dtype not in [np.timedelta64, "timedelta64[ns]", "timedelta64[ms]", "timedelta64[s]"]:
raise ValueError(f"dt_dtype must be a numpy timedelta64 dtype. Got {dt_dtype=!r}")

self._data["dt"] = self._data["dt"].astype(dt_dtype)

def set_variable_write_status(self, var, write_status):
"""Method to set the write status of a Variable.

Expand Down Expand Up @@ -500,6 +513,17 @@ def execute(
except (ValueError, AssertionError) as e:
raise ValueError(f"dt must be a non-zero datetime.timedelta or np.timedelta64 object, got {dt=!r}") from e

# Check if particle dt has finer resolution than input dt
particle_resolution = np.timedelta64(1, np.datetime_data(self._data["dt"].dtype))
input_resolution = np.timedelta64(1, np.datetime_data(dt.dtype))

if input_resolution >= particle_resolution:
self._data["dt"][:] = dt
else:
raise ValueError(
f"The dtype of dt ({dt.dtype}) is coarser than the dtype of the particle dt ({self._data['dt'].dtype}). Please use ParticleSet.set_dt_dtype() to provide a dt with at least the same precision as the particle dt."
)

if runtime is not None:
try:
runtime = maybe_convert_python_timedelta_to_numpy(runtime)
Expand All @@ -508,8 +532,6 @@ def execute(
f"The runtime must be a datetime.timedelta or np.timedelta64 object. Got {type(runtime)}"
) from e

self._data["dt"][:] = dt

start_time, end_time = _get_simulation_start_and_end_times(
self.fieldset.time_interval, self._data["time_nextloop"], runtime, endtime, sign_dt
)
Expand Down
22 changes: 22 additions & 0 deletions tests/v4/test_particleset_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ParticleSet,
StatusCode,
UXPiecewiseConstantFace,
Variable,
VectorField,
)
from parcels._datasets.structured.generated import simple_UV_dataset
Expand Down Expand Up @@ -150,6 +151,27 @@ def test_particleset_endtime_type(fieldset, endtime, expectation):
pset.execute(endtime=endtime, dt=np.timedelta64(10, "m"), pyfunc=DoNothing)


@pytest.mark.parametrize(
"dt", [np.timedelta64(1, "s"), np.timedelta64(1, "ms"), np.timedelta64(10, "ms"), np.timedelta64(1, "ns")]
)
def test_pset_execute_subsecond_dt(fieldset, dt):
def AddDt(particles, fieldset): # pragma: no cover
dt = particles.dt / np.timedelta64(1, "s")
particles.added_dt += dt

pclass = Particle.add_variable(Variable("added_dt", dtype=np.float32, initial=0))
pset = ParticleSet(fieldset, pclass=pclass, lon=0, lat=0)
pset.update_dt_dtype(dt.dtype)
pset.execute(AddDt, runtime=dt * 10, dt=dt)
np.testing.assert_allclose(pset[0].added_dt, 10.0 * dt / np.timedelta64(1, "s"), atol=1e-5)


def test_pset_execute_subsecond_dt_error(fieldset):
pset = ParticleSet(fieldset, lon=0, lat=0)
with pytest.raises(ValueError, match="The dtype of dt"):
pset.execute(DoNothing, runtime=np.timedelta64(10, "ms"), dt=np.timedelta64(1, "ms"))


def test_pset_remove_particle_in_kernel(fieldset):
npart = 100
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart))
Expand Down
Loading