Skip to content

Commit 5393f6f

Browse files
committed
Remove pset.ParticleFile alias
In favour of directly using the ParticleFile object
1 parent e3d6a61 commit 5393f6f

File tree

7 files changed

+24
-21
lines changed

7 files changed

+24
-21
lines changed

parcels/particleset.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from parcels.basegrid import GridType
1616
from parcels.kernel import Kernel
1717
from parcels.particle import KernelParticle, Particle, create_particle_data
18-
from parcels.particlefile import ParticleFile
1918
from parcels.tools.converters import convert_to_flat_array
2019
from parcels.tools.loggers import logger
2120
from parcels.tools.statuscodes import StatusCode
@@ -392,10 +391,6 @@ def InteractionKernel(self, pyfunc_inter):
392391
return None
393392
return InteractionKernel(self.fieldset, self._ptype, pyfunc=pyfunc_inter)
394393

395-
def ParticleFile(self, *args, **kwargs):
396-
"""Wrapper method to initialise a :class:`parcels.particlefile.ParticleFile` object from the ParticleSet."""
397-
return ParticleFile(*args, **kwargs)
398-
399394
def data_indices(self, variable_name, compare_values, invert=False):
400395
"""Get the indices of all particles where the value of `variable_name` equals (one of) `compare_values`.
401396

tests/tools/test_warnings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ParticleSet,
1212
ParticleSetWarning,
1313
)
14+
from parcels.particlefile import ParticleFile
1415
from tests.utils import TEST_DATA
1516

1617

@@ -30,7 +31,7 @@ def test_file_warnings(tmp_zarrfile):
3031
data={"U": np.zeros((1, 1)), "V": np.zeros((1, 1))}, dimensions={"lon": [0], "lat": [0]}
3132
)
3233
pset = ParticleSet(fieldset=fieldset, pclass=Particle, lon=[0, 0], lat=[0, 0], time=[0, 1])
33-
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=2)
34+
pfile = ParticleFile(name=tmp_zarrfile, outputdt=2)
3435
with pytest.warns(ParticleSetWarning, match="Some of the particles have a start time difference.*"):
3536
pset.execute(AdvectionRK4, runtime=3, dt=1, output_file=pfile)
3637

tests/v4/test_advection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from parcels.field import Field, VectorField
1818
from parcels.fieldset import FieldSet
1919
from parcels.particle import Particle, Variable
20+
from parcels.particlefile import ParticleFile
2021
from parcels.particleset import ParticleSet
2122
from parcels.tools.statuscodes import StatusCode
2223
from parcels.xgrid import XGrid
@@ -64,7 +65,7 @@ def test_advection_zonal_with_particlefile(tmp_store):
6465
fieldset = FieldSet([U, V, UV])
6566

6667
pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart))
67-
pfile = pset.ParticleFile(tmp_store, outputdt=np.timedelta64(15, "m"))
68+
pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(15, "m"))
6869
pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m"), output_file=pfile)
6970

7071
assert (np.diff(pset.lon) < 1.0e-4).all()

tests/v4/test_interpolation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from parcels.field import Field, VectorField
1111
from parcels.fieldset import FieldSet
1212
from parcels.particle import Particle, Variable
13+
from parcels.particlefile import ParticleFile
1314
from parcels.particleset import ParticleSet
1415
from parcels.tools.statuscodes import StatusCode
1516
from parcels.uxgrid import UxGrid
@@ -169,7 +170,7 @@ def DeleteParticle(particle, fieldset, time):
169170
if particle.state >= 50:
170171
particle.state = StatusCode.Delete
171172

172-
outfile = pset.ParticleFile(f"test_interpolation_v4_{interp_name}", outputdt=np.timedelta64(1, "s"))
173+
outfile = ParticleFile(f"test_interpolation_v4_{interp_name}", outputdt=np.timedelta64(1, "s"))
173174
pset.execute(
174175
[AdvectionRK4_3D, DeleteParticle],
175176
runtime=np.timedelta64(4, "s"),

tests/v4/test_particlefile.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remov
3636
def test_metadata(fieldset, tmp_zarrfile):
3737
pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0)
3838

39-
pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")))
39+
pset.execute(DoNothing, runtime=1, output_file=ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")))
4040

4141
ds = xr.open_zarr(tmp_zarrfile, decode_cf=False) # TODO v4: Fix metadata and re-enable decode_cf
4242
assert ds.attrs["parcels_kernels"].lower() == "ParticleDoNothing".lower()
@@ -53,7 +53,7 @@ def test_pfile_array_write_zarr_memorystore(fieldset):
5353
lat=0.5 * np.ones(npart),
5454
time=fieldset.time_interval.left,
5555
)
56-
pfile = pset.ParticleFile(zarr_store, outputdt=np.timedelta64(1, "s"))
56+
pfile = ParticleFile(zarr_store, outputdt=np.timedelta64(1, "s"))
5757
pfile.write(pset, time=fieldset.time_interval.left)
5858

5959
ds = xr.open_zarr(zarr_store)
@@ -69,7 +69,7 @@ def test_pfile_array_remove_particles(fieldset, tmp_zarrfile):
6969
lat=0.5 * np.ones(npart),
7070
time=fieldset.time_interval.left,
7171
)
72-
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))
72+
pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))
7373
pset._data["time"][:] = fieldset.time_interval.left
7474
pset._data["time_nextloop"][:] = fieldset.time_interval.left
7575
pfile.write(pset, time=fieldset.time_interval.left)
@@ -97,7 +97,7 @@ def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile):
9797
time=fieldset.time_interval.left,
9898
)
9999
chunks = (npart, chunks_obs) if chunks_obs else None
100-
pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=np.timedelta64(1, "s"))
100+
pfile = ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=np.timedelta64(1, "s"))
101101
pfile.write(pset, time=fieldset.time_interval.left)
102102
for _ in range(npart):
103103
pset.remove_indices(-1)
@@ -123,7 +123,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover
123123

124124
particle = get_default_particle(np.float64)
125125
pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0])
126-
ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(10, "us"))
126+
ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(10, "us"))
127127
pset.execute(
128128
pset.Kernel(Update_lon),
129129
runtime=np.timedelta64(1, "ms"),
@@ -155,7 +155,7 @@ def test_write_dtypes_pfile(fieldset, tmp_zarrfile):
155155
MyParticle = Particle.add_variable(extra_vars)
156156

157157
pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=fieldset.time_interval.left)
158-
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))
158+
pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))
159159
pfile.write(pset, time=fieldset.time_interval.left)
160160

161161
ds = xr.open_zarr(
@@ -196,7 +196,7 @@ def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_zarrfile, d
196196
pclass=MyParticle,
197197
time=fieldset.time_interval.left + np.array([np.timedelta64(i, "s") for i in range(npart)]),
198198
)
199-
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1))
199+
pfile = ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1))
200200

201201
def IncrLon(particle, fieldset, time): # pragma: no cover
202202
particle.sample_var += 1.0
@@ -235,7 +235,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover
235235
lon=[0, 0, 0],
236236
time=np.array([np.datetime64("2000-01-01") for _ in range(3)]),
237237
)
238-
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))
238+
pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s"))
239239
pset.execute(pset.Kernel(Update_lon), runtime=np.timedelta64(1, "s"), dt=-np.timedelta64(1, "s"), output_file=pfile)
240240
ds = xr.open_zarr(tmp_zarrfile)
241241
trajs = ds["trajectory"][:]
@@ -274,7 +274,7 @@ def SampleP(particle, fieldset, time): # pragma: no cover
274274
_ = fieldset.P[particle] # To trigger sampling of the P field
275275

276276
pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1])
277-
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=dt)
277+
pfile = ParticleFile(tmp_zarrfile, outputdt=dt)
278278
pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile)
279279

280280
ds = xr.open_zarr(tmp_zarrfile)
@@ -307,7 +307,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover
307307

308308
particle = get_default_particle(np.float64)
309309
pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0])
310-
ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(50, "ms"))
310+
ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(50, "ms"))
311311
dt = np.timedelta64(20, "ms")
312312
pset.execute(pset.Kernel(Update_lon), runtime=6 * dt, dt=dt, output_file=ofile)
313313

@@ -324,7 +324,7 @@ def Update_lon(particle, fieldset, time): # pragma: no cover
324324

325325
particle = get_default_particle(np.float64)
326326
pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0])
327-
ofile = pset.ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(3, "s"))
327+
ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(3, "s"))
328328
pset.execute(pset.Kernel(Update_lon), runtime=np.timedelta64(11, "s"), dt=np.timedelta64(2, "s"), output_file=ofile)
329329

330330
ds = xr.open_zarr(tmp_zarrfile)
@@ -346,7 +346,7 @@ def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwarg
346346

347347
with tempfile.TemporaryDirectory() as dir:
348348
name = f"{dir}/test.zarr"
349-
output_file = pset.ParticleFile(name, outputdt=outputdt)
349+
output_file = ParticleFile(name, outputdt=outputdt)
350350

351351
pset.execute(DoNothing, output_file=output_file, **execute_kwargs)
352352
ds = xr.open_zarr(name).load()

tests/v4/test_particleset_execute.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from parcels._datasets.structured.generated import simple_UV_dataset
1818
from parcels._datasets.structured.generic import datasets as datasets_structured
1919
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
20+
from parcels.particlefile import ParticleFile
2021
from parcels.tools.statuscodes import FieldInterpolationError, FieldOutOfBoundError, TimeExtrapolationError
2122
from parcels.uxgrid import UxGrid
2223
from parcels.xgrid import XGrid
@@ -433,7 +434,7 @@ def test_uxstommelgyre_pset_execute_output():
433434
time=[0.0],
434435
pclass=Particle,
435436
)
436-
output_file = pset.ParticleFile(
437+
output_file = ParticleFile(
437438
name="stommel_uxarray_particles.zarr", # the file name
438439
outputdt=np.timedelta64(5, "m"), # the time step of the outputs
439440
)

v3to4-breaking-changes.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@
1616
- `repeatdt` and `lonlatdepth_dtype` have been removed from the ParticleSet.
1717
- ParticleSet.execute() expects `numpy.datetime64`/`numpy.timedelta.64` for `runtime`, `endtime` and `dt`.
1818
- `ParticleSet.from_field()`, `ParticleSet.from_line()`, `ParticleSet.from_list()` have been removed.
19+
20+
## ParticleFile
21+
22+
- Particlefiles should be created by `ParticleFile(...)` instead of `pset.ParticleFile(...)`

0 commit comments

Comments
 (0)