Skip to content

Commit a47b9f6

Browse files
committed
Restrict outputdt to be datetime or timedelta
1 parent 9f1d24f commit a47b9f6

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

parcels/particlefile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ class ParticleFile:
5151
"""
5252

5353
def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True):
54-
self._outputdt = timedelta_to_float(outputdt)
54+
if not isinstance(outputdt, (np.datetime64, np.timedelta64)):
55+
raise ValueError(f"Expected outputdt to be a np.timedelta64 or datetime64, got {type(outputdt)}")
56+
57+
self._outputdt = outputdt
5558

5659
_assert_valid_chunks_tuple(chunks)
5760
self._chunks = chunks

tests/v4/test_particlefile.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remov
2222
"""Fixture to create a FieldSet object for testing."""
2323
ds = datasets["ds_2d_left"]
2424
grid = XGrid.from_dataset(ds)
25-
U = Field("U", ds["U (A grid)"], grid, mesh_type="flat")
26-
V = Field("V", ds["V (A grid)"], grid, mesh_type="flat")
25+
U = Field("U", ds["U (A grid)"], grid)
26+
V = Field("V", ds["V (A grid)"], grid)
2727
UV = VectorField("UV", U, V)
2828

2929
return FieldSet(
@@ -191,11 +191,18 @@ def IncrLon(particle, fieldset, time): # pragma: no cover
191191

192192
def test_write_timebackward(fieldset, tmp_zarrfile):
193193
def Update_lon(particle, fieldset, time): # pragma: no cover
194-
particle.dlon -= 0.1 * particle.dt
194+
dt = particle.dt / np.timedelta64(1, "s")
195+
particle.dlon -= 0.1 * dt
195196

196-
pset = ParticleSet(fieldset, pclass=Particle, lat=np.linspace(0, 1, 3), lon=[0, 0, 0], time=[1, 2, 3])
197-
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1.0)
198-
pset.execute(pset.Kernel(Update_lon), runtime=4, dt=-1.0, output_file=pfile)
197+
pset = ParticleSet(
198+
fieldset,
199+
pclass=Particle,
200+
lat=np.linspace(0, 1, 3),
201+
lon=[0, 0, 0],
202+
time=np.array([np.datetime64("2000-01-01") for _ in range(3)]),
203+
)
204+
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))
205+
pset.execute(pset.Kernel(Update_lon), runtime=np.timedelta64(1, "s"), dt=-np.timedelta64(1, "s"), output_file=pfile)
199206
ds = xr.open_zarr(tmp_zarrfile)
200207
trajs = ds["trajectory"][:]
201208
assert trajs.values.dtype == "int64"

0 commit comments

Comments
 (0)