Skip to content

Commit 5bcf378

Browse files
Also supporting outputdt as float
1 parent 3d21467 commit 5bcf378

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

src/parcels/_core/particlefile.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class ParticleFile:
5252
outputdt :
5353
Interval which dictates the update frequency of file output
5454
while ParticleFile is given as an argument of ParticleSet.execute()
55-
It is either a timedelta object or a positive double.
55+
It is either a numpy.timedelta64, a datimetime.timedelta object or a positive float (in seconds).
5656
chunks :
5757
Tuple (trajs, obs) to control the size of chunks in the zarr output.
5858
create_new_zarrfile : bool
@@ -65,14 +65,15 @@ class ParticleFile:
6565
"""
6666

6767
def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True):
68-
if isinstance(outputdt, timedelta):
69-
outputdt = np.timedelta64(int(outputdt.total_seconds()), "s")
68+
if not isinstance(outputdt, (np.timedelta64, timedelta, float)):
69+
raise ValueError(
70+
f"Expected outputdt to be a np.timedelta64, datetime.timedelta or float (in seconds), got {type(outputdt)}"
71+
)
7072

71-
if not isinstance(outputdt, np.timedelta64):
72-
raise ValueError(f"Expected outputdt to be a np.timedelta64 or datetime.timedelta, got {type(outputdt)}")
73+
outputdt = timedelta_to_float(outputdt)
7374

74-
if outputdt <= np.timedelta64(0, "s"):
75-
raise ValueError(f"outputdt must be a positive non-zero timedelta. Got {outputdt=!r}")
75+
if outputdt <= 0:
76+
raise ValueError(f"outputdt must be positive/non-zero. Got {outputdt=!r}")
7677

7778
self._outputdt = outputdt
7879

src/parcels/_core/particleset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def execute(
492492
if np.isnan(self._data["time"]).any():
493493
self._data["time"][:] = start_time
494494

495-
outputdt = timedelta_to_float(output_file.outputdt) if output_file else None
495+
outputdt = output_file.outputdt if output_file else None
496496
_warn_outputdt_release_desync(outputdt, start_time, self._data["time"][:])
497497

498498
# Set up pbar

tests/test_particlefile.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import tempfile
3-
from datetime import timedelta
3+
from contextlib import nullcontext as does_not_raise
4+
from datetime import datetime, timedelta
45

56
import numpy as np
67
import pytest
@@ -231,6 +232,23 @@ def test_file_warnings(fieldset, tmp_zarrfile):
231232
pset.execute(AdvectionRK4, runtime=3, dt=1, output_file=pfile)
232233

233234

235+
@pytest.mark.parametrize(
236+
"outputdt, expectation",
237+
[
238+
(np.timedelta64(5, "s"), does_not_raise()),
239+
(timedelta(seconds=2), does_not_raise()),
240+
(5.0, does_not_raise()),
241+
(np.datetime64("2001-01-02T00:00:00"), pytest.raises(ValueError)),
242+
(datetime(2000, 1, 2, 0, 0, 0), pytest.raises(ValueError)),
243+
(-np.timedelta64(5, "s"), pytest.raises(ValueError)),
244+
],
245+
)
246+
def test_outputdt_types(outputdt, expectation, tmp_zarrfile):
247+
with expectation:
248+
pfile = ParticleFile(tmp_zarrfile, outputdt=outputdt)
249+
assert pfile.outputdt == timedelta_to_float(outputdt)
250+
251+
234252
def test_write_timebackward(fieldset, tmp_zarrfile):
235253
release_time = fieldset.time_interval.left + [np.timedelta64(i + 1, "s") for i in range(3)]
236254
pset = ParticleSet(fieldset, lat=[0, 1, 2], lon=[0, 0, 0], time=release_time)

0 commit comments

Comments
 (0)