Skip to content

Commit 9ceca1d

Browse files
committed
WIP
1 parent 6dffa19 commit 9ceca1d

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

parcels/particlefile.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in
148148
# stacklevel=2,
149149
# )
150150
# return
151-
151+
nparticles = len(particle_data["trajectory"])
152152
vars_to_write = _get_vars_to_write(pclass)
153153
if indices is None:
154154
indices_to_write = _to_write_particles(particle_data, time)
@@ -173,7 +173,7 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in
173173
store = self.store
174174
if self.create_new_zarrfile:
175175
if self.chunks is None:
176-
self._chunks = (len(pset), 1)
176+
self._chunks = (nparticles, 1)
177177
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index]
178178
arrsize = (self._maxids, self.chunks[1]) # type: ignore[index]
179179
else:
@@ -280,10 +280,10 @@ def _to_write_particles(particle_data, time):
280280
return np.where(
281281
(
282282
np.less_equal(
283-
time - np.abs(particle_data["dt"] / 2), particle_data["time"], where=np.isfinite(particle_data["time"])
283+
time - np.abs(particle_data["dt"]), particle_data["time"], where=np.isfinite(particle_data["time"])
284284
)
285285
& np.greater_equal(
286-
time + np.abs(particle_data["dt"] / 2), particle_data["time"], where=np.isfinite(particle_data["time"])
286+
time + np.abs(particle_data["dt"]), particle_data["time"], where=np.isfinite(particle_data["time"])
287287
)
288288
| (
289289
(np.isnan(particle_data["dt"]))
@@ -307,7 +307,9 @@ def _convert_particle_data_time_to_float_seconds(particle_data, time_interval):
307307
def _maybe_convert_time_dtype(dtype: np.dtype | _SAME_AS_FIELDSET_TIME_INTERVAL) -> np.dtype:
308308
"""Convert the dtype of time to float64 if it is not already."""
309309
if dtype is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE:
310-
return np.dtype(np.float64)
310+
return np.dtype(
311+
np.uint64
312+
) #! We need to have here some proper mechanism for converting particle data to the data that is to be output to zarr (namely the time needs to be converted to float seconds by subtracting the time_interval.left)
311313
return dtype
312314

313315

tests/v4/test_particlefile.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
import parcels
1111
from parcels import AdvectionRK4, Field, FieldSet, Particle, ParticleSet, Variable, VectorField
12+
from parcels._core.utils.time import TimeInterval
1213
from parcels._datasets.structured.generic import datasets
14+
from parcels.particle import Particle, create_particle_data
1315
from parcels.particlefile import ParticleFile
14-
from parcels.particle import create_particle_data
15-
from parcels._core.utils.time import TimeInterval
1616
from parcels.xgrid import XGrid
1717
from tests.common_kernels import DoNothing
1818

@@ -366,4 +366,48 @@ def test_particlefile_init_invalid(store): # TODO: Add test for read only store
366366

367367

368368
@pytest.mark.new
369-
def test_particlefile_writing(store): ...
369+
def test_particlefile_write_particle_data(store):
370+
nparticles = 100
371+
372+
pfile = ParticleFile(store, outputdt=np.timedelta64(1, "s"), chunks=(nparticles, 40))
373+
pclass = Particle
374+
375+
left, right = np.datetime64("2020-01-01", "ns"), np.datetime64("2020-01-02", "ns")
376+
time_interval = TimeInterval(left=left, right=right)
377+
378+
initial_lon = np.linspace(0, 1, nparticles)
379+
data = create_particle_data(
380+
pclass=pclass,
381+
nparticles=nparticles,
382+
ngrids=4,
383+
time_interval=time_interval,
384+
initial={
385+
"time": np.full(nparticles, fill_value=left),
386+
"lon": initial_lon,
387+
"dt": np.full(nparticles, fill_value=1.0),
388+
"trajectory": np.arange(nparticles),
389+
},
390+
)
391+
np.testing.assert_array_equal(data["time"], left)
392+
pfile._write_particle_data(
393+
particle_data=data,
394+
pclass=pclass,
395+
time_interval=time_interval,
396+
time=left,
397+
)
398+
ds = xr.open_zarr(store, decode_cf=False)
399+
# ds.time.attrs
400+
401+
def patched_time_attrs(attrs):
402+
units = attrs["units"]
403+
units = units[: units.index("T")]
404+
attrs["units"] = units
405+
return attrs
406+
407+
ds.time.attrs = patched_time_attrs(ds.time.attrs)
408+
breakpoint()
409+
ds = xr.decode_cf(ds)
410+
assert ds.sizes["trajectory"] == nparticles
411+
assert ds.time.dtype == "datetime64[ns]"
412+
np.testing.assert_equal(ds["time"].isel(obs=0).values, left)
413+
np.testing.assert_equal(ds["lon"].isel(obs=0).values, initial_lon)

0 commit comments

Comments
 (0)