From 9e0c7ce1f2b4ee60db51b135333c75df6f327c3c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 5 Aug 2025 13:16:49 +0200 Subject: [PATCH 01/40] Reorganising particlefile.py --- parcels/particlefile.py | 78 +++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 42 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 33451caae..d7318e977 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -5,7 +5,7 @@ import os import warnings from datetime import timedelta -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np import xarray as xr @@ -14,11 +14,14 @@ import parcels from parcels._reprs import default_repr +from parcels.particle import ParticleClass from parcels.tools._helpers import timedelta_to_float if TYPE_CHECKING: from pathlib import Path + from parcels.particleset import ParticleSet + __all__ = ["ParticleFile"] _DATATYPES_TO_FILL_VALUES = { @@ -65,29 +68,10 @@ def __init__(self, store, particleset, outputdt, chunks=None, create_new_zarrfil self._outputdt = timedelta_to_float(outputdt) self._chunks = chunks self._particleset = particleset - self._parcels_mesh = "spherical" - if self.particleset.fieldset is not None: - self._parcels_mesh = self.particleset.fieldset.gridset[0].mesh - self.lonlatdepth_dtype = self.particleset.particledata.lonlatdepth_dtype self._maxids = 0 self._pids_written = {} - self._create_new_zarrfile = create_new_zarrfile - self._vars_to_write = {} - for var in self.particleset.particledata.ptype.variables: - if var.to_write: - self.vars_to_write[var.name] = var.dtype - self.particleset.fieldset._particlefile = self - - # Reset obs_written of each particle, in case new ParticleFile created for a ParticleSet - particleset.particledata.setallvardata("obs_written", 0) - - self.metadata = { - "feature_type": "trajectory", - "Conventions": "CF-1.6/CF-1.7", - "ncei_template_version": "NCEI_NetCDF_Trajectory_Template_v2.0", - "parcels_version": parcels.__version__, - "parcels_mesh": self._parcels_mesh, - } + self.metadata = None + self.create_new_zarrfile = create_new_zarrfile if isinstance(store, zarr.storage.Store): self.store = store @@ -109,9 +93,14 @@ def __repr__(self) -> str: f"create_new_zarrfile={self.create_new_zarrfile!r})" ) - @property - def create_new_zarrfile(self): - return self._create_new_zarrfile + def set_metadata(self, parcels_mesh: Literal["spherical", "flat"]): + self.metadata = { + "feature_type": "trajectory", + "Conventions": "CF-1.6/CF-1.7", + "ncei_template_version": "NCEI_NetCDF_Trajectory_Template_v2.0", + "parcels_version": parcels.__version__, + "parcels_mesh": parcels_mesh, + } @property def outputdt(self): @@ -129,11 +118,7 @@ def particleset(self): def fname(self): return self._fname - @property - def vars_to_write(self): - return self._vars_to_write - - def _create_variables_attribute_dict(self): + def _create_variables_attribute_dict(self, particle: ParticleClass): """Creates the dictionary with variable attributes. Notes @@ -190,7 +175,7 @@ def _extend_zarr_dims(self, Z, store, dtype, axis): Z.append(a, axis=axis) zarr.consolidate_metadata(store) - def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=None): + def write(self, pset: ParticleSet, time: float | timedelta | np.timedelta64 | None, indices=None): """Write all data from one time step to the zarr file, before the particle locations are updated. @@ -202,7 +187,8 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N Time at which to write ParticleSet """ time = timedelta_to_float(time) if time is not None else None - + pclass = pset._ptype + vars_to_write = _get_vars_to_write(pclass) if pset.particledata._ncount == 0: warnings.warn( f"ParticleSet is empty on writing as array at time {time:g}", @@ -245,40 +231,38 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N ) attrs = self._create_variables_attribute_dict() obs = np.zeros((self._maxids), dtype=np.int32) - for var in self.vars_to_write: + for var in vars_to_write: varout = self._convert_varout_name(var) if varout not in ["trajectory"]: # because 'trajectory' is written as coordinate if self._write_once(var): data = np.full( (arrsize[0],), - _DATATYPES_TO_FILL_VALUES[self.vars_to_write[var]], - dtype=self.vars_to_write[var], + _DATATYPES_TO_FILL_VALUES[vars_to_write[var]], + dtype=vars_to_write[var], ) data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once) dims = ["trajectory"] else: - data = np.full( - arrsize, _DATATYPES_TO_FILL_VALUES[self.vars_to_write[var]], dtype=self.vars_to_write[var] - ) + data = np.full(arrsize, _DATATYPES_TO_FILL_VALUES[vars_to_write[var]], dtype=vars_to_write[var]) data[ids, 0] = pset.particledata.getvardata(var, indices_to_write) dims = ["trajectory", "obs"] ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout]) ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index] ds.to_zarr(store, mode="w") - self._create_new_zarrfile = False + self.create_new_zarrfile = False else: Z = zarr.group(store=store, overwrite=False) obs = pset.particledata.getvardata("obs_written", indices_to_write) - for var in self.vars_to_write: + for var in vars_to_write: varout = self._convert_varout_name(var) if self._maxids > Z[varout].shape[0]: - self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=0) + self._extend_zarr_dims(Z[varout], store, dtype=vars_to_write[var], axis=0) if self._write_once(var): if len(once_ids) > 0: Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once) else: if max(obs) >= Z[varout].shape[1]: # type: ignore[type-var] - self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=1) + self._extend_zarr_dims(Z[varout], store, dtype=vars_to_write[var], axis=1) Z[varout].vindex[ids, obs] = pset.particledata.getvardata(var, indices_to_write) pset.particledata.setvardata("obs_written", indices_to_write, obs + 1) @@ -308,3 +292,13 @@ def _get_store_from_pathlike(path: Path | str) -> DirectoryStore: raise ValueError(f"ParticleFile name must end with '.zarr' extension. Got path {path!r}.") return DirectoryStore(path) + + +def _get_vars_to_write(particle: ParticleClass): + ret = {} + for var in particle.variables: + if var.to_write is False: + continue + ret[var.name] = var.dtype + + return ret From 9cb867c5ac1328777a29ea1889ef5017c633507f Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 5 Aug 2025 13:48:01 +0200 Subject: [PATCH 02/40] Move particle attr metadata to particle class --- parcels/_constants.py | 16 ++++++++++++ parcels/particle.py | 34 ++++++++++++++++++++---- parcels/particlefile.py | 57 ++++++++++++----------------------------- 3 files changed, 61 insertions(+), 46 deletions(-) create mode 100644 parcels/_constants.py diff --git a/parcels/_constants.py b/parcels/_constants.py new file mode 100644 index 000000000..b49b1f2f5 --- /dev/null +++ b/parcels/_constants.py @@ -0,0 +1,16 @@ +import numpy as np + +DATATYPES_TO_FILL_VALUES = { + np.float16: np.nan, + np.float32: np.nan, + np.float64: np.nan, + np.bool_: np.iinfo(np.int8).max, + np.int8: np.iinfo(np.int8).max, + np.int16: np.iinfo(np.int16).max, + np.int32: np.iinfo(np.int32).max, + np.int64: np.iinfo(np.int64).max, + np.uint8: np.iinfo(np.uint8).max, + np.uint16: np.iinfo(np.uint16).max, + np.uint32: np.iinfo(np.uint32).max, + np.uint64: np.iinfo(np.uint64).max, +} diff --git a/parcels/particle.py b/parcels/particle.py index 8a51e3680..50ba66114 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -165,18 +165,42 @@ def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClas return ParticleClass( variables=[ - Variable("lon", dtype=spatial_dtype), + Variable( + "lon", + dtype=spatial_dtype, + attrs={"long_name": "", "standard_name": "longitude", "units": "degrees_east", "axis": "X"}, + ), Variable("lon_nextloop", dtype=spatial_dtype, to_write=False), - Variable("lat", dtype=spatial_dtype), + Variable( + "lat", + dtype=spatial_dtype, + attrs={"long_name": "", "standard_name": "latitude", "units": "degrees_north", "axis": "Y"}, + ), Variable("lat_nextloop", dtype=spatial_dtype, to_write=False), - Variable("depth", dtype=spatial_dtype), + Variable( + "depth", + dtype=spatial_dtype, + attrs={"long_name": "", "standard_name": "depth", "units": "m", "positive": "down"}, + ), Variable("dlon", dtype=spatial_dtype, to_write=False), Variable("dlat", dtype=spatial_dtype, to_write=False), Variable("ddepth", dtype=spatial_dtype, to_write=False), Variable("depth_nextloop", dtype=spatial_dtype, to_write=False), - Variable("time", dtype=_SAME_AS_FIELDSET_TIME_INTERVAL.VALUE), + Variable( + "time", + dtype=_SAME_AS_FIELDSET_TIME_INTERVAL.VALUE, + attrs={"long_name": "", "standard_name": "time", "units": "seconds", "axis": "T"}, + ), Variable("time_nextloop", dtype=_SAME_AS_FIELDSET_TIME_INTERVAL.VALUE, to_write=False), - Variable("trajectory", dtype=np.int64, to_write="once"), + Variable( + "trajectory", + dtype=np.int64, + to_write="once", + attrs={ + "long_name": "Unique identifier for each particle", + "cf_role": "trajectory_id", + }, + ), Variable("obs_written", dtype=np.int32, initial=0, to_write=False), Variable("dt", dtype="timedelta64[s]", initial=np.timedelta64(1, "s"), to_write=False), Variable("state", dtype=np.int32, initial=StatusCode.Evaluate, to_write=False), diff --git a/parcels/particlefile.py b/parcels/particlefile.py index d7318e977..72e12508a 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -13,8 +13,9 @@ from zarr.storage import DirectoryStore import parcels +from parcels._constants import DATATYPES_TO_FILL_VALUES from parcels._reprs import default_repr -from parcels.particle import ParticleClass +from parcels.particle import _SAME_AS_FIELDSET_TIME_INTERVAL, ParticleClass from parcels.tools._helpers import timedelta_to_float if TYPE_CHECKING: @@ -24,21 +25,6 @@ __all__ = ["ParticleFile"] -_DATATYPES_TO_FILL_VALUES = { - np.float16: np.nan, - np.float32: np.nan, - np.float64: np.nan, - np.bool_: np.iinfo(np.int8).max, - np.int8: np.iinfo(np.int8).max, - np.int16: np.iinfo(np.int16).max, - np.int32: np.iinfo(np.int32).max, - np.int64: np.iinfo(np.int64).max, - np.uint8: np.iinfo(np.uint8).max, - np.uint16: np.iinfo(np.uint16).max, - np.uint32: np.iinfo(np.uint32).max, - np.uint64: np.iinfo(np.uint64).max, -} - class ParticleFile: """Initialise trajectory output. @@ -125,30 +111,19 @@ def _create_variables_attribute_dict(self, particle: ParticleClass): ----- For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden. """ - attrs = { - "z": {"long_name": "", "standard_name": "depth", "units": "m", "positive": "down"}, - "trajectory": { - "long_name": "Unique identifier for each particle", - "cf_role": "trajectory_id", - "_FillValue": _DATATYPES_TO_FILL_VALUES[np.int64], - }, - "time": {"long_name": "", "standard_name": "time", "units": "seconds", "axis": "T"}, - "lon": {"long_name": "", "standard_name": "longitude", "units": "degrees_east", "axis": "X"}, - "lat": {"long_name": "", "standard_name": "latitude", "units": "degrees_north", "axis": "Y"}, - } + attrs = {} + + vars = [var for var in particle.variables if var.to_write is not False] + for var in vars: + fill_value = {} + if var.dtype is not _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: + fill_value = {"_FillValue": DATATYPES_TO_FILL_VALUES[var.dtype]} + + attrs[var.name] = {**var.attrs, **fill_value} attrs["time"]["units"] = "seconds since " # TODO fix units attrs["time"]["calendar"] = "None" # TODO fix calendar - for vname in self.vars_to_write: - if vname not in ["time", "lat", "lon", "depth", "trajectory"]: - attrs[vname] = { - "_FillValue": _DATATYPES_TO_FILL_VALUES[self.vars_to_write[vname]], - "long_name": "", - "standard_name": vname, - "units": "unknown", - } - return attrs def _convert_varout_name(self, var): @@ -162,16 +137,16 @@ def _write_once(self, var): def _extend_zarr_dims(self, Z, store, dtype, axis): if axis == 1: - a = np.full((Z.shape[0], self.chunks[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) + a = np.full((Z.shape[0], self.chunks[1]), DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) obs = zarr.group(store=store, overwrite=False)["obs"] if len(obs) == Z.shape[1]: obs.append(np.arange(self.chunks[1]) + obs[-1] + 1) else: extra_trajs = self._maxids - Z.shape[0] if len(Z.shape) == 2: - a = np.full((extra_trajs, Z.shape[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) + a = np.full((extra_trajs, Z.shape[1]), DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) else: - a = np.full((extra_trajs,), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) + a = np.full((extra_trajs,), DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) Z.append(a, axis=axis) zarr.consolidate_metadata(store) @@ -237,13 +212,13 @@ def write(self, pset: ParticleSet, time: float | timedelta | np.timedelta64 | No if self._write_once(var): data = np.full( (arrsize[0],), - _DATATYPES_TO_FILL_VALUES[vars_to_write[var]], + DATATYPES_TO_FILL_VALUES[vars_to_write[var]], dtype=vars_to_write[var], ) data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once) dims = ["trajectory"] else: - data = np.full(arrsize, _DATATYPES_TO_FILL_VALUES[vars_to_write[var]], dtype=vars_to_write[var]) + data = np.full(arrsize, DATATYPES_TO_FILL_VALUES[vars_to_write[var]], dtype=vars_to_write[var]) data[ids, 0] = pset.particledata.getvardata(var, indices_to_write) dims = ["trajectory", "obs"] ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout]) From caae32c52de3e636d71f285f662eafe1bf663d89 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 5 Aug 2025 13:51:46 +0200 Subject: [PATCH 03/40] Reduce dependency on particleset in particlefile --- parcels/particlefile.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 72e12508a..bfe4155e8 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -14,7 +14,6 @@ import parcels from parcels._constants import DATATYPES_TO_FILL_VALUES -from parcels._reprs import default_repr from parcels.particle import _SAME_AS_FIELDSET_TIME_INTERVAL, ParticleClass from parcels.tools._helpers import timedelta_to_float @@ -53,7 +52,7 @@ class ParticleFile: def __init__(self, store, particleset, outputdt, chunks=None, create_new_zarrfile=True): self._outputdt = timedelta_to_float(outputdt) self._chunks = chunks - self._particleset = particleset + self.particleset = particleset self._maxids = 0 self._pids_written = {} self.metadata = None @@ -73,7 +72,6 @@ def __repr__(self) -> str: return ( f"{type(self).__name__}(" f"name={self.fname!r}, " - f"particleset={default_repr(self.particleset)}, " f"outputdt={self.outputdt!r}, " f"chunks={self.chunks!r}, " f"create_new_zarrfile={self.create_new_zarrfile!r})" @@ -96,10 +94,6 @@ def outputdt(self): def chunks(self): return self._chunks - @property - def particleset(self): - return self._particleset - @property def fname(self): return self._fname From fca6d5798f1ab200e5e2a07960153a31a62b8010 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 5 Aug 2025 14:01:12 +0200 Subject: [PATCH 04/40] Remove pset from particlefile.py init --- parcels/_constants.py | 24 +++++------ parcels/particlefile.py | 80 ++++++++++++++++------------------- tests/v4/test_particlefile.py | 15 +++++++ 3 files changed, 63 insertions(+), 56 deletions(-) create mode 100644 tests/v4/test_particlefile.py diff --git a/parcels/_constants.py b/parcels/_constants.py index b49b1f2f5..784d44676 100644 --- a/parcels/_constants.py +++ b/parcels/_constants.py @@ -1,16 +1,16 @@ import numpy as np DATATYPES_TO_FILL_VALUES = { - np.float16: np.nan, - np.float32: np.nan, - np.float64: np.nan, - np.bool_: np.iinfo(np.int8).max, - np.int8: np.iinfo(np.int8).max, - np.int16: np.iinfo(np.int16).max, - np.int32: np.iinfo(np.int32).max, - np.int64: np.iinfo(np.int64).max, - np.uint8: np.iinfo(np.uint8).max, - np.uint16: np.iinfo(np.uint16).max, - np.uint32: np.iinfo(np.uint32).max, - np.uint64: np.iinfo(np.uint64).max, + np.dtype(np.float16): np.nan, + np.dtype(np.float32): np.nan, + np.dtype(np.float64): np.nan, + np.dtype(np.bool_): np.iinfo(np.int8).max, + np.dtype(np.int8): np.iinfo(np.int8).max, + np.dtype(np.int16): np.iinfo(np.int16).max, + np.dtype(np.int32): np.iinfo(np.int32).max, + np.dtype(np.int64): np.iinfo(np.int64).max, + np.dtype(np.uint8): np.iinfo(np.uint8).max, + np.dtype(np.uint16): np.iinfo(np.uint16).max, + np.dtype(np.uint32): np.iinfo(np.uint32).max, + np.dtype(np.uint64): np.iinfo(np.uint64).max, } diff --git a/parcels/particlefile.py b/parcels/particlefile.py index bfe4155e8..b9d646287 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from pathlib import Path + from parcels.particle import Variable from parcels.particleset import ParticleSet __all__ = ["ParticleFile"] @@ -49,10 +50,9 @@ class ParticleFile: ParticleFile object that can be used to write particle data to file """ - def __init__(self, store, particleset, outputdt, chunks=None, create_new_zarrfile=True): + def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): self._outputdt = timedelta_to_float(outputdt) self._chunks = chunks - self.particleset = particleset self._maxids = 0 self._pids_written = {} self.metadata = None @@ -98,37 +98,12 @@ def chunks(self): def fname(self): return self._fname - def _create_variables_attribute_dict(self, particle: ParticleClass): - """Creates the dictionary with variable attributes. - - Notes - ----- - For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden. - """ - attrs = {} - - vars = [var for var in particle.variables if var.to_write is not False] - for var in vars: - fill_value = {} - if var.dtype is not _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: - fill_value = {"_FillValue": DATATYPES_TO_FILL_VALUES[var.dtype]} - - attrs[var.name] = {**var.attrs, **fill_value} - - attrs["time"]["units"] = "seconds since " # TODO fix units - attrs["time"]["calendar"] = "None" # TODO fix calendar - - return attrs - def _convert_varout_name(self, var): if var == "depth": return "z" else: return var - def _write_once(self, var): - return self.particleset.particledata.ptype[var].to_write == "once" - def _extend_zarr_dims(self, Z, store, dtype, axis): if axis == 1: a = np.full((Z.shape[0], self.chunks[1]), DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) @@ -198,40 +173,40 @@ def write(self, pset: ParticleSet, time: float | timedelta | np.timedelta64 | No attrs=self.metadata, coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))}, ) - attrs = self._create_variables_attribute_dict() + attrs = _create_variables_attribute_dict(pclass) obs = np.zeros((self._maxids), dtype=np.int32) for var in vars_to_write: varout = self._convert_varout_name(var) if varout not in ["trajectory"]: # because 'trajectory' is written as coordinate - if self._write_once(var): + if var.to_write == "once": data = np.full( (arrsize[0],), - DATATYPES_TO_FILL_VALUES[vars_to_write[var]], - dtype=vars_to_write[var], + DATATYPES_TO_FILL_VALUES[vars_to_write.dtype], + dtype=var.dtype, ) data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once) dims = ["trajectory"] else: - data = np.full(arrsize, DATATYPES_TO_FILL_VALUES[vars_to_write[var]], dtype=vars_to_write[var]) + data = np.full(arrsize, DATATYPES_TO_FILL_VALUES[var.dtype], dtype=var.dtype) data[ids, 0] = pset.particledata.getvardata(var, indices_to_write) dims = ["trajectory", "obs"] ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout]) - ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index] + ds[varout].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks # type: ignore[index] ds.to_zarr(store, mode="w") self.create_new_zarrfile = False else: Z = zarr.group(store=store, overwrite=False) obs = pset.particledata.getvardata("obs_written", indices_to_write) for var in vars_to_write: - varout = self._convert_varout_name(var) + varout = self._convert_varout_name(var.name) if self._maxids > Z[varout].shape[0]: - self._extend_zarr_dims(Z[varout], store, dtype=vars_to_write[var], axis=0) - if self._write_once(var): + self._extend_zarr_dims(Z[varout], store, dtype=var.dtype, axis=0) + if var.to_write == "once": if len(once_ids) > 0: Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once) else: if max(obs) >= Z[varout].shape[1]: # type: ignore[type-var] - self._extend_zarr_dims(Z[varout], store, dtype=vars_to_write[var], axis=1) + self._extend_zarr_dims(Z[varout], store, dtype=var.dtype, axis=1) Z[varout].vindex[ids, obs] = pset.particledata.getvardata(var, indices_to_write) pset.particledata.setvardata("obs_written", indices_to_write, obs + 1) @@ -263,11 +238,28 @@ def _get_store_from_pathlike(path: Path | str) -> DirectoryStore: return DirectoryStore(path) -def _get_vars_to_write(particle: ParticleClass): - ret = {} - for var in particle.variables: - if var.to_write is False: - continue - ret[var.name] = var.dtype +def _get_vars_to_write(particle: ParticleClass) -> list[Variable]: + return [v for v in particle.variables if v.to_write is not False] + + +def _create_variables_attribute_dict(particle: ParticleClass) -> dict: + """Creates the dictionary with variable attributes. + + Notes + ----- + For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden. + """ + attrs = {} + + vars = [var for var in particle.variables if var.to_write is not False] + for var in vars: + fill_value = {} + if var.dtype is not _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: + fill_value = {"_FillValue": DATATYPES_TO_FILL_VALUES[var.dtype]} + + attrs[var.name] = {**var.attrs, **fill_value} + + attrs["time"]["units"] = "seconds since " # TODO fix units + attrs["time"]["calendar"] = "None" # TODO fix calendar - return ret + return attrs diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py new file mode 100644 index 000000000..0679cb0f8 --- /dev/null +++ b/tests/v4/test_particlefile.py @@ -0,0 +1,15 @@ +from parcels.particle import Particle +from parcels.particlefile import _create_variables_attribute_dict + + +def test_particlefile_init(): ... + + +def test_particlefile_init_read_only_store(): ... + + +def test_particlefile_init_no_zarr_extension(): ... + + +def test_create_variables_attribute_dict(): + _create_variables_attribute_dict(Particle) From 7d86dd65048ce145a1ba1f646a451b75479bab44 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 6 Aug 2025 18:28:14 +0200 Subject: [PATCH 05/40] Move test_particlefile.py to v4 test suite --- tests/test_particlefile.py | 406 ---------------------------------- tests/v4/test_particlefile.py | 405 ++++++++++++++++++++++++++++++++- 2 files changed, 398 insertions(+), 413 deletions(-) delete mode 100755 tests/test_particlefile.py mode change 100644 => 100755 tests/v4/test_particlefile.py diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py deleted file mode 100755 index 39cee4ae1..000000000 --- a/tests/test_particlefile.py +++ /dev/null @@ -1,406 +0,0 @@ -import os -import tempfile -from datetime import timedelta - -import numpy as np -import pytest -import xarray as xr -from zarr.storage import MemoryStore - -import parcels -from parcels import ( - AdvectionRK4, - Field, - FieldSet, - Particle, - ParticleSet, - Variable, -) -from tests.common_kernels import DoNothing -from tests.utils import create_fieldset_zeros_simple - - -@pytest.fixture -def fieldset(): - return create_fieldset_zeros_simple() - - -def test_metadata(fieldset, tmp_zarrfile): - pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) - - pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile, outputdt=1)) - - ds = xr.open_zarr(tmp_zarrfile) - assert ds.attrs["parcels_kernels"].lower() == "ParticleDoNothing".lower() - - -def test_pfile_array_write_zarr_memorystore(fieldset): - """Check that writing to a Zarr MemoryStore works.""" - npart = 10 - zarr_store = MemoryStore() - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) - pfile = pset.ParticleFile(zarr_store, outputdt=1) - pfile.write(pset, 0) - - ds = xr.open_zarr(zarr_store) - assert ds.sizes["trajectory"] == npart - ds.close() - - -def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): - npart = 10 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) - pfile.write(pset, 0) - pset.remove_indices(3) - for p in pset: - p.time = 1 - pfile.write(pset, 1) - - ds = xr.open_zarr(tmp_zarrfile) - timearr = ds["time"][:] - assert (np.isnat(timearr[3, 1])) and (np.isfinite(timearr[3, 0])) - ds.close() - - -def test_pfile_set_towrite_False(fieldset, tmp_zarrfile): - npart = 10 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart)) - pset.set_variable_write_status("depth", False) - pset.set_variable_write_status("lat", False) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) - - def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon += 0.1 - - pset.execute(Update_lon, runtime=10, output_file=pfile) - - ds = xr.open_zarr(tmp_zarrfile) - assert "time" in ds - assert "z" not in ds - assert "lat" not in ds - ds.close() - - # For pytest purposes, we need to reset to original status - pset.set_variable_write_status("depth", True) - pset.set_variable_write_status("lat", True) - - -@pytest.mark.parametrize("chunks_obs", [1, None]) -def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): - npart = 10 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) - chunks = (npart, chunks_obs) if chunks_obs else None - pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=1) - pfile.write(pset, 0) - for _ in range(npart): - pset.remove_indices(-1) - pfile.write(pset, 1) - pfile.write(pset, 2) - - ds = xr.open_zarr(tmp_zarrfile).load() - assert np.allclose(ds["time"][:, 0], np.timedelta64(0, "s"), atol=np.timedelta64(1, "ms")) - if chunks_obs is not None: - assert ds["time"][:].shape == chunks - else: - assert ds["time"][:].shape[0] == npart - assert np.all(np.isnan(ds["time"][:, 1:])) - ds.close() - - -@pytest.mark.xfail(reason="lonlatdepth_dtype removed. Update implementation to use a different particle") -def test_variable_write_double(fieldset, tmp_zarrfile): - def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon += 0.1 - - pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.00001) - pset.execute(pset.Kernel(Update_lon), endtime=0.001, dt=0.00001, output_file=ofile) - - ds = xr.open_zarr(tmp_zarrfile) - lons = ds["lon"][:] - assert isinstance(lons.values[0, 0], np.float64) - ds.close() - - -def test_write_dtypes_pfile(fieldset, tmp_zarrfile): - dtypes = [ - np.float32, - np.float64, - np.int32, - np.uint32, - np.int64, - np.uint64, - np.bool_, - np.int8, - np.uint8, - np.int16, - np.uint16, - ] - - extra_vars = [Variable(f"v_{d.__name__}", dtype=d, initial=0.0) for d in dtypes] - MyParticle = Particle.add_variables(extra_vars) - - pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=0) - pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1) - pfile.write(pset, 0) - - ds = xr.open_zarr( - tmp_zarrfile, mask_and_scale=False - ) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float - for d in dtypes: - assert ds[f"v_{d.__name__}"].dtype == d - - -@pytest.mark.parametrize("npart", [1, 2, 5]) -def test_variable_written_once(fieldset, tmp_zarrfile, npart): - def Update_v(particle, fieldset, time): # pragma: no cover - particle.v_once += 1.0 - particle.age += particle.dt - - MyParticle = Particle.add_variables( - [ - Variable("v_once", dtype=np.float64, initial=0.0, to_write="once"), - Variable("age", dtype=np.float32, initial=0.0), - ] - ) - lon = np.linspace(0, 1, npart) - lat = np.linspace(1, 0, npart) - time = np.arange(0, npart / 10.0, 0.1, dtype=np.float64) - pset = ParticleSet(fieldset, pclass=MyParticle, lon=lon, lat=lat, time=time, v_once=time) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.1) - pset.execute(pset.Kernel(Update_v), endtime=1, dt=0.1, output_file=ofile) - - assert np.allclose(pset.v_once - time - pset.age * 10, 1, atol=1e-5) - ds = xr.open_zarr(tmp_zarrfile) - vfile = np.ma.filled(ds["v_once"][:], np.nan) - assert vfile.shape == (npart,) - ds.close() - - -@pytest.mark.parametrize("type", ["repeatdt", "timearr"]) -@pytest.mark.parametrize("repeatdt", range(1, 3)) -@pytest.mark.parametrize("dt", [-1, 1]) -@pytest.mark.parametrize("maxvar", [2, 4, 10]) -def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, repeatdt, tmp_zarrfile, dt, maxvar): - runtime = 10 - fieldset.maxvar = maxvar - pset = None - - MyParticle = Particle.add_variables( - [Variable("sample_var", initial=0.0), Variable("v_once", dtype=np.float64, initial=0.0, to_write="once")] - ) - - if type == "repeatdt": - pset = ParticleSet(fieldset, lon=[0], lat=[0], pclass=MyParticle, repeatdt=repeatdt) - elif type == "timearr": - pset = ParticleSet( - fieldset, lon=np.zeros(runtime), lat=np.zeros(runtime), pclass=MyParticle, time=list(range(runtime)) - ) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1)) - - def IncrLon(particle, fieldset, time): # pragma: no cover - particle.sample_var += 1.0 - if particle.sample_var > fieldset.maxvar: - particle.delete() - - for _ in range(runtime): - pset.execute(IncrLon, dt=dt, runtime=1.0, output_file=pfile) - - ds = xr.open_zarr(tmp_zarrfile) - samplevar = ds["sample_var"][:] - if type == "repeatdt": - assert samplevar.shape == (runtime // repeatdt, min(maxvar + 1, runtime)) - assert np.allclose(pset.sample_var, np.arange(maxvar, -1, -repeatdt)) - elif type == "timearr": - assert samplevar.shape == (runtime, min(maxvar + 1, runtime)) - # test whether samplevar[:, k] = k - for k in range(samplevar.shape[1]): - assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k + 1) - filesize = os.path.getsize(str(tmp_zarrfile)) - assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB - ds.close() - - -@pytest.mark.parametrize("repeatdt", [1, 2]) -@pytest.mark.parametrize("nump", [1, 10]) -def test_pfile_chunks_repeatedrelease(fieldset, repeatdt, nump, tmp_zarrfile): - runtime = 8 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.zeros((nump, 1)), lat=np.zeros((nump, 1)), repeatdt=repeatdt) - chunks = (20, 10) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1, chunks=chunks) - - def DoNothing(particle, fieldset, time): # pragma: no cover - pass - - pset.execute(DoNothing, dt=1, runtime=runtime, output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) - assert ds["time"].shape == (int(nump * runtime / repeatdt), chunks[1]) - - -def test_write_timebackward(fieldset, tmp_zarrfile): - def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon -= 0.1 * particle.dt - - pset = ParticleSet(fieldset, pclass=Particle, lat=np.linspace(0, 1, 3), lon=[0, 0, 0], time=[1, 2, 3]) - pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1.0) - pset.execute(pset.Kernel(Update_lon), runtime=4, dt=-1.0, output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) - trajs = ds["trajectory"][:] - assert trajs.values.dtype == "int64" - assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release - ds.close() - - -def test_write_xiyi(fieldset, tmp_zarrfile): - fieldset.U.data[:] = 1 # set a non-zero zonal velocity - fieldset.add_field(Field(name="P", data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2])) - dt = 3600 - - XiYiParticle = Particle.add_variables( - [ - Variable("pxi0", dtype=np.int32, initial=0.0), - Variable("pxi1", dtype=np.int32, initial=0.0), - Variable("pyi", dtype=np.int32, initial=0.0), - ] - ) - - def Get_XiYi(particle, fieldset, time): # pragma: no cover - """Kernel to sample the grid indices of the particle. - Note that this sampling should be done _before_ the advection kernel - and that the first outputted value is zero. - Be careful when using multiple grids, as the index may be different for the grids. - """ - particle.pxi0 = fieldset.U.unravel_index(particle.ei)[2] - particle.pxi1 = fieldset.P.unravel_index(particle.ei)[2] - particle.pyi = fieldset.U.unravel_index(particle.ei)[1] - - def SampleP(particle, fieldset, time): # pragma: no cover - if time > 5 * 3600: - _ = fieldset.P[particle] # To trigger sampling of the P field - - pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1], lonlatdepth_dtype=np.float64) - pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=dt) - pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile) - - ds = xr.open_zarr(tmp_zarrfile) - pxi0 = ds["pxi0"][:].values.astype(np.int32) - pxi1 = ds["pxi1"][:].values.astype(np.int32) - lons = ds["lon"][:].values - pyi = ds["pyi"][:].values.astype(np.int32) - lats = ds["lat"][:].values - - for p in range(pyi.shape[0]): - assert (pxi0[p, 0] == 0) and (pxi0[p, -1] == pset[p].pxi0) # check that particle has moved - assert np.all(pxi1[p, :6] == 0) # check that particle has not been sampled on grid 1 until time 6 - assert np.all(pxi1[p, 6:] > 0) # check that particle has not been sampled on grid 1 after time 6 - for xi, lon in zip(pxi0[p, 1:], lons[p, 1:], strict=True): - assert fieldset.U.grid.lon[xi] <= lon < fieldset.U.grid.lon[xi + 1] - for xi, lon in zip(pxi1[p, 6:], lons[p, 6:], strict=True): - assert fieldset.P.grid.lon[xi] <= lon < fieldset.P.grid.lon[xi + 1] - for yi, lat in zip(pyi[p, 1:], lats[p, 1:], strict=True): - assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi + 1] - ds.close() - - -def test_reset_dt(fieldset, tmp_zarrfile): - # Assert that p.dt gets reset when a write_time is not a multiple of dt - # for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions - - def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon += 0.1 - - pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.05) - pset.execute(pset.Kernel(Update_lon), endtime=0.12, dt=0.02, output_file=ofile) - - assert np.allclose(pset.lon, 0.6) - - -def test_correct_misaligned_outputdt_dt(fieldset, tmp_zarrfile): - """Testing that outputdt does not need to be a multiple of dt.""" - - def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon += particle.dt - - pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=3) - pset.execute(pset.Kernel(Update_lon), endtime=11, dt=2, output_file=ofile) - - ds = xr.open_zarr(tmp_zarrfile) - assert np.allclose(ds.lon.values, [0, 3, 6, 9]) - assert np.allclose( - ds.time.values[0, :], [np.timedelta64(t, "s") for t in [0, 3, 6, 9]], atol=np.timedelta64(1, "ns") - ) - - -def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwargs, particle_class=Particle): - npart = 10 - - if fieldset is None: - fieldset = create_fieldset_zeros_simple() - - pset = ParticleSet( - fieldset, - pclass=particle_class, - lon=np.full(npart, fieldset.U.lon.mean()), - lat=np.full(npart, fieldset.U.lat.mean()), - ) - - with tempfile.TemporaryDirectory() as dir: - name = f"{dir}/test.zarr" - output_file = pset.ParticleFile(name=name, outputdt=outputdt) - - pset.execute(DoNothing, output_file=output_file, **execute_kwargs) - ds = xr.open_zarr(name).load() - - return ds - - -def test_pset_execute_outputdt_forwards(): - """Testing output data dt matches outputdt in forward time.""" - outputdt = timedelta(hours=1) - runtime = timedelta(hours=5) - dt = timedelta(minutes=5) - - ds = setup_pset_execute( - fieldset=create_fieldset_zeros_simple(), outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt) - ) - - assert np.all(ds.isel(trajectory=0).time.diff(dim="obs").values == np.timedelta64(outputdt)) - - -def test_pset_execute_outputdt_backwards(): - """Testing output data dt matches outputdt in backwards time.""" - outputdt = timedelta(hours=1) - runtime = timedelta(days=2) - dt = -timedelta(minutes=5) - - ds = setup_pset_execute( - fieldset=create_fieldset_zeros_simple(), outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt) - ) - file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values - assert np.all(file_outputdt == np.timedelta64(-outputdt)) - - -def test_pset_execute_outputdt_backwards_fieldset_timevarying(): - """test_pset_execute_outputdt_backwards() still passed despite #1722 as it doesn't account for time-varying fields, - which for some reason #1722 - """ - outputdt = timedelta(hours=1) - runtime = timedelta(days=2) - dt = -timedelta(minutes=5) - - # TODO: Not ideal using the `download_example_dataset` here, but I'm struggling to recreate this error using the test suite fieldsets we have - example_dataset_folder = parcels.download_example_dataset("MovingEddies_data") - filenames = { - "U": str(example_dataset_folder / "moving_eddiesU.nc"), - "V": str(example_dataset_folder / "moving_eddiesV.nc"), - } - variables = {"U": "vozocrtx", "V": "vomecrty"} - dimensions = {"lon": "nav_lon", "lat": "nav_lat", "time": "time_counter"} - fieldset = parcels.FieldSet.from_netcdf(filenames, variables, dimensions) - - ds = setup_pset_execute(outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt), fieldset=fieldset) - file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values - assert np.all(file_outputdt == np.timedelta64(-outputdt)), (file_outputdt, np.timedelta64(-outputdt)) diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py old mode 100644 new mode 100755 index 0679cb0f8..39cee4ae1 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -1,15 +1,406 @@ -from parcels.particle import Particle -from parcels.particlefile import _create_variables_attribute_dict +import os +import tempfile +from datetime import timedelta +import numpy as np +import pytest +import xarray as xr +from zarr.storage import MemoryStore -def test_particlefile_init(): ... +import parcels +from parcels import ( + AdvectionRK4, + Field, + FieldSet, + Particle, + ParticleSet, + Variable, +) +from tests.common_kernels import DoNothing +from tests.utils import create_fieldset_zeros_simple -def test_particlefile_init_read_only_store(): ... +@pytest.fixture +def fieldset(): + return create_fieldset_zeros_simple() -def test_particlefile_init_no_zarr_extension(): ... +def test_metadata(fieldset, tmp_zarrfile): + pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) + pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile, outputdt=1)) -def test_create_variables_attribute_dict(): - _create_variables_attribute_dict(Particle) + ds = xr.open_zarr(tmp_zarrfile) + assert ds.attrs["parcels_kernels"].lower() == "ParticleDoNothing".lower() + + +def test_pfile_array_write_zarr_memorystore(fieldset): + """Check that writing to a Zarr MemoryStore works.""" + npart = 10 + zarr_store = MemoryStore() + pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) + pfile = pset.ParticleFile(zarr_store, outputdt=1) + pfile.write(pset, 0) + + ds = xr.open_zarr(zarr_store) + assert ds.sizes["trajectory"] == npart + ds.close() + + +def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): + npart = 10 + pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) + pfile.write(pset, 0) + pset.remove_indices(3) + for p in pset: + p.time = 1 + pfile.write(pset, 1) + + ds = xr.open_zarr(tmp_zarrfile) + timearr = ds["time"][:] + assert (np.isnat(timearr[3, 1])) and (np.isfinite(timearr[3, 0])) + ds.close() + + +def test_pfile_set_towrite_False(fieldset, tmp_zarrfile): + npart = 10 + pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart)) + pset.set_variable_write_status("depth", False) + pset.set_variable_write_status("lat", False) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) + + def Update_lon(particle, fieldset, time): # pragma: no cover + particle.dlon += 0.1 + + pset.execute(Update_lon, runtime=10, output_file=pfile) + + ds = xr.open_zarr(tmp_zarrfile) + assert "time" in ds + assert "z" not in ds + assert "lat" not in ds + ds.close() + + # For pytest purposes, we need to reset to original status + pset.set_variable_write_status("depth", True) + pset.set_variable_write_status("lat", True) + + +@pytest.mark.parametrize("chunks_obs", [1, None]) +def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): + npart = 10 + pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) + chunks = (npart, chunks_obs) if chunks_obs else None + pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=1) + pfile.write(pset, 0) + for _ in range(npart): + pset.remove_indices(-1) + pfile.write(pset, 1) + pfile.write(pset, 2) + + ds = xr.open_zarr(tmp_zarrfile).load() + assert np.allclose(ds["time"][:, 0], np.timedelta64(0, "s"), atol=np.timedelta64(1, "ms")) + if chunks_obs is not None: + assert ds["time"][:].shape == chunks + else: + assert ds["time"][:].shape[0] == npart + assert np.all(np.isnan(ds["time"][:, 1:])) + ds.close() + + +@pytest.mark.xfail(reason="lonlatdepth_dtype removed. Update implementation to use a different particle") +def test_variable_write_double(fieldset, tmp_zarrfile): + def Update_lon(particle, fieldset, time): # pragma: no cover + particle.dlon += 0.1 + + pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) + ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.00001) + pset.execute(pset.Kernel(Update_lon), endtime=0.001, dt=0.00001, output_file=ofile) + + ds = xr.open_zarr(tmp_zarrfile) + lons = ds["lon"][:] + assert isinstance(lons.values[0, 0], np.float64) + ds.close() + + +def test_write_dtypes_pfile(fieldset, tmp_zarrfile): + dtypes = [ + np.float32, + np.float64, + np.int32, + np.uint32, + np.int64, + np.uint64, + np.bool_, + np.int8, + np.uint8, + np.int16, + np.uint16, + ] + + extra_vars = [Variable(f"v_{d.__name__}", dtype=d, initial=0.0) for d in dtypes] + MyParticle = Particle.add_variables(extra_vars) + + pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=0) + pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1) + pfile.write(pset, 0) + + ds = xr.open_zarr( + tmp_zarrfile, mask_and_scale=False + ) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float + for d in dtypes: + assert ds[f"v_{d.__name__}"].dtype == d + + +@pytest.mark.parametrize("npart", [1, 2, 5]) +def test_variable_written_once(fieldset, tmp_zarrfile, npart): + def Update_v(particle, fieldset, time): # pragma: no cover + particle.v_once += 1.0 + particle.age += particle.dt + + MyParticle = Particle.add_variables( + [ + Variable("v_once", dtype=np.float64, initial=0.0, to_write="once"), + Variable("age", dtype=np.float32, initial=0.0), + ] + ) + lon = np.linspace(0, 1, npart) + lat = np.linspace(1, 0, npart) + time = np.arange(0, npart / 10.0, 0.1, dtype=np.float64) + pset = ParticleSet(fieldset, pclass=MyParticle, lon=lon, lat=lat, time=time, v_once=time) + ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.1) + pset.execute(pset.Kernel(Update_v), endtime=1, dt=0.1, output_file=ofile) + + assert np.allclose(pset.v_once - time - pset.age * 10, 1, atol=1e-5) + ds = xr.open_zarr(tmp_zarrfile) + vfile = np.ma.filled(ds["v_once"][:], np.nan) + assert vfile.shape == (npart,) + ds.close() + + +@pytest.mark.parametrize("type", ["repeatdt", "timearr"]) +@pytest.mark.parametrize("repeatdt", range(1, 3)) +@pytest.mark.parametrize("dt", [-1, 1]) +@pytest.mark.parametrize("maxvar", [2, 4, 10]) +def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, repeatdt, tmp_zarrfile, dt, maxvar): + runtime = 10 + fieldset.maxvar = maxvar + pset = None + + MyParticle = Particle.add_variables( + [Variable("sample_var", initial=0.0), Variable("v_once", dtype=np.float64, initial=0.0, to_write="once")] + ) + + if type == "repeatdt": + pset = ParticleSet(fieldset, lon=[0], lat=[0], pclass=MyParticle, repeatdt=repeatdt) + elif type == "timearr": + pset = ParticleSet( + fieldset, lon=np.zeros(runtime), lat=np.zeros(runtime), pclass=MyParticle, time=list(range(runtime)) + ) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1)) + + def IncrLon(particle, fieldset, time): # pragma: no cover + particle.sample_var += 1.0 + if particle.sample_var > fieldset.maxvar: + particle.delete() + + for _ in range(runtime): + pset.execute(IncrLon, dt=dt, runtime=1.0, output_file=pfile) + + ds = xr.open_zarr(tmp_zarrfile) + samplevar = ds["sample_var"][:] + if type == "repeatdt": + assert samplevar.shape == (runtime // repeatdt, min(maxvar + 1, runtime)) + assert np.allclose(pset.sample_var, np.arange(maxvar, -1, -repeatdt)) + elif type == "timearr": + assert samplevar.shape == (runtime, min(maxvar + 1, runtime)) + # test whether samplevar[:, k] = k + for k in range(samplevar.shape[1]): + assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k + 1) + filesize = os.path.getsize(str(tmp_zarrfile)) + assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB + ds.close() + + +@pytest.mark.parametrize("repeatdt", [1, 2]) +@pytest.mark.parametrize("nump", [1, 10]) +def test_pfile_chunks_repeatedrelease(fieldset, repeatdt, nump, tmp_zarrfile): + runtime = 8 + pset = ParticleSet(fieldset, pclass=Particle, lon=np.zeros((nump, 1)), lat=np.zeros((nump, 1)), repeatdt=repeatdt) + chunks = (20, 10) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1, chunks=chunks) + + def DoNothing(particle, fieldset, time): # pragma: no cover + pass + + pset.execute(DoNothing, dt=1, runtime=runtime, output_file=pfile) + ds = xr.open_zarr(tmp_zarrfile) + assert ds["time"].shape == (int(nump * runtime / repeatdt), chunks[1]) + + +def test_write_timebackward(fieldset, tmp_zarrfile): + def Update_lon(particle, fieldset, time): # pragma: no cover + particle.dlon -= 0.1 * particle.dt + + pset = ParticleSet(fieldset, pclass=Particle, lat=np.linspace(0, 1, 3), lon=[0, 0, 0], time=[1, 2, 3]) + pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1.0) + pset.execute(pset.Kernel(Update_lon), runtime=4, dt=-1.0, output_file=pfile) + ds = xr.open_zarr(tmp_zarrfile) + trajs = ds["trajectory"][:] + assert trajs.values.dtype == "int64" + assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release + ds.close() + + +def test_write_xiyi(fieldset, tmp_zarrfile): + fieldset.U.data[:] = 1 # set a non-zero zonal velocity + fieldset.add_field(Field(name="P", data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2])) + dt = 3600 + + XiYiParticle = Particle.add_variables( + [ + Variable("pxi0", dtype=np.int32, initial=0.0), + Variable("pxi1", dtype=np.int32, initial=0.0), + Variable("pyi", dtype=np.int32, initial=0.0), + ] + ) + + def Get_XiYi(particle, fieldset, time): # pragma: no cover + """Kernel to sample the grid indices of the particle. + Note that this sampling should be done _before_ the advection kernel + and that the first outputted value is zero. + Be careful when using multiple grids, as the index may be different for the grids. + """ + particle.pxi0 = fieldset.U.unravel_index(particle.ei)[2] + particle.pxi1 = fieldset.P.unravel_index(particle.ei)[2] + particle.pyi = fieldset.U.unravel_index(particle.ei)[1] + + def SampleP(particle, fieldset, time): # pragma: no cover + if time > 5 * 3600: + _ = fieldset.P[particle] # To trigger sampling of the P field + + pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1], lonlatdepth_dtype=np.float64) + pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=dt) + pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile) + + ds = xr.open_zarr(tmp_zarrfile) + pxi0 = ds["pxi0"][:].values.astype(np.int32) + pxi1 = ds["pxi1"][:].values.astype(np.int32) + lons = ds["lon"][:].values + pyi = ds["pyi"][:].values.astype(np.int32) + lats = ds["lat"][:].values + + for p in range(pyi.shape[0]): + assert (pxi0[p, 0] == 0) and (pxi0[p, -1] == pset[p].pxi0) # check that particle has moved + assert np.all(pxi1[p, :6] == 0) # check that particle has not been sampled on grid 1 until time 6 + assert np.all(pxi1[p, 6:] > 0) # check that particle has not been sampled on grid 1 after time 6 + for xi, lon in zip(pxi0[p, 1:], lons[p, 1:], strict=True): + assert fieldset.U.grid.lon[xi] <= lon < fieldset.U.grid.lon[xi + 1] + for xi, lon in zip(pxi1[p, 6:], lons[p, 6:], strict=True): + assert fieldset.P.grid.lon[xi] <= lon < fieldset.P.grid.lon[xi + 1] + for yi, lat in zip(pyi[p, 1:], lats[p, 1:], strict=True): + assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi + 1] + ds.close() + + +def test_reset_dt(fieldset, tmp_zarrfile): + # Assert that p.dt gets reset when a write_time is not a multiple of dt + # for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions + + def Update_lon(particle, fieldset, time): # pragma: no cover + particle.dlon += 0.1 + + pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) + ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.05) + pset.execute(pset.Kernel(Update_lon), endtime=0.12, dt=0.02, output_file=ofile) + + assert np.allclose(pset.lon, 0.6) + + +def test_correct_misaligned_outputdt_dt(fieldset, tmp_zarrfile): + """Testing that outputdt does not need to be a multiple of dt.""" + + def Update_lon(particle, fieldset, time): # pragma: no cover + particle.dlon += particle.dt + + pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) + ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=3) + pset.execute(pset.Kernel(Update_lon), endtime=11, dt=2, output_file=ofile) + + ds = xr.open_zarr(tmp_zarrfile) + assert np.allclose(ds.lon.values, [0, 3, 6, 9]) + assert np.allclose( + ds.time.values[0, :], [np.timedelta64(t, "s") for t in [0, 3, 6, 9]], atol=np.timedelta64(1, "ns") + ) + + +def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwargs, particle_class=Particle): + npart = 10 + + if fieldset is None: + fieldset = create_fieldset_zeros_simple() + + pset = ParticleSet( + fieldset, + pclass=particle_class, + lon=np.full(npart, fieldset.U.lon.mean()), + lat=np.full(npart, fieldset.U.lat.mean()), + ) + + with tempfile.TemporaryDirectory() as dir: + name = f"{dir}/test.zarr" + output_file = pset.ParticleFile(name=name, outputdt=outputdt) + + pset.execute(DoNothing, output_file=output_file, **execute_kwargs) + ds = xr.open_zarr(name).load() + + return ds + + +def test_pset_execute_outputdt_forwards(): + """Testing output data dt matches outputdt in forward time.""" + outputdt = timedelta(hours=1) + runtime = timedelta(hours=5) + dt = timedelta(minutes=5) + + ds = setup_pset_execute( + fieldset=create_fieldset_zeros_simple(), outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt) + ) + + assert np.all(ds.isel(trajectory=0).time.diff(dim="obs").values == np.timedelta64(outputdt)) + + +def test_pset_execute_outputdt_backwards(): + """Testing output data dt matches outputdt in backwards time.""" + outputdt = timedelta(hours=1) + runtime = timedelta(days=2) + dt = -timedelta(minutes=5) + + ds = setup_pset_execute( + fieldset=create_fieldset_zeros_simple(), outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt) + ) + file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values + assert np.all(file_outputdt == np.timedelta64(-outputdt)) + + +def test_pset_execute_outputdt_backwards_fieldset_timevarying(): + """test_pset_execute_outputdt_backwards() still passed despite #1722 as it doesn't account for time-varying fields, + which for some reason #1722 + """ + outputdt = timedelta(hours=1) + runtime = timedelta(days=2) + dt = -timedelta(minutes=5) + + # TODO: Not ideal using the `download_example_dataset` here, but I'm struggling to recreate this error using the test suite fieldsets we have + example_dataset_folder = parcels.download_example_dataset("MovingEddies_data") + filenames = { + "U": str(example_dataset_folder / "moving_eddiesU.nc"), + "V": str(example_dataset_folder / "moving_eddiesV.nc"), + } + variables = {"U": "vozocrtx", "V": "vomecrty"} + dimensions = {"lon": "nav_lon", "lat": "nav_lat", "time": "time_counter"} + fieldset = parcels.FieldSet.from_netcdf(filenames, variables, dimensions) + + ds = setup_pset_execute(outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt), fieldset=fieldset) + file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values + assert np.all(file_outputdt == np.timedelta64(-outputdt)), (file_outputdt, np.timedelta64(-outputdt)) From c4d083de5b86ce7238954d5ec6b205ef2e8a5977 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 7 Aug 2025 09:16:52 +0200 Subject: [PATCH 06/40] Update test_particlefile.py --- tests/v4/test_particlefile.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 39cee4ae1..f89ef6b44 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -8,21 +8,25 @@ from zarr.storage import MemoryStore import parcels -from parcels import ( - AdvectionRK4, - Field, - FieldSet, - Particle, - ParticleSet, - Variable, -) +from parcels import AdvectionRK4, Field, FieldSet, Particle, ParticleSet, Variable, VectorField +from parcels._datasets.structured.generic import datasets +from parcels.xgrid import XGrid from tests.common_kernels import DoNothing from tests.utils import create_fieldset_zeros_simple @pytest.fixture -def fieldset(): - return create_fieldset_zeros_simple() +def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remove duplicates + """Fixture to create a FieldSet object for testing.""" + ds = datasets["ds_2d_left"] + grid = XGrid.from_dataset(ds) + U = Field("U", ds["U (A grid)"], grid, mesh_type="flat") + V = Field("V", ds["V (A grid)"], grid, mesh_type="flat") + UV = VectorField("UV", U, V) + + return FieldSet( + [U, V, UV], + ) def test_metadata(fieldset, tmp_zarrfile): @@ -139,7 +143,7 @@ def test_write_dtypes_pfile(fieldset, tmp_zarrfile): ] extra_vars = [Variable(f"v_{d.__name__}", dtype=d, initial=0.0) for d in dtypes] - MyParticle = Particle.add_variables(extra_vars) + MyParticle = Particle.add_variable(extra_vars) pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=0) pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1) @@ -158,7 +162,7 @@ def Update_v(particle, fieldset, time): # pragma: no cover particle.v_once += 1.0 particle.age += particle.dt - MyParticle = Particle.add_variables( + MyParticle = Particle.add_variable( [ Variable("v_once", dtype=np.float64, initial=0.0, to_write="once"), Variable("age", dtype=np.float32, initial=0.0), @@ -187,7 +191,7 @@ def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, repeatdt, fieldset.maxvar = maxvar pset = None - MyParticle = Particle.add_variables( + MyParticle = Particle.add_variable( [Variable("sample_var", initial=0.0), Variable("v_once", dtype=np.float64, initial=0.0, to_write="once")] ) @@ -257,7 +261,7 @@ def test_write_xiyi(fieldset, tmp_zarrfile): fieldset.add_field(Field(name="P", data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2])) dt = 3600 - XiYiParticle = Particle.add_variables( + XiYiParticle = Particle.add_variable( [ Variable("pxi0", dtype=np.int32, initial=0.0), Variable("pxi1", dtype=np.int32, initial=0.0), From 803fe249ebb4e7c942f45c09e1140391d908fe14 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 7 Aug 2025 09:18:34 +0200 Subject: [PATCH 07/40] Update test_particlefile.py to remote repeatdt tests --- tests/v4/test_particlefile.py | 36 +++++------------------------------ 1 file changed, 5 insertions(+), 31 deletions(-) diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index f89ef6b44..04162f9ae 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -182,11 +182,9 @@ def Update_v(particle, fieldset, time): # pragma: no cover ds.close() -@pytest.mark.parametrize("type", ["repeatdt", "timearr"]) -@pytest.mark.parametrize("repeatdt", range(1, 3)) @pytest.mark.parametrize("dt", [-1, 1]) @pytest.mark.parametrize("maxvar", [2, 4, 10]) -def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, repeatdt, tmp_zarrfile, dt, maxvar): +def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_zarrfile, dt, maxvar): runtime = 10 fieldset.maxvar = maxvar pset = None @@ -195,12 +193,9 @@ def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, repeatdt, [Variable("sample_var", initial=0.0), Variable("v_once", dtype=np.float64, initial=0.0, to_write="once")] ) - if type == "repeatdt": - pset = ParticleSet(fieldset, lon=[0], lat=[0], pclass=MyParticle, repeatdt=repeatdt) - elif type == "timearr": - pset = ParticleSet( - fieldset, lon=np.zeros(runtime), lat=np.zeros(runtime), pclass=MyParticle, time=list(range(runtime)) - ) + pset = ParticleSet( + fieldset, lon=np.zeros(runtime), lat=np.zeros(runtime), pclass=MyParticle, time=list(range(runtime)) + ) pfile = pset.ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1)) def IncrLon(particle, fieldset, time): # pragma: no cover @@ -213,33 +208,12 @@ def IncrLon(particle, fieldset, time): # pragma: no cover ds = xr.open_zarr(tmp_zarrfile) samplevar = ds["sample_var"][:] - if type == "repeatdt": - assert samplevar.shape == (runtime // repeatdt, min(maxvar + 1, runtime)) - assert np.allclose(pset.sample_var, np.arange(maxvar, -1, -repeatdt)) - elif type == "timearr": - assert samplevar.shape == (runtime, min(maxvar + 1, runtime)) + assert samplevar.shape == (runtime, min(maxvar + 1, runtime)) # test whether samplevar[:, k] = k for k in range(samplevar.shape[1]): assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k + 1) filesize = os.path.getsize(str(tmp_zarrfile)) assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB - ds.close() - - -@pytest.mark.parametrize("repeatdt", [1, 2]) -@pytest.mark.parametrize("nump", [1, 10]) -def test_pfile_chunks_repeatedrelease(fieldset, repeatdt, nump, tmp_zarrfile): - runtime = 8 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.zeros((nump, 1)), lat=np.zeros((nump, 1)), repeatdt=repeatdt) - chunks = (20, 10) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1, chunks=chunks) - - def DoNothing(particle, fieldset, time): # pragma: no cover - pass - - pset.execute(DoNothing, dt=1, runtime=runtime, output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) - assert ds["time"].shape == (int(nump * runtime / repeatdt), chunks[1]) def test_write_timebackward(fieldset, tmp_zarrfile): From 8b1c14e3989c6324517b398e80f1e0c9d5735ba8 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 7 Aug 2025 09:27:38 +0200 Subject: [PATCH 08/40] Update particlefile.py to use updated particle data --- parcels/particlefile.py | 55 +++++++++++++++++++++++------------ parcels/particleset.py | 2 +- tests/v4/test_particlefile.py | 17 ++++------- 3 files changed, 43 insertions(+), 31 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index b9d646287..767b343f8 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -3,7 +3,6 @@ from __future__ import annotations import os -import warnings from datetime import timedelta from typing import TYPE_CHECKING, Literal @@ -133,30 +132,30 @@ def write(self, pset: ParticleSet, time: float | timedelta | np.timedelta64 | No time = timedelta_to_float(time) if time is not None else None pclass = pset._ptype vars_to_write = _get_vars_to_write(pclass) - if pset.particledata._ncount == 0: - warnings.warn( - f"ParticleSet is empty on writing as array at time {time:g}", - RuntimeWarning, - stacklevel=2, - ) - return + # if pset.particledata._ncount == 0: + # warnings.warn( + # f"ParticleSet is empty on writing as array at time {time:g}", + # RuntimeWarning, + # stacklevel=2, + # ) + # return if indices is None: - indices_to_write = pset.particledata._to_write_particles(time) + indices_to_write = _to_write_particles(pset.particledata, time) else: indices_to_write = indices if len(indices_to_write) == 0: return - pids = pset.particledata.getvardata("trajectory", indices_to_write) + pids = pset.particledata["trajectory"][indices_to_write] to_add = sorted(set(pids) - set(self._pids_written.keys())) for i, pid in enumerate(to_add): self._pids_written[pid] = self._maxids + i ids = np.array([self._pids_written[p] for p in pids], dtype=int) self._maxids = len(self._pids_written) - once_ids = np.where(pset.particledata.getvardata("obs_written", indices_to_write) == 0)[0] + once_ids = np.where(pset.particledata["obs_written"][indices_to_write] == 0)[0] if len(once_ids) > 0: ids_once = ids[once_ids] indices_to_write_once = indices_to_write[once_ids] @@ -184,11 +183,11 @@ def write(self, pset: ParticleSet, time: float | timedelta | np.timedelta64 | No DATATYPES_TO_FILL_VALUES[vars_to_write.dtype], dtype=var.dtype, ) - data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once) + data[ids_once] = pset.particledata[var][indices_to_write_once] dims = ["trajectory"] else: data = np.full(arrsize, DATATYPES_TO_FILL_VALUES[var.dtype], dtype=var.dtype) - data[ids, 0] = pset.particledata.getvardata(var, indices_to_write) + data[ids, 0] = pset.particledata[var][indices_to_write] dims = ["trajectory", "obs"] ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout]) ds[varout].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks # type: ignore[index] @@ -196,20 +195,20 @@ def write(self, pset: ParticleSet, time: float | timedelta | np.timedelta64 | No self.create_new_zarrfile = False else: Z = zarr.group(store=store, overwrite=False) - obs = pset.particledata.getvardata("obs_written", indices_to_write) + obs = pset.particledata["obs_written"][indices_to_write] for var in vars_to_write: varout = self._convert_varout_name(var.name) if self._maxids > Z[varout].shape[0]: self._extend_zarr_dims(Z[varout], store, dtype=var.dtype, axis=0) if var.to_write == "once": if len(once_ids) > 0: - Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once) + Z[varout].vindex[ids_once] = pset.particledata[var][indices_to_write_once] else: if max(obs) >= Z[varout].shape[1]: # type: ignore[type-var] self._extend_zarr_dims(Z[varout], store, dtype=var.dtype, axis=1) - Z[varout].vindex[ids, obs] = pset.particledata.getvardata(var, indices_to_write) + Z[varout].vindex[ids, obs] = pset.particledata[var][indices_to_write] - pset.particledata.setvardata("obs_written", indices_to_write, obs + 1) + pset.particledata["obs_written"][indices_to_write] = obs + 1 def write_latest_locations(self, pset, time): """Write the current (latest) particle locations to zarr file. @@ -224,7 +223,7 @@ def write_latest_locations(self, pset, time): Time at which to write ParticleSet. Note that typically this would be pset.time_nextloop """ for var in ["lon", "lat", "depth", "time"]: - pset.particledata.setallvardata(f"{var}", pset.particledata.getvardata(f"{var}_nextloop")) + pset.particledata[f"{var}"] = pset.particledata[f"{var}_nextloop"] self.write(pset, time) @@ -263,3 +262,23 @@ def _create_variables_attribute_dict(particle: ParticleClass) -> dict: attrs["time"]["calendar"] = "None" # TODO fix calendar return attrs + + +def _to_write_particles(particle_data, time): + """Return the Particles that need to be written at time: if particle.time is between time-dt/2 and time+dt (/2)""" + return np.where( + ( + np.less_equal( + time - np.abs(particle_data["dt"] / 2), particle_data["time"], where=np.isfinite(particle_data["time"]) + ) + & np.greater_equal( + time + np.abs(particle_data["dt"] / 2), particle_data["time"], where=np.isfinite(particle_data["time"]) + ) + | ( + (np.isnan(particle_data["dt"])) + & np.equal(time, particle_data["time"], where=np.isfinite(particle_data["time"])) + ) + ) + & (np.isfinite(particle_data["id"])) + & (np.isfinite(particle_data["time"])) + )[0] diff --git a/parcels/particleset.py b/parcels/particleset.py index d8e9cbac2..70730c3dd 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -394,7 +394,7 @@ def InteractionKernel(self, pyfunc_inter): def ParticleFile(self, *args, **kwargs): """Wrapper method to initialise a :class:`parcels.particlefile.ParticleFile` object from the ParticleSet.""" - return ParticleFile(*args, particleset=self, **kwargs) + return ParticleFile(*args, **kwargs) def data_indices(self, variable_name, compare_values, invert=False): """Get the indices of all particles where the value of `variable_name` equals (one of) `compare_values`. diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 04162f9ae..aea3ceacd 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -12,7 +12,6 @@ from parcels._datasets.structured.generic import datasets from parcels.xgrid import XGrid from tests.common_kernels import DoNothing -from tests.utils import create_fieldset_zeros_simple @pytest.fixture @@ -315,9 +314,6 @@ def Update_lon(particle, fieldset, time): # pragma: no cover def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwargs, particle_class=Particle): npart = 10 - if fieldset is None: - fieldset = create_fieldset_zeros_simple() - pset = ParticleSet( fieldset, pclass=particle_class, @@ -335,32 +331,29 @@ def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwarg return ds -def test_pset_execute_outputdt_forwards(): +def test_pset_execute_outputdt_forwards(fieldset): """Testing output data dt matches outputdt in forward time.""" outputdt = timedelta(hours=1) runtime = timedelta(hours=5) dt = timedelta(minutes=5) - ds = setup_pset_execute( - fieldset=create_fieldset_zeros_simple(), outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt) - ) + ds = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) assert np.all(ds.isel(trajectory=0).time.diff(dim="obs").values == np.timedelta64(outputdt)) -def test_pset_execute_outputdt_backwards(): +def test_pset_execute_outputdt_backwards(fieldset): """Testing output data dt matches outputdt in backwards time.""" outputdt = timedelta(hours=1) runtime = timedelta(days=2) dt = -timedelta(minutes=5) - ds = setup_pset_execute( - fieldset=create_fieldset_zeros_simple(), outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt) - ) + ds = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values assert np.all(file_outputdt == np.timedelta64(-outputdt)) +@pytest.mark.xfail(reason="TODO v4: Update dataset loading") def test_pset_execute_outputdt_backwards_fieldset_timevarying(): """test_pset_execute_outputdt_backwards() still passed despite #1722 as it doesn't account for time-varying fields, which for some reason #1722 From 097bf443ca43506b9bca763c27739e1f66c4b82e Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 7 Aug 2025 09:37:03 +0200 Subject: [PATCH 09/40] Remove test_pfile_set_towrite_False Setting the particle write status during execution is no longer supported --- tests/v4/test_particlefile.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index aea3ceacd..6a517b537 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -66,29 +66,6 @@ def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): ds.close() -def test_pfile_set_towrite_False(fieldset, tmp_zarrfile): - npart = 10 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart)) - pset.set_variable_write_status("depth", False) - pset.set_variable_write_status("lat", False) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) - - def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon += 0.1 - - pset.execute(Update_lon, runtime=10, output_file=pfile) - - ds = xr.open_zarr(tmp_zarrfile) - assert "time" in ds - assert "z" not in ds - assert "lat" not in ds - ds.close() - - # For pytest purposes, we need to reset to original status - pset.set_variable_write_status("depth", True) - pset.set_variable_write_status("lat", True) - - @pytest.mark.parametrize("chunks_obs", [1, None]) def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): npart = 10 From b161b160790ba179f61266981baa4dd1c9944619 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 7 Aug 2025 09:59:33 +0200 Subject: [PATCH 10/40] Update pfile writing to use time of fieldset Also update all references of particledata to _data --- parcels/particlefile.py | 32 +++++++++++++++----------------- tests/v4/test_particlefile.py | 34 ++++++++++++++++++++++++++-------- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 767b343f8..f5544f48f 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -3,7 +3,6 @@ from __future__ import annotations import os -from datetime import timedelta from typing import TYPE_CHECKING, Literal import numpy as np @@ -57,7 +56,7 @@ def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): self.metadata = None self.create_new_zarrfile = create_new_zarrfile - if isinstance(store, zarr.storage.Store): + if isinstance(store, zarr.abc.store.Store): self.store = store else: self.store = _get_store_from_pathlike(store) @@ -118,7 +117,7 @@ def _extend_zarr_dims(self, Z, store, dtype, axis): Z.append(a, axis=axis) zarr.consolidate_metadata(store) - def write(self, pset: ParticleSet, time: float | timedelta | np.timedelta64 | None, indices=None): + def write(self, pset: ParticleSet, time, indices=None): """Write all data from one time step to the zarr file, before the particle locations are updated. @@ -127,12 +126,11 @@ def write(self, pset: ParticleSet, time: float | timedelta | np.timedelta64 | No pset : ParticleSet object to write time : - Time at which to write ParticleSet + Time at which to write ParticleSet (same time object as fieldset) """ - time = timedelta_to_float(time) if time is not None else None pclass = pset._ptype vars_to_write = _get_vars_to_write(pclass) - # if pset.particledata._ncount == 0: + # if pset._data._ncount == 0: # warnings.warn( # f"ParticleSet is empty on writing as array at time {time:g}", # RuntimeWarning, @@ -141,21 +139,21 @@ def write(self, pset: ParticleSet, time: float | timedelta | np.timedelta64 | No # return if indices is None: - indices_to_write = _to_write_particles(pset.particledata, time) + indices_to_write = _to_write_particles(pset._data, time) else: indices_to_write = indices if len(indices_to_write) == 0: return - pids = pset.particledata["trajectory"][indices_to_write] + pids = pset._data["trajectory"][indices_to_write] to_add = sorted(set(pids) - set(self._pids_written.keys())) for i, pid in enumerate(to_add): self._pids_written[pid] = self._maxids + i ids = np.array([self._pids_written[p] for p in pids], dtype=int) self._maxids = len(self._pids_written) - once_ids = np.where(pset.particledata["obs_written"][indices_to_write] == 0)[0] + once_ids = np.where(pset._data["obs_written"][indices_to_write] == 0)[0] if len(once_ids) > 0: ids_once = ids[once_ids] indices_to_write_once = indices_to_write[once_ids] @@ -183,11 +181,11 @@ def write(self, pset: ParticleSet, time: float | timedelta | np.timedelta64 | No DATATYPES_TO_FILL_VALUES[vars_to_write.dtype], dtype=var.dtype, ) - data[ids_once] = pset.particledata[var][indices_to_write_once] + data[ids_once] = pset._data[var][indices_to_write_once] dims = ["trajectory"] else: data = np.full(arrsize, DATATYPES_TO_FILL_VALUES[var.dtype], dtype=var.dtype) - data[ids, 0] = pset.particledata[var][indices_to_write] + data[ids, 0] = pset._data[var][indices_to_write] dims = ["trajectory", "obs"] ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout]) ds[varout].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks # type: ignore[index] @@ -195,20 +193,20 @@ def write(self, pset: ParticleSet, time: float | timedelta | np.timedelta64 | No self.create_new_zarrfile = False else: Z = zarr.group(store=store, overwrite=False) - obs = pset.particledata["obs_written"][indices_to_write] + obs = pset._data["obs_written"][indices_to_write] for var in vars_to_write: varout = self._convert_varout_name(var.name) if self._maxids > Z[varout].shape[0]: self._extend_zarr_dims(Z[varout], store, dtype=var.dtype, axis=0) if var.to_write == "once": if len(once_ids) > 0: - Z[varout].vindex[ids_once] = pset.particledata[var][indices_to_write_once] + Z[varout].vindex[ids_once] = pset._data[var][indices_to_write_once] else: if max(obs) >= Z[varout].shape[1]: # type: ignore[type-var] self._extend_zarr_dims(Z[varout], store, dtype=var.dtype, axis=1) - Z[varout].vindex[ids, obs] = pset.particledata[var][indices_to_write] + Z[varout].vindex[ids, obs] = pset._data[var][indices_to_write] - pset.particledata["obs_written"][indices_to_write] = obs + 1 + pset._data["obs_written"][indices_to_write] = obs + 1 def write_latest_locations(self, pset, time): """Write the current (latest) particle locations to zarr file. @@ -223,7 +221,7 @@ def write_latest_locations(self, pset, time): Time at which to write ParticleSet. Note that typically this would be pset.time_nextloop """ for var in ["lon", "lat", "depth", "time"]: - pset.particledata[f"{var}"] = pset.particledata[f"{var}_nextloop"] + pset._data[f"{var}"] = pset._data[f"{var}_nextloop"] self.write(pset, time) @@ -279,6 +277,6 @@ def _to_write_particles(particle_data, time): & np.equal(time, particle_data["time"], where=np.isfinite(particle_data["time"])) ) ) - & (np.isfinite(particle_data["id"])) + & (np.isfinite(particle_data["trajectory"])) & (np.isfinite(particle_data["time"])) )[0] diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 6a517b537..ca3776772 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -41,9 +41,15 @@ def test_pfile_array_write_zarr_memorystore(fieldset): """Check that writing to a Zarr MemoryStore works.""" npart = 10 zarr_store = MemoryStore() - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) + pset = ParticleSet( + fieldset, + pclass=Particle, + lon=np.linspace(0, 1, npart), + lat=0.5 * np.ones(npart), + time=fieldset.time_interval.left, + ) pfile = pset.ParticleFile(zarr_store, outputdt=1) - pfile.write(pset, 0) + pfile.write(pset, time=fieldset.time_interval.left) ds = xr.open_zarr(zarr_store) assert ds.sizes["trajectory"] == npart @@ -52,9 +58,15 @@ def test_pfile_array_write_zarr_memorystore(fieldset): def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): npart = 10 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) + pset = ParticleSet( + fieldset, + pclass=Particle, + lon=np.linspace(0, 1, npart), + lat=0.5 * np.ones(npart), + time=fieldset.time_interval.left, + ) pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) - pfile.write(pset, 0) + pfile.write(pset, time=fieldset.time_interval.left) pset.remove_indices(3) for p in pset: p.time = 1 @@ -69,10 +81,16 @@ def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): @pytest.mark.parametrize("chunks_obs", [1, None]) def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): npart = 10 - pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0) + pset = ParticleSet( + fieldset, + pclass=Particle, + lon=np.linspace(0, 1, npart), + lat=0.5 * np.ones(npart), + time=fieldset.time_interval.left, + ) chunks = (npart, chunks_obs) if chunks_obs else None pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=1) - pfile.write(pset, 0) + pfile.write(pset, time=fieldset.time_interval.left) for _ in range(npart): pset.remove_indices(-1) pfile.write(pset, 1) @@ -121,9 +139,9 @@ def test_write_dtypes_pfile(fieldset, tmp_zarrfile): extra_vars = [Variable(f"v_{d.__name__}", dtype=d, initial=0.0) for d in dtypes] MyParticle = Particle.add_variable(extra_vars) - pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=0) + pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=fieldset.time_interval.left) pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1) - pfile.write(pset, 0) + pfile.write(pset, time=fieldset.time_interval.left) ds = xr.open_zarr( tmp_zarrfile, mask_and_scale=False From 85b972273b2ad9f2e0d789e9707bd0cabfa0339b Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 7 Aug 2025 13:13:55 +0200 Subject: [PATCH 11/40] Updating particlefile.py to make compatible with v4-dev --- parcels/particlefile.py | 70 ++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index f5544f48f..dd772b034 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -56,20 +56,20 @@ def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): self.metadata = None self.create_new_zarrfile = create_new_zarrfile - if isinstance(store, zarr.abc.store.Store): + if isinstance(store, zarr.storage.Store): self.store = store else: self.store = _get_store_from_pathlike(store) - if store.read_only: - raise ValueError(f"Store {store} is read-only. Please provide a writable store.") + # TODO v4: Enable once updating to zarr v3 + # if store.read_only: + # raise ValueError(f"Store {store} is read-only. Please provide a writable store.") # TODO v4: Add check that if create_new_zarrfile is False, the store already exists def __repr__(self) -> str: return ( f"{type(self).__name__}(" - f"name={self.fname!r}, " f"outputdt={self.outputdt!r}, " f"chunks={self.chunks!r}, " f"create_new_zarrfile={self.create_new_zarrfile!r})" @@ -92,10 +92,6 @@ def outputdt(self): def chunks(self): return self._chunks - @property - def fname(self): - return self._fname - def _convert_varout_name(self, var): if var == "depth": return "z" @@ -130,6 +126,10 @@ def write(self, pset: ParticleSet, time, indices=None): """ pclass = pset._ptype vars_to_write = _get_vars_to_write(pclass) + time_interval = pset.fieldset.time_interval + time = timedelta_to_float(time - time_interval.left) + particle_data = _convert_particle_data_time_to_float_seconds(pset._data, time_interval) + # if pset._data._ncount == 0: # warnings.warn( # f"ParticleSet is empty on writing as array at time {time:g}", @@ -139,21 +139,21 @@ def write(self, pset: ParticleSet, time, indices=None): # return if indices is None: - indices_to_write = _to_write_particles(pset._data, time) + indices_to_write = _to_write_particles(particle_data, time) else: indices_to_write = indices if len(indices_to_write) == 0: return - pids = pset._data["trajectory"][indices_to_write] + pids = particle_data["trajectory"][indices_to_write] to_add = sorted(set(pids) - set(self._pids_written.keys())) for i, pid in enumerate(to_add): self._pids_written[pid] = self._maxids + i ids = np.array([self._pids_written[p] for p in pids], dtype=int) self._maxids = len(self._pids_written) - once_ids = np.where(pset._data["obs_written"][indices_to_write] == 0)[0] + once_ids = np.where(particle_data["obs_written"][indices_to_write] == 0)[0] if len(once_ids) > 0: ids_once = ids[once_ids] indices_to_write_once = indices_to_write[once_ids] @@ -173,40 +173,42 @@ def write(self, pset: ParticleSet, time, indices=None): attrs = _create_variables_attribute_dict(pclass) obs = np.zeros((self._maxids), dtype=np.int32) for var in vars_to_write: - varout = self._convert_varout_name(var) + dtype = _maybe_convert_time_dtype(var.dtype) + varout = self._convert_varout_name(var.name) if varout not in ["trajectory"]: # because 'trajectory' is written as coordinate if var.to_write == "once": data = np.full( (arrsize[0],), - DATATYPES_TO_FILL_VALUES[vars_to_write.dtype], - dtype=var.dtype, + DATATYPES_TO_FILL_VALUES[dtype], + dtype=dtype, ) - data[ids_once] = pset._data[var][indices_to_write_once] + data[ids_once] = particle_data[var.name][indices_to_write_once] dims = ["trajectory"] else: - data = np.full(arrsize, DATATYPES_TO_FILL_VALUES[var.dtype], dtype=var.dtype) - data[ids, 0] = pset._data[var][indices_to_write] + data = np.full(arrsize, DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) + data[ids, 0] = particle_data[var.name][indices_to_write] dims = ["trajectory", "obs"] - ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout]) + ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[var.name]) ds[varout].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks # type: ignore[index] ds.to_zarr(store, mode="w") self.create_new_zarrfile = False else: Z = zarr.group(store=store, overwrite=False) - obs = pset._data["obs_written"][indices_to_write] + obs = particle_data["obs_written"][indices_to_write] for var in vars_to_write: + dtype = _maybe_convert_time_dtype(var.dtype) varout = self._convert_varout_name(var.name) if self._maxids > Z[varout].shape[0]: - self._extend_zarr_dims(Z[varout], store, dtype=var.dtype, axis=0) + self._extend_zarr_dims(Z[varout], store, dtype=dtype, axis=0) if var.to_write == "once": if len(once_ids) > 0: - Z[varout].vindex[ids_once] = pset._data[var][indices_to_write_once] + Z[varout].vindex[ids_once] = particle_data[var.name][indices_to_write_once] else: if max(obs) >= Z[varout].shape[1]: # type: ignore[type-var] - self._extend_zarr_dims(Z[varout], store, dtype=var.dtype, axis=1) - Z[varout].vindex[ids, obs] = pset._data[var][indices_to_write] + self._extend_zarr_dims(Z[varout], store, dtype=dtype, axis=1) + Z[varout].vindex[ids, obs] = particle_data[var.name][indices_to_write] - pset._data["obs_written"][indices_to_write] = obs + 1 + particle_data["obs_written"][indices_to_write] = obs + 1 def write_latest_locations(self, pset, time): """Write the current (latest) particle locations to zarr file. @@ -256,8 +258,8 @@ def _create_variables_attribute_dict(particle: ParticleClass) -> dict: attrs[var.name] = {**var.attrs, **fill_value} - attrs["time"]["units"] = "seconds since " # TODO fix units - attrs["time"]["calendar"] = "None" # TODO fix calendar + # attrs["time"]["units"] = "seconds since " # TODO fix units + # attrs["time"]["calendar"] = "None" # TODO fix calendar return attrs @@ -280,3 +282,19 @@ def _to_write_particles(particle_data, time): & (np.isfinite(particle_data["trajectory"])) & (np.isfinite(particle_data["time"])) )[0] + + +def _convert_particle_data_time_to_float_seconds(particle_data, time_interval): + #! Important that this is a shallow copy, so that updates to this propogate back to the original data + particle_data = particle_data.copy() + + particle_data["time"] = ((particle_data["time"] - time_interval.left) / np.timedelta64(1, "s")).astype(np.float64) + particle_data["dt"] = (particle_data["dt"] / np.timedelta64(1, "s")).astype(np.float64) + return particle_data + + +def _maybe_convert_time_dtype(dtype: np.dtype | _SAME_AS_FIELDSET_TIME_INTERVAL) -> np.dtype: + """Convert the dtype of time to float64 if it is not already.""" + if dtype is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: + return np.dtype(np.float64) + return dtype From e8d6adaf2f3fca5341659417bd5b7dff1165758b Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 7 Aug 2025 13:39:47 +0200 Subject: [PATCH 12/40] Fix time metadata --- parcels/particlefile.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index dd772b034..47c5e0add 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -3,8 +3,10 @@ from __future__ import annotations import os +from datetime import datetime from typing import TYPE_CHECKING, Literal +import cftime import numpy as np import xarray as xr import zarr @@ -18,6 +20,7 @@ if TYPE_CHECKING: from pathlib import Path + from parcels._core.utils.time import TimeInterval from parcels.particle import Variable from parcels.particleset import ParticleSet @@ -170,7 +173,7 @@ def write(self, pset: ParticleSet, time, indices=None): attrs=self.metadata, coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))}, ) - attrs = _create_variables_attribute_dict(pclass) + attrs = _create_variables_attribute_dict(pclass, time_interval) obs = np.zeros((self._maxids), dtype=np.int32) for var in vars_to_write: dtype = _maybe_convert_time_dtype(var.dtype) @@ -241,7 +244,7 @@ def _get_vars_to_write(particle: ParticleClass) -> list[Variable]: return [v for v in particle.variables if v.to_write is not False] -def _create_variables_attribute_dict(particle: ParticleClass) -> dict: +def _create_variables_attribute_dict(particle: ParticleClass, time_interval: TimeInterval) -> dict: """Creates the dictionary with variable attributes. Notes @@ -258,8 +261,7 @@ def _create_variables_attribute_dict(particle: ParticleClass) -> dict: attrs[var.name] = {**var.attrs, **fill_value} - # attrs["time"]["units"] = "seconds since " # TODO fix units - # attrs["time"]["calendar"] = "None" # TODO fix calendar + attrs["time"].update(_get_calendar_and_units(time_interval)) return attrs @@ -298,3 +300,21 @@ def _maybe_convert_time_dtype(dtype: np.dtype | _SAME_AS_FIELDSET_TIME_INTERVAL) if dtype is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: return np.dtype(np.float64) return dtype + + +def _get_calendar_and_units(time_interval: TimeInterval) -> dict[str, str]: + calendar = None + units = "seconds" + if isinstance(time_interval.left, (np.datetime64, datetime)): + calendar = "standard" + elif isinstance(time_interval.left, cftime.datetime): + calendar = time_interval.left.calendar + + if calendar is not None: + units += f" since {time_interval.left}" + + attrs = {"units": units} + if calendar is not None: + attrs["calendar"] = calendar + + return attrs From 86f3592685d6c6a7289e1d393aa3abf7dd44cdd4 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 7 Aug 2025 14:07:33 +0200 Subject: [PATCH 13/40] Misc bugfixing while porting particlefile test suite to v4-dev --- parcels/kernel.py | 3 ++- parcels/particlefile.py | 21 +++++++++++---------- parcels/particleset.py | 2 +- tests/v4/test_particlefile.py | 27 ++++++++++++++++----------- 4 files changed, 30 insertions(+), 23 deletions(-) diff --git a/parcels/kernel.py b/parcels/kernel.py index e4c42238f..29f96ac3c 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -94,7 +94,8 @@ def funcname(self): @property #! Ported from v3. To be removed in v4? (/find another way to name kernels in output file) def name(self): - return f"{self._ptype.name}{self.funcname}" + # return f"{self._ptype.name}{self.funcname}" # TODO v4: Should we propogate the name of the particle to the metadata? At the moment we don't have the concept of naming particles (somewhat incompatible with the .add_variable() API?) + return f"{self.funcname}" @property def ptype(self): diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 47c5e0add..e0f987fe0 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -4,6 +4,7 @@ import os from datetime import datetime +from pathlib import Path from typing import TYPE_CHECKING, Literal import cftime @@ -18,8 +19,6 @@ from parcels.tools._helpers import timedelta_to_float if TYPE_CHECKING: - from pathlib import Path - from parcels._core.utils.time import TimeInterval from parcels.particle import Variable from parcels.particleset import ParticleSet @@ -56,7 +55,7 @@ def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): self._chunks = chunks self._maxids = 0 self._pids_written = {} - self.metadata = None + self.metadata = {} self.create_new_zarrfile = create_new_zarrfile if isinstance(store, zarr.storage.Store): @@ -79,13 +78,15 @@ def __repr__(self) -> str: ) def set_metadata(self, parcels_mesh: Literal["spherical", "flat"]): - self.metadata = { - "feature_type": "trajectory", - "Conventions": "CF-1.6/CF-1.7", - "ncei_template_version": "NCEI_NetCDF_Trajectory_Template_v2.0", - "parcels_version": parcels.__version__, - "parcels_mesh": parcels_mesh, - } + self.metadata.update( + { + "feature_type": "trajectory", + "Conventions": "CF-1.6/CF-1.7", + "ncei_template_version": "NCEI_NetCDF_Trajectory_Template_v2.0", + "parcels_version": parcels.__version__, + "parcels_mesh": parcels_mesh, + } + ) @property def outputdt(self): diff --git a/parcels/particleset.py b/parcels/particleset.py index 70730c3dd..c0fe2b9d5 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -542,7 +542,7 @@ def execute( # Set up pbar if output_file: - logger.info(f"Output files are stored in {output_file.fname}.") + logger.info(f"Output files are stored in {output_file.store.path}.") if verbose_progress: pbar = tqdm(total=(end_time - start_time) / np.timedelta64(1, "s"), file=sys.stdout) diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index ca3776772..0f0e6f7ff 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -112,7 +112,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover particle.dlon += 0.1 pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.00001) + ofile = pset.ParticleFile(tmp_zarrfile, outputdt=0.00001) pset.execute(pset.Kernel(Update_lon), endtime=0.001, dt=0.00001, output_file=ofile) ds = xr.open_zarr(tmp_zarrfile) @@ -140,7 +140,7 @@ def test_write_dtypes_pfile(fieldset, tmp_zarrfile): MyParticle = Particle.add_variable(extra_vars) pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=fieldset.time_interval.left) - pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) pfile.write(pset, time=fieldset.time_interval.left) ds = xr.open_zarr( @@ -153,8 +153,9 @@ def test_write_dtypes_pfile(fieldset, tmp_zarrfile): @pytest.mark.parametrize("npart", [1, 2, 5]) def test_variable_written_once(fieldset, tmp_zarrfile, npart): def Update_v(particle, fieldset, time): # pragma: no cover + dt = particle.dt / np.timedelta64(1, "s") particle.v_once += 1.0 - particle.age += particle.dt + particle.age += dt MyParticle = Particle.add_variable( [ @@ -164,10 +165,14 @@ def Update_v(particle, fieldset, time): # pragma: no cover ) lon = np.linspace(0, 1, npart) lat = np.linspace(1, 0, npart) - time = np.arange(0, npart / 10.0, 0.1, dtype=np.float64) + time = xr.date_range( + start=fieldset.time_interval.left, end=fieldset.time_interval.right - np.timedelta64(30, "D"), periods=npart + ) pset = ParticleSet(fieldset, pclass=MyParticle, lon=lon, lat=lat, time=time, v_once=time) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.1) - pset.execute(pset.Kernel(Update_v), endtime=1, dt=0.1, output_file=ofile) + ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(4, "D")) + pset.execute( + pset.Kernel(Update_v), endtime=fieldset.time_interval.right, dt=np.timedelta64(1, "D"), output_file=ofile + ) assert np.allclose(pset.v_once - time - pset.age * 10, 1, atol=1e-5) ds = xr.open_zarr(tmp_zarrfile) @@ -215,7 +220,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover particle.dlon -= 0.1 * particle.dt pset = ParticleSet(fieldset, pclass=Particle, lat=np.linspace(0, 1, 3), lon=[0, 0, 0], time=[1, 2, 3]) - pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1.0) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1.0) pset.execute(pset.Kernel(Update_lon), runtime=4, dt=-1.0, output_file=pfile) ds = xr.open_zarr(tmp_zarrfile) trajs = ds["trajectory"][:] @@ -252,7 +257,7 @@ def SampleP(particle, fieldset, time): # pragma: no cover _ = fieldset.P[particle] # To trigger sampling of the P field pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1], lonlatdepth_dtype=np.float64) - pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=dt) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=dt) pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile) ds = xr.open_zarr(tmp_zarrfile) @@ -283,7 +288,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover particle.dlon += 0.1 pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.05) + ofile = pset.ParticleFile(tmp_zarrfile, outputdt=0.05) pset.execute(pset.Kernel(Update_lon), endtime=0.12, dt=0.02, output_file=ofile) assert np.allclose(pset.lon, 0.6) @@ -296,7 +301,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover particle.dlon += particle.dt pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=3) + ofile = pset.ParticleFile(tmp_zarrfile, outputdt=3) pset.execute(pset.Kernel(Update_lon), endtime=11, dt=2, output_file=ofile) ds = xr.open_zarr(tmp_zarrfile) @@ -318,7 +323,7 @@ def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwarg with tempfile.TemporaryDirectory() as dir: name = f"{dir}/test.zarr" - output_file = pset.ParticleFile(name=name, outputdt=outputdt) + output_file = pset.ParticleFile(name, outputdt=outputdt) pset.execute(DoNothing, output_file=output_file, **execute_kwargs) ds = xr.open_zarr(name).load() From 00db91ab847c2d763995c6e6333f22495e7bbf68 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 8 Aug 2025 12:13:16 +0200 Subject: [PATCH 14/40] Update tests --- tests/v4/test_particlefile.py | 39 ++++------------------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 0f0e6f7ff..529107afe 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -28,6 +28,7 @@ def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remov ) +@pytest.mark.skip def test_metadata(fieldset, tmp_zarrfile): pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) @@ -53,7 +54,6 @@ def test_pfile_array_write_zarr_memorystore(fieldset): ds = xr.open_zarr(zarr_store) assert ds.sizes["trajectory"] == npart - ds.close() def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): @@ -75,7 +75,6 @@ def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): ds = xr.open_zarr(tmp_zarrfile) timearr = ds["time"][:] assert (np.isnat(timearr[3, 1])) and (np.isfinite(timearr[3, 0])) - ds.close() @pytest.mark.parametrize("chunks_obs", [1, None]) @@ -103,7 +102,6 @@ def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): else: assert ds["time"][:].shape[0] == npart assert np.all(np.isnan(ds["time"][:, 1:])) - ds.close() @pytest.mark.xfail(reason="lonlatdepth_dtype removed. Update implementation to use a different particle") @@ -118,7 +116,6 @@ def Update_lon(particle, fieldset, time): # pragma: no cover ds = xr.open_zarr(tmp_zarrfile) lons = ds["lon"][:] assert isinstance(lons.values[0, 0], np.float64) - ds.close() def test_write_dtypes_pfile(fieldset, tmp_zarrfile): @@ -150,35 +147,9 @@ def test_write_dtypes_pfile(fieldset, tmp_zarrfile): assert ds[f"v_{d.__name__}"].dtype == d -@pytest.mark.parametrize("npart", [1, 2, 5]) -def test_variable_written_once(fieldset, tmp_zarrfile, npart): - def Update_v(particle, fieldset, time): # pragma: no cover - dt = particle.dt / np.timedelta64(1, "s") - particle.v_once += 1.0 - particle.age += dt - - MyParticle = Particle.add_variable( - [ - Variable("v_once", dtype=np.float64, initial=0.0, to_write="once"), - Variable("age", dtype=np.float32, initial=0.0), - ] - ) - lon = np.linspace(0, 1, npart) - lat = np.linspace(1, 0, npart) - time = xr.date_range( - start=fieldset.time_interval.left, end=fieldset.time_interval.right - np.timedelta64(30, "D"), periods=npart - ) - pset = ParticleSet(fieldset, pclass=MyParticle, lon=lon, lat=lat, time=time, v_once=time) - ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(4, "D")) - pset.execute( - pset.Kernel(Update_v), endtime=fieldset.time_interval.right, dt=np.timedelta64(1, "D"), output_file=ofile - ) - - assert np.allclose(pset.v_once - time - pset.age * 10, 1, atol=1e-5) - ds = xr.open_zarr(tmp_zarrfile) - vfile = np.ma.filled(ds["v_once"][:], np.nan) - assert vfile.shape == (npart,) - ds.close() +def test_variable_written_once(): + # Test that a vaiable is only written once. This should also work with gradual particle release (so the written once time is actually after the release of the particle) + ... @pytest.mark.parametrize("dt", [-1, 1]) @@ -226,7 +197,6 @@ def Update_lon(particle, fieldset, time): # pragma: no cover trajs = ds["trajectory"][:] assert trajs.values.dtype == "int64" assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release - ds.close() def test_write_xiyi(fieldset, tmp_zarrfile): @@ -277,7 +247,6 @@ def SampleP(particle, fieldset, time): # pragma: no cover assert fieldset.P.grid.lon[xi] <= lon < fieldset.P.grid.lon[xi + 1] for yi, lat in zip(pyi[p, 1:], lats[p, 1:], strict=True): assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi + 1] - ds.close() def test_reset_dt(fieldset, tmp_zarrfile): From c9d22707bcb99f234924ecb827c5826c3f0d5e75 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 8 Aug 2025 12:58:16 +0200 Subject: [PATCH 15/40] Add 'new' test marker --- parcels/particlefile.py | 13 +++++++++++++ pyproject.toml | 1 + tests/v4/test_particlefile.py | 17 +++++++++++++++++ 3 files changed, 31 insertions(+) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index e0f987fe0..c31fa07b5 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -52,6 +52,8 @@ class ParticleFile: def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): self._outputdt = timedelta_to_float(outputdt) + + _assert_valid_chunks_tuple(chunks) self._chunks = chunks self._maxids = 0 self._pids_written = {} @@ -319,3 +321,14 @@ def _get_calendar_and_units(time_interval: TimeInterval) -> dict[str, str]: attrs["calendar"] = calendar return attrs + + +def _assert_valid_chunks_tuple(chunks): + e = ValueError(f"chunks must be a tuple of integers with length 2, got {chunks=!r} instead.") + + if not isinstance(chunks, tuple): + raise e + if len(chunks) != 2: + raise e + if not all(isinstance(c, int) for c in chunks): + raise e diff --git a/pyproject.toml b/pyproject.toml index ff2e8a315..e0d651634 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ markers = [ # can be skipped by doing `pytest -m "not slow"` etc. "v4alpha: failing tests that should work for v4alpha", "v4future: failing tests that should work for a future release of v4", "v4remove: failing tests that should probably be removed later", + "new: new tests likely to replace the old. Adding this temporarily so split suite. TODO: remove" ] filterwarnings = [ diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 529107afe..19faf002a 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -10,6 +10,7 @@ import parcels from parcels import AdvectionRK4, Field, FieldSet, Particle, ParticleSet, Variable, VectorField from parcels._datasets.structured.generic import datasets +from parcels.particlefile import ParticleFile from parcels.xgrid import XGrid from tests.common_kernels import DoNothing @@ -344,3 +345,19 @@ def test_pset_execute_outputdt_backwards_fieldset_timevarying(): ds = setup_pset_execute(outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt), fieldset=fieldset) file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values assert np.all(file_outputdt == np.timedelta64(-outputdt)), (file_outputdt, np.timedelta64(-outputdt)) + + +@pytest.fixture +def store(): + return MemoryStore() + + +@pytest.mark.new +def test_particlefile_init(store): + ParticleFile(store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) + + +@pytest.mark.new +def test_particlefile_init_invalid(store): # TODO: Add test for read only store + with pytest.raises(ValueError, match="chunks must be a tuple"): + ParticleFile(store, outputdt=np.timedelta64(1, "s"), chunks=1) From df686214683a966c20abdd027e4add3e03b4f4cd Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 8 Aug 2025 13:36:00 +0200 Subject: [PATCH 16/40] Split ParticleFile.write into ParticleFile._write_particle_data --- parcels/particlefile.py | 10 ++++++++-- tests/v4/test_particlefile.py | 6 ++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index c31fa07b5..e526f246d 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -131,11 +131,16 @@ def write(self, pset: ParticleSet, time, indices=None): Time at which to write ParticleSet (same time object as fieldset) """ pclass = pset._ptype - vars_to_write = _get_vars_to_write(pclass) time_interval = pset.fieldset.time_interval + particle_data = pset._data time = timedelta_to_float(time - time_interval.left) - particle_data = _convert_particle_data_time_to_float_seconds(pset._data, time_interval) + particle_data = _convert_particle_data_time_to_float_seconds(particle_data, time_interval) + + self._write_particle_data( + particle_data=particle_data, pclass=pclass, time_interval=time_interval, time=time, indices=indices + ) + def _write_particle_data(self, *, particle_data, pclass, time_interval, time, indices=None): # if pset._data._ncount == 0: # warnings.warn( # f"ParticleSet is empty on writing as array at time {time:g}", @@ -144,6 +149,7 @@ def write(self, pset: ParticleSet, time, indices=None): # ) # return + vars_to_write = _get_vars_to_write(pclass) if indices is None: indices_to_write = _to_write_particles(particle_data, time) else: diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 19faf002a..d9e81df61 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -11,6 +11,8 @@ from parcels import AdvectionRK4, Field, FieldSet, Particle, ParticleSet, Variable, VectorField from parcels._datasets.structured.generic import datasets from parcels.particlefile import ParticleFile +from parcels.particle import create_particle_data +from parcels._core.utils.time import TimeInterval from parcels.xgrid import XGrid from tests.common_kernels import DoNothing @@ -361,3 +363,7 @@ def test_particlefile_init(store): def test_particlefile_init_invalid(store): # TODO: Add test for read only store with pytest.raises(ValueError, match="chunks must be a tuple"): ParticleFile(store, outputdt=np.timedelta64(1, "s"), chunks=1) + + +@pytest.mark.new +def test_particlefile_writing(store): ... From 1ba8ac7f2ab1c7c6a9174119a62993bead298587 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 8 Aug 2025 14:36:07 +0200 Subject: [PATCH 17/40] Update _get_time_interval_dtype Cftime needs to be stored as object type --- parcels/particle.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/parcels/particle.py b/parcels/particle.py index 50ba66114..c6a617975 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -263,7 +263,7 @@ def _create_array_for_variable(variable: Variable, nparticles: int, time_interva "This function cannot handle attrgetter initial values." ) if (dtype := variable.dtype) is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: - dtype = type(time_interval.left) + dtype = _get_time_interval_dtype(time_interval) return np.full( shape=(nparticles,), fill_value=variable.initial, @@ -274,4 +274,8 @@ def _create_array_for_variable(variable: Variable, nparticles: int, time_interva def _get_time_interval_dtype(time_interval: TimeInterval | None) -> np.dtype: if time_interval is None: return np.timedelta64(1, "ns") - return type(time_interval.left) + time = time_interval.left + if isinstance(time, (np.datetime64, np.timedelta64)): + return time.dtype + else: + return object # cftime objects needs to be stored as object dtype From 0ce4439fdf03a3e176dc5448fa77ce970bf5c7c4 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 1 Sep 2025 15:47:27 +0200 Subject: [PATCH 18/40] Add test_particlefile_write_particle_data test --- parcels/particlefile.py | 8 ++++--- tests/v4/test_particlefile.py | 42 +++++++++++++++++++++++++++++++---- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index e526f246d..8c612c758 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -148,7 +148,7 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in # stacklevel=2, # ) # return - + nparticles = len(particle_data["trajectory"]) vars_to_write = _get_vars_to_write(pclass) if indices is None: indices_to_write = _to_write_particles(particle_data, time) @@ -173,7 +173,7 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in store = self.store if self.create_new_zarrfile: if self.chunks is None: - self._chunks = (len(pset), 1) + self._chunks = (nparticles, 1) if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index] arrsize = (self._maxids, self.chunks[1]) # type: ignore[index] else: @@ -307,7 +307,9 @@ def _convert_particle_data_time_to_float_seconds(particle_data, time_interval): def _maybe_convert_time_dtype(dtype: np.dtype | _SAME_AS_FIELDSET_TIME_INTERVAL) -> np.dtype: """Convert the dtype of time to float64 if it is not already.""" if dtype is _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: - return np.dtype(np.float64) + return np.dtype( + np.uint64 + ) #! We need to have here some proper mechanism for converting particle data to the data that is to be output to zarr (namely the time needs to be converted to float seconds by subtracting the time_interval.left) return dtype diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index d9e81df61..036223d22 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -5,14 +5,14 @@ import numpy as np import pytest import xarray as xr -from zarr.storage import MemoryStore +from zarr.storage import DirectoryStore, MemoryStore import parcels from parcels import AdvectionRK4, Field, FieldSet, Particle, ParticleSet, Variable, VectorField +from parcels._core.utils.time import TimeInterval from parcels._datasets.structured.generic import datasets +from parcels.particle import Particle, create_particle_data from parcels.particlefile import ParticleFile -from parcels.particle import create_particle_data -from parcels._core.utils.time import TimeInterval from parcels.xgrid import XGrid from tests.common_kernels import DoNothing @@ -351,6 +351,7 @@ def test_pset_execute_outputdt_backwards_fieldset_timevarying(): @pytest.fixture def store(): + return DirectoryStore("/tmp/test.zarr") return MemoryStore() @@ -366,4 +367,37 @@ def test_particlefile_init_invalid(store): # TODO: Add test for read only store @pytest.mark.new -def test_particlefile_writing(store): ... +def test_particlefile_write_particle_data(store): + nparticles = 100 + + pfile = ParticleFile(store, outputdt=np.timedelta64(1, "s"), chunks=(nparticles, 40)) + pclass = Particle + + left, right = np.datetime64("2019-05-30T12:00:00.000000000", "ns"), np.datetime64("2020-01-02", "ns") + time_interval = TimeInterval(left=left, right=right) + + initial_lon = np.linspace(0, 1, nparticles) + data = create_particle_data( + pclass=pclass, + nparticles=nparticles, + ngrids=4, + time_interval=time_interval, + initial={ + "time": np.full(nparticles, fill_value=left), + "lon": initial_lon, + "dt": np.full(nparticles, fill_value=1.0), + "trajectory": np.arange(nparticles), + }, + ) + np.testing.assert_array_equal(data["time"], left) + pfile._write_particle_data( + particle_data=data, + pclass=pclass, + time_interval=time_interval, + time=left, + ) + ds = xr.open_zarr(store, decode_cf=False) # TODO: Fix metadata and re-enable decode_cf + # assert ds.time.dtype == "datetime64[ns]" + # np.testing.assert_equal(ds["time"].isel(obs=0).values, left) + assert ds.sizes["trajectory"] == nparticles + np.testing.assert_allclose(ds["lon"].isel(obs=0).values, initial_lon) From cee3082ad067628408430adab6401ef88e31421c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 8 Aug 2025 11:27:31 +0200 Subject: [PATCH 19/40] Add comments --- parcels/particlefile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 8c612c758..2f74bde3f 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -284,11 +284,11 @@ def _to_write_particles(particle_data, time): ) & np.greater_equal( time + np.abs(particle_data["dt"] / 2), particle_data["time"], where=np.isfinite(particle_data["time"]) - ) + ) # check time - dt/2 <= particle_data["time"] <= time + dt/2 | ( (np.isnan(particle_data["dt"])) & np.equal(time, particle_data["time"], where=np.isfinite(particle_data["time"])) - ) + ) # or dt is NaN and time matches particle_data["time"] ) & (np.isfinite(particle_data["trajectory"])) & (np.isfinite(particle_data["time"])) From 686411687ae6747360448691f595f7e328294d2c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 1 Sep 2025 16:33:52 +0200 Subject: [PATCH 20/40] Allow chunks to be None --- parcels/particlefile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 2f74bde3f..696cdd29e 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -333,6 +333,8 @@ def _get_calendar_and_units(time_interval: TimeInterval) -> dict[str, str]: def _assert_valid_chunks_tuple(chunks): e = ValueError(f"chunks must be a tuple of integers with length 2, got {chunks=!r} instead.") + if chunks is None: + return if not isinstance(chunks, tuple): raise e From ac72c56f7b2f59096a2fbca2ea0877c72235df8b Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 2 Sep 2025 09:57:19 +0200 Subject: [PATCH 21/40] Restrict outputdt to be datetime or timedelta --- parcels/particlefile.py | 5 ++++- tests/v4/test_particlefile.py | 19 +++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 696cdd29e..b5dda4e62 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -51,7 +51,10 @@ class ParticleFile: """ def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): - self._outputdt = timedelta_to_float(outputdt) + if not isinstance(outputdt, (np.datetime64, np.timedelta64)): + raise ValueError(f"Expected outputdt to be a np.timedelta64 or datetime64, got {type(outputdt)}") + + self._outputdt = outputdt _assert_valid_chunks_tuple(chunks) self._chunks = chunks diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 036223d22..1118f6370 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -22,8 +22,8 @@ def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remov """Fixture to create a FieldSet object for testing.""" ds = datasets["ds_2d_left"] grid = XGrid.from_dataset(ds) - U = Field("U", ds["U (A grid)"], grid, mesh_type="flat") - V = Field("V", ds["V (A grid)"], grid, mesh_type="flat") + U = Field("U", ds["U (A grid)"], grid) + V = Field("V", ds["V (A grid)"], grid) UV = VectorField("UV", U, V) return FieldSet( @@ -191,11 +191,18 @@ def IncrLon(particle, fieldset, time): # pragma: no cover def test_write_timebackward(fieldset, tmp_zarrfile): def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon -= 0.1 * particle.dt + dt = particle.dt / np.timedelta64(1, "s") + particle.dlon -= 0.1 * dt - pset = ParticleSet(fieldset, pclass=Particle, lat=np.linspace(0, 1, 3), lon=[0, 0, 0], time=[1, 2, 3]) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1.0) - pset.execute(pset.Kernel(Update_lon), runtime=4, dt=-1.0, output_file=pfile) + pset = ParticleSet( + fieldset, + pclass=Particle, + lat=np.linspace(0, 1, 3), + lon=[0, 0, 0], + time=np.array([np.datetime64("2000-01-01") for _ in range(3)]), + ) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pset.execute(pset.Kernel(Update_lon), runtime=np.timedelta64(1, "s"), dt=-np.timedelta64(1, "s"), output_file=pfile) ds = xr.open_zarr(tmp_zarrfile) trajs = ds["trajectory"][:] assert trajs.values.dtype == "int64" From 2e5e287f4d870f0dfd0977ebd5f7288dcefc8132 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 3 Sep 2025 16:08:47 +0200 Subject: [PATCH 22/40] Update outputdt in tests --- tests/v4/test_particlefile.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 1118f6370..3fb39db48 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -35,7 +35,7 @@ def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remov def test_metadata(fieldset, tmp_zarrfile): pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) - pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile, outputdt=1)) + pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))) ds = xr.open_zarr(tmp_zarrfile) assert ds.attrs["parcels_kernels"].lower() == "ParticleDoNothing".lower() @@ -52,7 +52,7 @@ def test_pfile_array_write_zarr_memorystore(fieldset): lat=0.5 * np.ones(npart), time=fieldset.time_interval.left, ) - pfile = pset.ParticleFile(zarr_store, outputdt=1) + pfile = pset.ParticleFile(zarr_store, outputdt=np.timedelta64(1, "s")) pfile.write(pset, time=fieldset.time_interval.left) ds = xr.open_zarr(zarr_store) @@ -68,7 +68,7 @@ def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): lat=0.5 * np.ones(npart), time=fieldset.time_interval.left, ) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) pfile.write(pset, time=fieldset.time_interval.left) pset.remove_indices(3) for p in pset: @@ -91,7 +91,7 @@ def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): time=fieldset.time_interval.left, ) chunks = (npart, chunks_obs) if chunks_obs else None - pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=1) + pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=np.timedelta64(1, "s")) pfile.write(pset, time=fieldset.time_interval.left) for _ in range(npart): pset.remove_indices(-1) @@ -113,7 +113,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover particle.dlon += 0.1 pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(tmp_zarrfile, outputdt=0.00001) + ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(10, "us")) pset.execute(pset.Kernel(Update_lon), endtime=0.001, dt=0.00001, output_file=ofile) ds = xr.open_zarr(tmp_zarrfile) @@ -140,7 +140,7 @@ def test_write_dtypes_pfile(fieldset, tmp_zarrfile): MyParticle = Particle.add_variable(extra_vars) pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=fieldset.time_interval.left) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) pfile.write(pset, time=fieldset.time_interval.left) ds = xr.open_zarr( @@ -267,7 +267,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover particle.dlon += 0.1 pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(tmp_zarrfile, outputdt=0.05) + ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(50, "ms")) pset.execute(pset.Kernel(Update_lon), endtime=0.12, dt=0.02, output_file=ofile) assert np.allclose(pset.lon, 0.6) @@ -280,7 +280,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover particle.dlon += particle.dt pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) - ofile = pset.ParticleFile(tmp_zarrfile, outputdt=3) + ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(3, "s")) pset.execute(pset.Kernel(Update_lon), endtime=11, dt=2, output_file=ofile) ds = xr.open_zarr(tmp_zarrfile) From 3c65501a555979c993ae49901af1db5108b8d9e5 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 3 Sep 2025 16:13:10 +0200 Subject: [PATCH 23/40] Remove directorystore from store fixture --- tests/v4/test_particlefile.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 3fb39db48..312177a28 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -5,7 +5,7 @@ import numpy as np import pytest import xarray as xr -from zarr.storage import DirectoryStore, MemoryStore +from zarr.storage import MemoryStore import parcels from parcels import AdvectionRK4, Field, FieldSet, Particle, ParticleSet, Variable, VectorField @@ -358,7 +358,6 @@ def test_pset_execute_outputdt_backwards_fieldset_timevarying(): @pytest.fixture def store(): - return DirectoryStore("/tmp/test.zarr") return MemoryStore() From f04aab127386f00ec182146f8ff3d8c71fbe0f6a Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 3 Sep 2025 16:30:05 +0200 Subject: [PATCH 24/40] Move zarr store fixture to project level --- tests/conftest.py | 6 ++++++ tests/v4/test_particlefile.py | 19 +++++++------------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 69787802c..82020c37e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,13 @@ import pytest +from zarr.storage import MemoryStore @pytest.fixture() def tmp_zarrfile(tmp_path, request): test_name = request.node.name yield tmp_path / f"{test_name}-output.zarr" + + +@pytest.fixture +def tmp_store(): + return MemoryStore() diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 312177a28..89f0347bb 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -356,27 +356,22 @@ def test_pset_execute_outputdt_backwards_fieldset_timevarying(): assert np.all(file_outputdt == np.timedelta64(-outputdt)), (file_outputdt, np.timedelta64(-outputdt)) -@pytest.fixture -def store(): - return MemoryStore() - - @pytest.mark.new -def test_particlefile_init(store): - ParticleFile(store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) +def test_particlefile_init(tmp_store): + ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) @pytest.mark.new -def test_particlefile_init_invalid(store): # TODO: Add test for read only store +def test_particlefile_init_invalid(tmp_store): # TODO: Add test for read only store with pytest.raises(ValueError, match="chunks must be a tuple"): - ParticleFile(store, outputdt=np.timedelta64(1, "s"), chunks=1) + ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=1) @pytest.mark.new -def test_particlefile_write_particle_data(store): +def test_particlefile_write_particle_data(tmp_store): nparticles = 100 - pfile = ParticleFile(store, outputdt=np.timedelta64(1, "s"), chunks=(nparticles, 40)) + pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(nparticles, 40)) pclass = Particle left, right = np.datetime64("2019-05-30T12:00:00.000000000", "ns"), np.datetime64("2020-01-02", "ns") @@ -402,7 +397,7 @@ def test_particlefile_write_particle_data(store): time_interval=time_interval, time=left, ) - ds = xr.open_zarr(store, decode_cf=False) # TODO: Fix metadata and re-enable decode_cf + ds = xr.open_zarr(tmp_store, decode_cf=False) # TODO: Fix metadata and re-enable decode_cf # assert ds.time.dtype == "datetime64[ns]" # np.testing.assert_equal(ds["time"].isel(obs=0).values, left) assert ds.sizes["trajectory"] == nparticles From 28e14b0f002830c4c1899fd3e7f5733aa75c0869 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 3 Sep 2025 17:34:06 +0200 Subject: [PATCH 25/40] Update particleset.py to be compatible with particle writing again --- parcels/particleset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index c0fe2b9d5..e50e10e12 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -542,24 +542,24 @@ def execute( # Set up pbar if output_file: - logger.info(f"Output files are stored in {output_file.store.path}.") + logger.info(f"Output files are stored in {output_file.store}.") if verbose_progress: pbar = tqdm(total=(end_time - start_time) / np.timedelta64(1, "s"), file=sys.stdout) - next_output = outputdt if output_file else None + next_output = start_time + sign_dt * outputdt if output_file else None time = start_time while sign_dt * (time - end_time) < 0: if sign_dt > 0: - next_time = end_time # TODO update to min(next_output, end_time) when ParticleFile works + next_time = min(next_output, end_time) else: - next_time = end_time # TODO update to max(next_output, end_time) when ParticleFile works + next_time = max(next_output, end_time) self._kernel.execute(self, endtime=next_time, dt=dt) # TODO: Handle IO timing based of timedelta or datetime objects if next_output: - if abs(next_time - next_output) < 1e-12: + if np.abs(next_time - next_output) < np.timedelta64(1000, "ns"): if output_file: output_file.write(self, next_output) if np.isfinite(outputdt): From 6db3737d02fc27881472f8438848406413aa4a6c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 3 Sep 2025 17:35:47 +0200 Subject: [PATCH 26/40] Fix selection of particlefiles for writing Looking at time_nextloop instead. We might rework the kernel loop in future (and the responsibilities of time, time_nextloop, as well as their effect on particle writing). --- parcels/particlefile.py | 15 +++++++++++---- tests/v4/test_advection.py | 20 ++++++++++++++++++++ tests/v4/test_particlefile.py | 1 + 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index b5dda4e62..3911c92cf 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -283,18 +283,22 @@ def _to_write_particles(particle_data, time): return np.where( ( np.less_equal( - time - np.abs(particle_data["dt"] / 2), particle_data["time"], where=np.isfinite(particle_data["time"]) + time - np.abs(particle_data["dt"] / 2), + particle_data["time_nextloop"], + where=np.isfinite(particle_data["time_nextloop"]), ) & np.greater_equal( - time + np.abs(particle_data["dt"] / 2), particle_data["time"], where=np.isfinite(particle_data["time"]) + time + np.abs(particle_data["dt"] / 2), + particle_data["time_nextloop"], + where=np.isfinite(particle_data["time_nextloop"]), ) # check time - dt/2 <= particle_data["time"] <= time + dt/2 | ( (np.isnan(particle_data["dt"])) - & np.equal(time, particle_data["time"], where=np.isfinite(particle_data["time"])) + & np.equal(time, particle_data["time_nextloop"], where=np.isfinite(particle_data["time_nextloop"])) ) # or dt is NaN and time matches particle_data["time"] ) & (np.isfinite(particle_data["trajectory"])) - & (np.isfinite(particle_data["time"])) + & (np.isfinite(particle_data["time_nextloop"])) )[0] @@ -303,6 +307,9 @@ def _convert_particle_data_time_to_float_seconds(particle_data, time_interval): particle_data = particle_data.copy() particle_data["time"] = ((particle_data["time"] - time_interval.left) / np.timedelta64(1, "s")).astype(np.float64) + particle_data["time_nextloop"] = ( + (particle_data["time_nextloop"] - time_interval.left) / np.timedelta64(1, "s") + ).astype(np.float64) particle_data["dt"] = (particle_data["dt"] / np.timedelta64(1, "s")).astype(np.float64) return particle_data diff --git a/tests/v4/test_advection.py b/tests/v4/test_advection.py index 2522a8b6f..08536dce8 100644 --- a/tests/v4/test_advection.py +++ b/tests/v4/test_advection.py @@ -52,6 +52,26 @@ def test_advection_zonal(mesh, npart=10): assert (np.diff(pset.lon) < 1.0e-4).all() +def test_advection_zonal_with_particlefile(tmp_store): + """Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`.""" + npart = 10 + ds = simple_UV_dataset(mesh="flat") + ds["U"].data[:] = 1.0 + grid = XGrid.from_dataset(ds, mesh="flat") + U = Field("U", ds["U"], grid, interp_method=XLinear) + V = Field("V", ds["V"], grid, interp_method=XLinear) + UV = VectorField("UV", U, V) + fieldset = FieldSet([U, V, UV]) + + pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart)) + pfile = pset.ParticleFile(tmp_store, outputdt=np.timedelta64(15, "m")) + pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m"), output_file=pfile) + + assert (np.diff(pset.lon) < 1.0e-4).all() + ds = xr.open_zarr(tmp_store) + breakpoint() + + def periodicBC(particle, fieldset, time): particle.total_dlon += particle.dlon particle.lon = np.fmod(particle.lon, 2) diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 89f0347bb..b62551a5a 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -385,6 +385,7 @@ def test_particlefile_write_particle_data(tmp_store): time_interval=time_interval, initial={ "time": np.full(nparticles, fill_value=left), + "time_nextloop": np.full(nparticles, fill_value=left), "lon": initial_lon, "dt": np.full(nparticles, fill_value=1.0), "trajectory": np.arange(nparticles), From d09ee25423a927587d0d41c4bd3feb17314a3dc9 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 4 Sep 2025 09:24:34 +0200 Subject: [PATCH 27/40] Move _DATATYPES_TO_FILL_VALUES back to particlefile.py --- parcels/_constants.py | 16 ---------------- parcels/particlefile.py | 28 +++++++++++++++++++++------- 2 files changed, 21 insertions(+), 23 deletions(-) delete mode 100644 parcels/_constants.py diff --git a/parcels/_constants.py b/parcels/_constants.py deleted file mode 100644 index 784d44676..000000000 --- a/parcels/_constants.py +++ /dev/null @@ -1,16 +0,0 @@ -import numpy as np - -DATATYPES_TO_FILL_VALUES = { - np.dtype(np.float16): np.nan, - np.dtype(np.float32): np.nan, - np.dtype(np.float64): np.nan, - np.dtype(np.bool_): np.iinfo(np.int8).max, - np.dtype(np.int8): np.iinfo(np.int8).max, - np.dtype(np.int16): np.iinfo(np.int16).max, - np.dtype(np.int32): np.iinfo(np.int32).max, - np.dtype(np.int64): np.iinfo(np.int64).max, - np.dtype(np.uint8): np.iinfo(np.uint8).max, - np.dtype(np.uint16): np.iinfo(np.uint16).max, - np.dtype(np.uint32): np.iinfo(np.uint32).max, - np.dtype(np.uint64): np.iinfo(np.uint64).max, -} diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 3911c92cf..f250ad7c3 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -14,7 +14,6 @@ from zarr.storage import DirectoryStore import parcels -from parcels._constants import DATATYPES_TO_FILL_VALUES from parcels.particle import _SAME_AS_FIELDSET_TIME_INTERVAL, ParticleClass from parcels.tools._helpers import timedelta_to_float @@ -25,6 +24,21 @@ __all__ = ["ParticleFile"] +_DATATYPES_TO_FILL_VALUES = { + np.dtype(np.float16): np.nan, + np.dtype(np.float32): np.nan, + np.dtype(np.float64): np.nan, + np.dtype(np.bool_): np.iinfo(np.int8).max, + np.dtype(np.int8): np.iinfo(np.int8).max, + np.dtype(np.int16): np.iinfo(np.int16).max, + np.dtype(np.int32): np.iinfo(np.int32).max, + np.dtype(np.int64): np.iinfo(np.int64).max, + np.dtype(np.uint8): np.iinfo(np.uint8).max, + np.dtype(np.uint16): np.iinfo(np.uint16).max, + np.dtype(np.uint32): np.iinfo(np.uint32).max, + np.dtype(np.uint64): np.iinfo(np.uint64).max, +} + class ParticleFile: """Initialise trajectory output. @@ -109,16 +123,16 @@ def _convert_varout_name(self, var): def _extend_zarr_dims(self, Z, store, dtype, axis): if axis == 1: - a = np.full((Z.shape[0], self.chunks[1]), DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) + a = np.full((Z.shape[0], self.chunks[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) obs = zarr.group(store=store, overwrite=False)["obs"] if len(obs) == Z.shape[1]: obs.append(np.arange(self.chunks[1]) + obs[-1] + 1) else: extra_trajs = self._maxids - Z.shape[0] if len(Z.shape) == 2: - a = np.full((extra_trajs, Z.shape[1]), DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) + a = np.full((extra_trajs, Z.shape[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) else: - a = np.full((extra_trajs,), DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) + a = np.full((extra_trajs,), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) Z.append(a, axis=axis) zarr.consolidate_metadata(store) @@ -194,13 +208,13 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in if var.to_write == "once": data = np.full( (arrsize[0],), - DATATYPES_TO_FILL_VALUES[dtype], + _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype, ) data[ids_once] = particle_data[var.name][indices_to_write_once] dims = ["trajectory"] else: - data = np.full(arrsize, DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) + data = np.full(arrsize, _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) data[ids, 0] = particle_data[var.name][indices_to_write] dims = ["trajectory", "obs"] ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[var.name]) @@ -269,7 +283,7 @@ def _create_variables_attribute_dict(particle: ParticleClass, time_interval: Tim for var in vars: fill_value = {} if var.dtype is not _SAME_AS_FIELDSET_TIME_INTERVAL.VALUE: - fill_value = {"_FillValue": DATATYPES_TO_FILL_VALUES[var.dtype]} + fill_value = {"_FillValue": _DATATYPES_TO_FILL_VALUES[var.dtype]} attrs[var.name] = {**var.attrs, **fill_value} From 66d94a0d4789d6727ba199ae19dcc91fa7f76b8e Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 4 Sep 2025 09:30:16 +0200 Subject: [PATCH 28/40] Remove np.datetime64 from outputdt Not sure how that got in there.... Typo --- parcels/particlefile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index f250ad7c3..bd2026b57 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -65,8 +65,8 @@ class ParticleFile: """ def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): - if not isinstance(outputdt, (np.datetime64, np.timedelta64)): - raise ValueError(f"Expected outputdt to be a np.timedelta64 or datetime64, got {type(outputdt)}") + if not isinstance(outputdt, np.timedelta64): + raise ValueError(f"Expected outputdt to be a np.timedelta64, got {type(outputdt)}") self._outputdt = outputdt From 6bc8284eb613215f8af7b09e5cf51ff2281bbe16 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 4 Sep 2025 09:59:23 +0200 Subject: [PATCH 29/40] Update test to remove breakpoint --- tests/v4/test_advection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v4/test_advection.py b/tests/v4/test_advection.py index 08536dce8..7dfca0aa0 100644 --- a/tests/v4/test_advection.py +++ b/tests/v4/test_advection.py @@ -69,7 +69,7 @@ def test_advection_zonal_with_particlefile(tmp_store): assert (np.diff(pset.lon) < 1.0e-4).all() ds = xr.open_zarr(tmp_store) - breakpoint() + np.testing.assert_allclose(ds.isel(obs=-1).lon.values, pset.lon) def periodicBC(particle, fieldset, time): From 68df63cae5c5697ad772df09e222583bf302ca0a Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 4 Sep 2025 10:36:21 +0200 Subject: [PATCH 30/40] Self review - Set metadata in file - Remove temp pytest marker --- parcels/particlefile.py | 7 +++++-- parcels/particleset.py | 3 ++- pyproject.toml | 1 - tests/v4/test_particlefile.py | 7 ++----- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index bd2026b57..22251e405 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path from typing import TYPE_CHECKING, Literal @@ -65,8 +65,11 @@ class ParticleFile: """ def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): + if isinstance(outputdt, timedelta): + outputdt = np.timedelta64(int(outputdt.total_seconds()), "s") + if not isinstance(outputdt, np.timedelta64): - raise ValueError(f"Expected outputdt to be a np.timedelta64, got {type(outputdt)}") + raise ValueError(f"Expected outputdt to be a np.timedelta64 or datetime.timedelta, got {type(outputdt)}") self._outputdt = outputdt diff --git a/parcels/particleset.py b/parcels/particleset.py index e50e10e12..d36e6d827 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -505,7 +505,8 @@ def execute( self._kernel = pyfunc - if output_file: + if output_file is not None: + output_file.set_metadata(self.fieldset.gridset[0]._mesh) output_file.metadata["parcels_kernels"] = self._kernel.name if dt is None: diff --git a/pyproject.toml b/pyproject.toml index e0d651634..ff2e8a315 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,7 +131,6 @@ markers = [ # can be skipped by doing `pytest -m "not slow"` etc. "v4alpha: failing tests that should work for v4alpha", "v4future: failing tests that should work for a future release of v4", "v4remove: failing tests that should probably be removed later", - "new: new tests likely to replace the old. Adding this temporarily so split suite. TODO: remove" ] filterwarnings = [ diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index b62551a5a..35cd56bfb 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -296,8 +296,8 @@ def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwarg pset = ParticleSet( fieldset, pclass=particle_class, - lon=np.full(npart, fieldset.U.lon.mean()), - lat=np.full(npart, fieldset.U.lat.mean()), + lon=np.full(npart, fieldset.U.data.lon.mean()), + lat=np.full(npart, fieldset.U.data.lat.mean()), ) with tempfile.TemporaryDirectory() as dir: @@ -356,18 +356,15 @@ def test_pset_execute_outputdt_backwards_fieldset_timevarying(): assert np.all(file_outputdt == np.timedelta64(-outputdt)), (file_outputdt, np.timedelta64(-outputdt)) -@pytest.mark.new def test_particlefile_init(tmp_store): ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) -@pytest.mark.new def test_particlefile_init_invalid(tmp_store): # TODO: Add test for read only store with pytest.raises(ValueError, match="chunks must be a tuple"): ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=1) -@pytest.mark.new def test_particlefile_write_particle_data(tmp_store): nparticles = 100 From 9d0a0a7dda71623c446b43c5add99ab21fd60acf Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 4 Sep 2025 10:48:33 +0200 Subject: [PATCH 31/40] Fix selection of next_time in execute --- parcels/particleset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index d36e6d827..193b480c6 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -552,10 +552,12 @@ def execute( time = start_time while sign_dt * (time - end_time) < 0: - if sign_dt > 0: - next_time = min(next_output, end_time) + if next_output is not None: + f = min if sign_dt > 0 else max + next_time = f(next_output, end_time) else: - next_time = max(next_output, end_time) + next_time = end_time + self._kernel.execute(self, endtime=next_time, dt=dt) # TODO: Handle IO timing based of timedelta or datetime objects From faec4616bef3cf015bf5f0a7a6329fd4e682ad9b Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 4 Sep 2025 14:49:31 +0200 Subject: [PATCH 32/40] Remove Kernel.name in favour of using funcname --- parcels/kernel.py | 5 ----- parcels/particleset.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/parcels/kernel.py b/parcels/kernel.py index 29f96ac3c..41409c15a 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -92,11 +92,6 @@ def funcname(self): ret += f.__name__ return ret - @property #! Ported from v3. To be removed in v4? (/find another way to name kernels in output file) - def name(self): - # return f"{self._ptype.name}{self.funcname}" # TODO v4: Should we propogate the name of the particle to the metadata? At the moment we don't have the concept of naming particles (somewhat incompatible with the .add_variable() API?) - return f"{self.funcname}" - @property def ptype(self): return self._ptype diff --git a/parcels/particleset.py b/parcels/particleset.py index 193b480c6..939ebdde9 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -507,7 +507,7 @@ def execute( if output_file is not None: output_file.set_metadata(self.fieldset.gridset[0]._mesh) - output_file.metadata["parcels_kernels"] = self._kernel.name + output_file.metadata["parcels_kernels"] = self._kernel.funcname if dt is None: dt = np.timedelta64(1, "s") From d5f4d1c78f9ebd615872218acfd20e40ff41e8d0 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 4 Sep 2025 14:54:53 +0200 Subject: [PATCH 33/40] Remove empty long_name s --- parcels/particle.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/parcels/particle.py b/parcels/particle.py index c6a617975..3fbbacb30 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -168,19 +168,19 @@ def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClas Variable( "lon", dtype=spatial_dtype, - attrs={"long_name": "", "standard_name": "longitude", "units": "degrees_east", "axis": "X"}, + attrs={"standard_name": "longitude", "units": "degrees_east", "axis": "X"}, ), Variable("lon_nextloop", dtype=spatial_dtype, to_write=False), Variable( "lat", dtype=spatial_dtype, - attrs={"long_name": "", "standard_name": "latitude", "units": "degrees_north", "axis": "Y"}, + attrs={"standard_name": "latitude", "units": "degrees_north", "axis": "Y"}, ), Variable("lat_nextloop", dtype=spatial_dtype, to_write=False), Variable( "depth", dtype=spatial_dtype, - attrs={"long_name": "", "standard_name": "depth", "units": "m", "positive": "down"}, + attrs={"standard_name": "depth", "units": "m", "positive": "down"}, ), Variable("dlon", dtype=spatial_dtype, to_write=False), Variable("dlat", dtype=spatial_dtype, to_write=False), @@ -189,7 +189,7 @@ def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClas Variable( "time", dtype=_SAME_AS_FIELDSET_TIME_INTERVAL.VALUE, - attrs={"long_name": "", "standard_name": "time", "units": "seconds", "axis": "T"}, + attrs={"standard_name": "time", "units": "seconds", "axis": "T"}, ), Variable("time_nextloop", dtype=_SAME_AS_FIELDSET_TIME_INTERVAL.VALUE, to_write=False), Variable( From 16032a4dc1955ab8f891e841c303d914ee15e64b Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Sep 2025 14:29:41 +0200 Subject: [PATCH 34/40] Fix tests in test_particlefile.py Also flag tests that should be looked at later - as they are out of scope for this PR (or may be able to be tested another way) --- tests/v4/test_particlefile.py | 109 ++++++++++++++++++++++++---------- 1 file changed, 77 insertions(+), 32 deletions(-) diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 35cd56bfb..956a46097 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -11,8 +11,9 @@ from parcels import AdvectionRK4, Field, FieldSet, Particle, ParticleSet, Variable, VectorField from parcels._core.utils.time import TimeInterval from parcels._datasets.structured.generic import datasets -from parcels.particle import Particle, create_particle_data +from parcels.particle import Particle, create_particle_data, get_default_particle from parcels.particlefile import ParticleFile +from parcels.tools.statuscodes import StatusCode from parcels.xgrid import XGrid from tests.common_kernels import DoNothing @@ -37,7 +38,7 @@ def test_metadata(fieldset, tmp_zarrfile): pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))) - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_zarrfile, decode_cf=False) # TODO v4: Fix metadata and re-enable decode_cf assert ds.attrs["parcels_kernels"].lower() == "ParticleDoNothing".lower() @@ -69,14 +70,19 @@ def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): time=fieldset.time_interval.left, ) pfile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pset._data["time"][:] = fieldset.time_interval.left + pset._data["time_nextloop"][:] = fieldset.time_interval.left pfile.write(pset, time=fieldset.time_interval.left) pset.remove_indices(3) - for p in pset: - p.time = 1 - pfile.write(pset, 1) - - ds = xr.open_zarr(tmp_zarrfile) + new_time = fieldset.time_interval.left + np.timedelta64(1, "D") + pset._data["time"][:] = new_time + pset._data["time_nextloop"][:] = new_time + pfile.write(pset, new_time) + ds = xr.open_zarr(tmp_zarrfile, decode_cf=False) timearr = ds["time"][:] + pytest.skip( + "TODO v4: Set decode_cf=True, which will mean that missing values get decoded to NaT rather than fill value" + ) assert (np.isnat(timearr[3, 1])) and (np.isfinite(timearr[3, 0])) @@ -95,10 +101,13 @@ def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): pfile.write(pset, time=fieldset.time_interval.left) for _ in range(npart): pset.remove_indices(-1) - pfile.write(pset, 1) - pfile.write(pset, 2) + pfile.write(pset, fieldset.time_interval.left + np.timedelta64(1, "D")) + pfile.write(pset, fieldset.time_interval.left + np.timedelta64(2, "D")) - ds = xr.open_zarr(tmp_zarrfile).load() + ds = xr.open_zarr(tmp_zarrfile, decode_cf=False).load() + pytest.skip( + "TODO v4: Set decode_cf=True, which will mean that missing values get decoded to NaT rather than fill value" + ) assert np.allclose(ds["time"][:, 0], np.timedelta64(0, "s"), atol=np.timedelta64(1, "ms")) if chunks_obs is not None: assert ds["time"][:].shape == chunks @@ -107,16 +116,22 @@ def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): assert np.all(np.isnan(ds["time"][:, 1:])) -@pytest.mark.xfail(reason="lonlatdepth_dtype removed. Update implementation to use a different particle") +@pytest.mark.skip(reason="TODO v4: stuck in infinite loop") def test_variable_write_double(fieldset, tmp_zarrfile): def Update_lon(particle, fieldset, time): # pragma: no cover particle.dlon += 0.1 - pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) + particle = get_default_particle(np.float64) + pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(10, "us")) - pset.execute(pset.Kernel(Update_lon), endtime=0.001, dt=0.00001, output_file=ofile) + pset.execute( + pset.Kernel(Update_lon), + runtime=np.timedelta64(1, "ms"), + dt=np.timedelta64(10, "us"), + output_file=ofile, + ) - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_zarrfile, decode_cf=False) # TODO v4: Fix metadata and re-enable decode_cf lons = ds["lon"][:] assert isinstance(lons.values[0, 0], np.float64) @@ -155,11 +170,19 @@ def test_variable_written_once(): ... -@pytest.mark.parametrize("dt", [-1, 1]) +@pytest.mark.parametrize( + "dt", + [ + pytest.param(-np.timedelta64(1, "s"), marks=pytest.mark.xfail(reason="need to fix backwards in time")), + np.timedelta64(1, "s"), + ], +) @pytest.mark.parametrize("maxvar", [2, 4, 10]) def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_zarrfile, dt, maxvar): - runtime = 10 - fieldset.maxvar = maxvar + """Tests that if particles are released and deleted based on age that resulting output file is correct.""" + npart = 10 + runtime = np.timedelta64(npart, "s") + fieldset.add_constant("maxvar", maxvar) pset = None MyParticle = Particle.add_variable( @@ -167,19 +190,29 @@ def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_zarrfile, d ) pset = ParticleSet( - fieldset, lon=np.zeros(runtime), lat=np.zeros(runtime), pclass=MyParticle, time=list(range(runtime)) + fieldset, + lon=np.zeros(npart), + lat=np.zeros(npart), + pclass=MyParticle, + time=fieldset.time_interval.left + np.array([np.timedelta64(i, "s") for i in range(npart)]), ) pfile = pset.ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1)) def IncrLon(particle, fieldset, time): # pragma: no cover particle.sample_var += 1.0 - if particle.sample_var > fieldset.maxvar: - particle.delete() + particle.state = np.where( + particle.sample_var > fieldset.maxvar, + StatusCode.Delete, + particle.state, + ) - for _ in range(runtime): - pset.execute(IncrLon, dt=dt, runtime=1.0, output_file=pfile) + for _ in range(npart): + pset.execute(IncrLon, dt=dt, runtime=np.timedelta64(1, "s"), output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_zarrfile, decode_cf=False) + pytest.skip( + "TODO v4: Set decode_cf=True, which will mean that missing values get decoded to NaT rather than fill value" + ) samplevar = ds["sample_var"][:] assert samplevar.shape == (runtime, min(maxvar + 1, runtime)) # test whether samplevar[:, k] = k @@ -189,6 +222,7 @@ def IncrLon(particle, fieldset, time): # pragma: no cover assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB +@pytest.mark.xfail(reason="need to fix backwards in time") def test_write_timebackward(fieldset, tmp_zarrfile): def Update_lon(particle, fieldset, time): # pragma: no cover dt = particle.dt / np.timedelta64(1, "s") @@ -209,12 +243,15 @@ def Update_lon(particle, fieldset, time): # pragma: no cover assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release +@pytest.mark.xfail +@pytest.mark.v4alpha def test_write_xiyi(fieldset, tmp_zarrfile): fieldset.U.data[:] = 1 # set a non-zero zonal velocity fieldset.add_field(Field(name="P", data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2])) - dt = 3600 + dt = np.timedelta64(3600, "s") - XiYiParticle = Particle.add_variable( + particle = get_default_particle(np.float64) + XiYiParticle = particle.add_variable( [ Variable("pxi0", dtype=np.int32, initial=0.0), Variable("pxi1", dtype=np.int32, initial=0.0), @@ -236,7 +273,7 @@ def SampleP(particle, fieldset, time): # pragma: no cover if time > 5 * 3600: _ = fieldset.P[particle] # To trigger sampling of the P field - pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1], lonlatdepth_dtype=np.float64) + pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1]) pfile = pset.ParticleFile(tmp_zarrfile, outputdt=dt) pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile) @@ -259,6 +296,8 @@ def SampleP(particle, fieldset, time): # pragma: no cover assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi + 1] +@pytest.mark.skip +@pytest.mark.v4alpha def test_reset_dt(fieldset, tmp_zarrfile): # Assert that p.dt gets reset when a write_time is not a multiple of dt # for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions @@ -266,22 +305,27 @@ def test_reset_dt(fieldset, tmp_zarrfile): def Update_lon(particle, fieldset, time): # pragma: no cover particle.dlon += 0.1 - pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) + particle = get_default_particle(np.float64) + pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(50, "ms")) - pset.execute(pset.Kernel(Update_lon), endtime=0.12, dt=0.02, output_file=ofile) + dt = np.timedelta64(20, "ms") + pset.execute(pset.Kernel(Update_lon), runtime=6 * dt, dt=dt, output_file=ofile) assert np.allclose(pset.lon, 0.6) +@pytest.mark.v4alpha +@pytest.mark.xfail def test_correct_misaligned_outputdt_dt(fieldset, tmp_zarrfile): """Testing that outputdt does not need to be a multiple of dt.""" def Update_lon(particle, fieldset, time): # pragma: no cover - particle.dlon += particle.dt + particle.dlon += particle.dt / np.timedelta64(1, "s") - pset = ParticleSet(fieldset, pclass=Particle, lon=[0], lat=[0], lonlatdepth_dtype=np.float64) + particle = get_default_particle(np.float64) + pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(3, "s")) - pset.execute(pset.Kernel(Update_lon), endtime=11, dt=2, output_file=ofile) + pset.execute(pset.Kernel(Update_lon), runtime=np.timedelta64(11, "s"), dt=np.timedelta64(2, "s"), output_file=ofile) ds = xr.open_zarr(tmp_zarrfile) assert np.allclose(ds.lon.values, [0, 3, 6, 9]) @@ -321,6 +365,7 @@ def test_pset_execute_outputdt_forwards(fieldset): assert np.all(ds.isel(trajectory=0).time.diff(dim="obs").values == np.timedelta64(outputdt)) +@pytest.mark.skip(reason="backwards in time not yet working") def test_pset_execute_outputdt_backwards(fieldset): """Testing output data dt matches outputdt in backwards time.""" outputdt = timedelta(hours=1) @@ -395,7 +440,7 @@ def test_particlefile_write_particle_data(tmp_store): time_interval=time_interval, time=left, ) - ds = xr.open_zarr(tmp_store, decode_cf=False) # TODO: Fix metadata and re-enable decode_cf + ds = xr.open_zarr(tmp_store, decode_cf=False) # TODO v4: Fix metadata and re-enable decode_cf # assert ds.time.dtype == "datetime64[ns]" # np.testing.assert_equal(ds["time"].isel(obs=0).values, left) assert ds.sizes["trajectory"] == nparticles From e3d6a61452270ba7ed49615077453d2a68a58bfc Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Sep 2025 14:42:41 +0200 Subject: [PATCH 35/40] Make ParticleFile store and create_new_zarrfile attrs read only --- parcels/particlefile.py | 20 ++++++++++++++------ tests/v4/test_particlefile.py | 7 +++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 22251e405..055fe7bd6 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -78,12 +78,12 @@ def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): self._maxids = 0 self._pids_written = {} self.metadata = {} - self.create_new_zarrfile = create_new_zarrfile + self._create_new_zarrfile = create_new_zarrfile - if isinstance(store, zarr.storage.Store): - self.store = store - else: - self.store = _get_store_from_pathlike(store) + if not isinstance(store, zarr.storage.Store): + store = _get_store_from_pathlike(store) + + self._store = store # TODO v4: Enable once updating to zarr v3 # if store.read_only: @@ -118,6 +118,14 @@ def outputdt(self): def chunks(self): return self._chunks + @property + def store(self): + return self._store + + @property + def create_new_zarrfile(self): + return self._create_new_zarrfile + def _convert_varout_name(self, var): if var == "depth": return "z" @@ -223,7 +231,7 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[var.name]) ds[varout].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks # type: ignore[index] ds.to_zarr(store, mode="w") - self.create_new_zarrfile = False + self._create_new_zarrfile = False else: Z = zarr.group(store=store, overwrite=False) obs = particle_data["obs_written"][indices_to_write] diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 956a46097..0f9728d93 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -405,6 +405,13 @@ def test_particlefile_init(tmp_store): ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) +@pytest.mark.parametrize("name", ["store", "outputdt", "chunks", "create_new_zarrfile"]) +def test_particlefile_readonly_attrs(tmp_store, name): + pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) + with pytest.raises(AttributeError, match="property .* of 'ParticleFile' object has no setter"): + setattr(pfile, name, "something") + + def test_particlefile_init_invalid(tmp_store): # TODO: Add test for read only store with pytest.raises(ValueError, match="chunks must be a tuple"): ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=1) From 5393f6f48d31fc7f6cf5fd7a04ae3eb6099c38c5 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Sep 2025 14:58:43 +0200 Subject: [PATCH 36/40] Remove pset.ParticleFile alias In favour of directly using the ParticleFile object --- parcels/particleset.py | 5 ----- tests/tools/test_warnings.py | 3 ++- tests/v4/test_advection.py | 3 ++- tests/v4/test_interpolation.py | 3 ++- tests/v4/test_particlefile.py | 24 ++++++++++++------------ tests/v4/test_particleset_execute.py | 3 ++- v3to4-breaking-changes.md | 4 ++++ 7 files changed, 24 insertions(+), 21 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index 939ebdde9..7fc058799 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -15,7 +15,6 @@ from parcels.basegrid import GridType from parcels.kernel import Kernel from parcels.particle import KernelParticle, Particle, create_particle_data -from parcels.particlefile import ParticleFile from parcels.tools.converters import convert_to_flat_array from parcels.tools.loggers import logger from parcels.tools.statuscodes import StatusCode @@ -392,10 +391,6 @@ def InteractionKernel(self, pyfunc_inter): return None return InteractionKernel(self.fieldset, self._ptype, pyfunc=pyfunc_inter) - def ParticleFile(self, *args, **kwargs): - """Wrapper method to initialise a :class:`parcels.particlefile.ParticleFile` object from the ParticleSet.""" - return ParticleFile(*args, **kwargs) - def data_indices(self, variable_name, compare_values, invert=False): """Get the indices of all particles where the value of `variable_name` equals (one of) `compare_values`. diff --git a/tests/tools/test_warnings.py b/tests/tools/test_warnings.py index 49007b671..1570d44bf 100644 --- a/tests/tools/test_warnings.py +++ b/tests/tools/test_warnings.py @@ -11,6 +11,7 @@ ParticleSet, ParticleSetWarning, ) +from parcels.particlefile import ParticleFile from tests.utils import TEST_DATA @@ -30,7 +31,7 @@ def test_file_warnings(tmp_zarrfile): data={"U": np.zeros((1, 1)), "V": np.zeros((1, 1))}, dimensions={"lon": [0], "lat": [0]} ) pset = ParticleSet(fieldset=fieldset, pclass=Particle, lon=[0, 0], lat=[0, 0], time=[0, 1]) - pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=2) + pfile = ParticleFile(name=tmp_zarrfile, outputdt=2) with pytest.warns(ParticleSetWarning, match="Some of the particles have a start time difference.*"): pset.execute(AdvectionRK4, runtime=3, dt=1, output_file=pfile) diff --git a/tests/v4/test_advection.py b/tests/v4/test_advection.py index 7dfca0aa0..9247c25bc 100644 --- a/tests/v4/test_advection.py +++ b/tests/v4/test_advection.py @@ -17,6 +17,7 @@ from parcels.field import Field, VectorField from parcels.fieldset import FieldSet from parcels.particle import Particle, Variable +from parcels.particlefile import ParticleFile from parcels.particleset import ParticleSet from parcels.tools.statuscodes import StatusCode from parcels.xgrid import XGrid @@ -64,7 +65,7 @@ def test_advection_zonal_with_particlefile(tmp_store): fieldset = FieldSet([U, V, UV]) pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart)) - pfile = pset.ParticleFile(tmp_store, outputdt=np.timedelta64(15, "m")) + pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(15, "m")) pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m"), output_file=pfile) assert (np.diff(pset.lon) < 1.0e-4).all() diff --git a/tests/v4/test_interpolation.py b/tests/v4/test_interpolation.py index d3ca4dea3..5fa92ab71 100644 --- a/tests/v4/test_interpolation.py +++ b/tests/v4/test_interpolation.py @@ -10,6 +10,7 @@ from parcels.field import Field, VectorField from parcels.fieldset import FieldSet from parcels.particle import Particle, Variable +from parcels.particlefile import ParticleFile from parcels.particleset import ParticleSet from parcels.tools.statuscodes import StatusCode from parcels.uxgrid import UxGrid @@ -169,7 +170,7 @@ def DeleteParticle(particle, fieldset, time): if particle.state >= 50: particle.state = StatusCode.Delete - outfile = pset.ParticleFile(f"test_interpolation_v4_{interp_name}", outputdt=np.timedelta64(1, "s")) + outfile = ParticleFile(f"test_interpolation_v4_{interp_name}", outputdt=np.timedelta64(1, "s")) pset.execute( [AdvectionRK4_3D, DeleteParticle], runtime=np.timedelta64(4, "s"), diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index 0f9728d93..e2388a6a5 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -36,7 +36,7 @@ def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remov def test_metadata(fieldset, tmp_zarrfile): pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) - pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))) + pset.execute(DoNothing, runtime=1, output_file=ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))) ds = xr.open_zarr(tmp_zarrfile, decode_cf=False) # TODO v4: Fix metadata and re-enable decode_cf assert ds.attrs["parcels_kernels"].lower() == "ParticleDoNothing".lower() @@ -53,7 +53,7 @@ def test_pfile_array_write_zarr_memorystore(fieldset): lat=0.5 * np.ones(npart), time=fieldset.time_interval.left, ) - pfile = pset.ParticleFile(zarr_store, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(zarr_store, outputdt=np.timedelta64(1, "s")) pfile.write(pset, time=fieldset.time_interval.left) ds = xr.open_zarr(zarr_store) @@ -69,7 +69,7 @@ def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): lat=0.5 * np.ones(npart), time=fieldset.time_interval.left, ) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) pset._data["time"][:] = fieldset.time_interval.left pset._data["time_nextloop"][:] = fieldset.time_interval.left pfile.write(pset, time=fieldset.time_interval.left) @@ -97,7 +97,7 @@ def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): time=fieldset.time_interval.left, ) chunks = (npart, chunks_obs) if chunks_obs else None - pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=np.timedelta64(1, "s")) pfile.write(pset, time=fieldset.time_interval.left) for _ in range(npart): pset.remove_indices(-1) @@ -123,7 +123,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover particle = get_default_particle(np.float64) pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) - ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(10, "us")) + ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(10, "us")) pset.execute( pset.Kernel(Update_lon), runtime=np.timedelta64(1, "ms"), @@ -155,7 +155,7 @@ def test_write_dtypes_pfile(fieldset, tmp_zarrfile): MyParticle = Particle.add_variable(extra_vars) pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=fieldset.time_interval.left) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) pfile.write(pset, time=fieldset.time_interval.left) ds = xr.open_zarr( @@ -196,7 +196,7 @@ def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_zarrfile, d pclass=MyParticle, time=fieldset.time_interval.left + np.array([np.timedelta64(i, "s") for i in range(npart)]), ) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1)) + pfile = ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1)) def IncrLon(particle, fieldset, time): # pragma: no cover particle.sample_var += 1.0 @@ -235,7 +235,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover lon=[0, 0, 0], time=np.array([np.datetime64("2000-01-01") for _ in range(3)]), ) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) pset.execute(pset.Kernel(Update_lon), runtime=np.timedelta64(1, "s"), dt=-np.timedelta64(1, "s"), output_file=pfile) ds = xr.open_zarr(tmp_zarrfile) trajs = ds["trajectory"][:] @@ -274,7 +274,7 @@ def SampleP(particle, fieldset, time): # pragma: no cover _ = fieldset.P[particle] # To trigger sampling of the P field pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1]) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=dt) + pfile = ParticleFile(tmp_zarrfile, outputdt=dt) pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile) ds = xr.open_zarr(tmp_zarrfile) @@ -307,7 +307,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover particle = get_default_particle(np.float64) pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) - ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(50, "ms")) + ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(50, "ms")) dt = np.timedelta64(20, "ms") pset.execute(pset.Kernel(Update_lon), runtime=6 * dt, dt=dt, output_file=ofile) @@ -324,7 +324,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover particle = get_default_particle(np.float64) pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) - ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(3, "s")) + ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(3, "s")) pset.execute(pset.Kernel(Update_lon), runtime=np.timedelta64(11, "s"), dt=np.timedelta64(2, "s"), output_file=ofile) ds = xr.open_zarr(tmp_zarrfile) @@ -346,7 +346,7 @@ def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwarg with tempfile.TemporaryDirectory() as dir: name = f"{dir}/test.zarr" - output_file = pset.ParticleFile(name, outputdt=outputdt) + output_file = ParticleFile(name, outputdt=outputdt) pset.execute(DoNothing, output_file=output_file, **execute_kwargs) ds = xr.open_zarr(name).load() diff --git a/tests/v4/test_particleset_execute.py b/tests/v4/test_particleset_execute.py index 4e4ba21e1..80087d120 100644 --- a/tests/v4/test_particleset_execute.py +++ b/tests/v4/test_particleset_execute.py @@ -17,6 +17,7 @@ from parcels._datasets.structured.generated import simple_UV_dataset from parcels._datasets.structured.generic import datasets as datasets_structured from parcels._datasets.unstructured.generic import datasets as datasets_unstructured +from parcels.particlefile import ParticleFile from parcels.tools.statuscodes import FieldInterpolationError, FieldOutOfBoundError, TimeExtrapolationError from parcels.uxgrid import UxGrid from parcels.xgrid import XGrid @@ -433,7 +434,7 @@ def test_uxstommelgyre_pset_execute_output(): time=[0.0], pclass=Particle, ) - output_file = pset.ParticleFile( + output_file = ParticleFile( name="stommel_uxarray_particles.zarr", # the file name outputdt=np.timedelta64(5, "m"), # the time step of the outputs ) diff --git a/v3to4-breaking-changes.md b/v3to4-breaking-changes.md index 3afdb3e05..f913a0d33 100644 --- a/v3to4-breaking-changes.md +++ b/v3to4-breaking-changes.md @@ -16,3 +16,7 @@ - `repeatdt` and `lonlatdepth_dtype` have been removed from the ParticleSet. - ParticleSet.execute() expects `numpy.datetime64`/`numpy.timedelta.64` for `runtime`, `endtime` and `dt`. - `ParticleSet.from_field()`, `ParticleSet.from_line()`, `ParticleSet.from_list()` have been removed. + +## ParticleFile + +- Particlefiles should be created by `ParticleFile(...)` instead of `pset.ParticleFile(...)` From 67df94e345711ddb39a1bd09fd7347712b6ac093 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Sep 2025 15:00:16 +0200 Subject: [PATCH 37/40] Update outputdt for test --- tests/v4/test_advection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v4/test_advection.py b/tests/v4/test_advection.py index 9247c25bc..751c149a3 100644 --- a/tests/v4/test_advection.py +++ b/tests/v4/test_advection.py @@ -65,7 +65,7 @@ def test_advection_zonal_with_particlefile(tmp_store): fieldset = FieldSet([U, V, UV]) pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart)) - pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(15, "m")) + pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(30, "m")) pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m"), output_file=pfile) assert (np.diff(pset.lon) < 1.0e-4).all() From 1b9f389fcdeafaa0b514ac236e52d65828aaed5f Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Sep 2025 15:37:47 +0200 Subject: [PATCH 38/40] Remove outdated TODO --- parcels/particleset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index 7fc058799..b44cef0be 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -555,7 +555,6 @@ def execute( self._kernel.execute(self, endtime=next_time, dt=dt) - # TODO: Handle IO timing based of timedelta or datetime objects if next_output: if np.abs(next_time - next_output) < np.timedelta64(1000, "ns"): if output_file: From 855ed76f7aa800c2d12fc96b01818a19325474c0 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 8 Sep 2025 09:14:43 +0200 Subject: [PATCH 39/40] Update pfile metadata to 'parcels_grid_mesh' --- parcels/particlefile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 055fe7bd6..8fd4ea25d 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -99,14 +99,14 @@ def __repr__(self) -> str: f"create_new_zarrfile={self.create_new_zarrfile!r})" ) - def set_metadata(self, parcels_mesh: Literal["spherical", "flat"]): + def set_metadata(self, parcels_grid_mesh: Literal["spherical", "flat"]): self.metadata.update( { "feature_type": "trajectory", "Conventions": "CF-1.6/CF-1.7", "ncei_template_version": "NCEI_NetCDF_Trajectory_Template_v2.0", "parcels_version": parcels.__version__, - "parcels_mesh": parcels_mesh, + "parcels_grid_mesh": parcels_grid_mesh, } ) From 097381fff78a429775820c8ba1752a9fffd8d6ce Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 8 Sep 2025 09:39:11 +0200 Subject: [PATCH 40/40] Add xfailed test_pfile_set_towrite_False back to test_particlefile.py --- tests/v4/test_particlefile.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/v4/test_particlefile.py b/tests/v4/test_particlefile.py index e2388a6a5..567793b4e 100755 --- a/tests/v4/test_particlefile.py +++ b/tests/v4/test_particlefile.py @@ -452,3 +452,35 @@ def test_particlefile_write_particle_data(tmp_store): # np.testing.assert_equal(ds["time"].isel(obs=0).values, left) assert ds.sizes["trajectory"] == nparticles np.testing.assert_allclose(ds["lon"].isel(obs=0).values, initial_lon) + + +def test_pfile_write_custom_particle(): + # Test the writing of a custom particle with variables that are to_write, some to_write once, and some not to_write + # ? This is more of an integration test... Should it be housed here? + ... + + +@pytest.mark.xfail( + reason="set_variable_write_status should be removed - with Particle writing defined on the particle level. GH2186" +) +def test_pfile_set_towrite_False(fieldset, tmp_zarrfile): + npart = 10 + pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart)) + pset.set_variable_write_status("depth", False) + pset.set_variable_write_status("lat", False) + pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) + + def Update_lon(particle, fieldset, time): # pragma: no cover + particle.dlon += 0.1 + + pset.execute(Update_lon, runtime=10, output_file=pfile) + + ds = xr.open_zarr(tmp_zarrfile) + assert "time" in ds + assert "z" not in ds + assert "lat" not in ds + ds.close() + + # For pytest purposes, we need to reset to original status + pset.set_variable_write_status("depth", True) + pset.set_variable_write_status("lat", True)