Skip to content

Commit 8b80038

Browse files
committed
Review feedback
1 parent 5e3c3a2 commit 8b80038

File tree

7 files changed

+51
-51
lines changed

7 files changed

+51
-51
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33

44
@pytest.fixture()
5-
def tmp_zarr(tmp_path, request):
5+
def tmp_zarrfile(tmp_path, request):
66
test_name = request.node.name
77
yield tmp_path / f"{test_name}-output.zarr"

tests/test_advection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def test_analyticalAgrid(mode):
605605
@pytest.mark.parametrize("v", [1, -0.3, 0, -1])
606606
@pytest.mark.parametrize("w", [None, 1, -0.3, 0, -1])
607607
@pytest.mark.parametrize("direction", [1, -1])
608-
def test_uniform_analytical(mode, u, v, w, direction, tmp_zarr):
608+
def test_uniform_analytical(mode, u, v, w, direction, tmp_zarrfile):
609609
lon = np.arange(0, 15, dtype=np.float32)
610610
lat = np.arange(0, 15, dtype=np.float32)
611611
if w is not None:
@@ -625,14 +625,14 @@ def test_uniform_analytical(mode, u, v, w, direction, tmp_zarr):
625625
x0, y0, z0 = 6.1, 6.2, 20
626626
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=x0, lat=y0, depth=z0)
627627

628-
outfile = pset.ParticleFile(name=tmp_zarr, outputdt=1, chunks=(1, 1))
628+
outfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1, chunks=(1, 1))
629629
pset.execute(AdvectionAnalytical, runtime=4, dt=direction, output_file=outfile)
630630
assert np.abs(pset.lon - x0 - pset.time * u) < 1e-6
631631
assert np.abs(pset.lat - y0 - pset.time * v) < 1e-6
632632
if w is not None:
633633
assert np.abs(pset.depth - z0 - pset.time * w) < 1e-4
634634

635-
ds = xr.open_zarr(tmp_zarr)
635+
ds = xr.open_zarr(tmp_zarrfile)
636636
times = (direction * ds["time"][:]).values.astype("timedelta64[s]")[0]
637637
timeref = np.arange(1, 5).astype("timedelta64[s]")
638638
assert np.allclose(times, timeref, atol=np.timedelta64(1, "ms"))

tests/test_fieldset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def SampleUV2(particle, fieldset, time):
655655
assert abs(pset.lat[0] - 0.5) < 1e-9
656656

657657

658-
def test_fieldset_write(tmp_zarr):
658+
def test_fieldset_write(tmp_zarrfile):
659659
xdim, ydim = 3, 4
660660
lon = np.linspace(0.0, 10.0, xdim, dtype=np.float32)
661661
lat = np.linspace(0.0, 10.0, ydim, dtype=np.float32)
@@ -673,12 +673,12 @@ def UpdateU(particle, fieldset, time):
673673
fieldset.U.grid.time[0] = time
674674

675675
pset = ParticleSet(fieldset, pclass=ScipyParticle, lon=5, lat=5)
676-
ofile = pset.ParticleFile(name=tmp_zarr, outputdt=2.0)
676+
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=2.0)
677677
pset.execute(UpdateU, dt=1, runtime=10, output_file=ofile)
678678

679679
assert fieldset.U.data[0, 1, 0] == 11
680680

681-
da = xr.open_dataset(str(tmp_zarr).replace(".zarr", "_0005U.nc"))
681+
da = xr.open_dataset(str(tmp_zarrfile).replace(".zarr", "_0005U.nc"))
682682
assert np.allclose(fieldset.U.data, da["U"].values, atol=1.0)
683683

684684

tests/test_fieldset_sampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def Recover(particle, fieldset, time):
891891

892892

893893
@pytest.mark.parametrize("mode", ["jit", "scipy"])
894-
def test_fieldset_sampling_updating_order(mode, tmp_zarr):
894+
def test_fieldset_sampling_updating_order(mode, tmp_zarrfile):
895895
def calc_p(t, y, x):
896896
return 10 * t + x + 0.2 * y
897897

@@ -923,10 +923,10 @@ def SampleP(particle, fieldset, time):
923923

924924
kernels = [AdvectionRK4, SampleP]
925925

926-
pfile = pset.ParticleFile(tmp_zarr, outputdt=1)
926+
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1)
927927
pset.execute(kernels, endtime=1, dt=1, output_file=pfile)
928928

929-
ds = xr.open_zarr(tmp_zarr)
929+
ds = xr.open_zarr(tmp_zarrfile)
930930
for t in range(len(ds["obs"])):
931931
for i in range(len(ds["trajectory"])):
932932
assert np.isclose(

tests/test_particlefile.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ def fieldset():
3232

3333

3434
@pytest.mark.parametrize("mode", ["scipy", "jit"])
35-
def test_metadata(fieldset, mode, tmp_zarr):
35+
def test_metadata(fieldset, mode, tmp_zarrfile):
3636
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=0, lat=0)
3737

38-
pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarr))
38+
pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile))
3939

40-
ds = xr.open_zarr(tmp_zarr)
40+
ds = xr.open_zarr(tmp_zarrfile)
4141
assert ds.attrs["parcels_kernels"].lower() == f"{mode}ParticleDoNothing".lower()
4242

4343

@@ -56,36 +56,36 @@ def test_pfile_array_write_zarr_memorystore(fieldset, mode):
5656

5757

5858
@pytest.mark.parametrize("mode", ["scipy", "jit"])
59-
def test_pfile_array_remove_particles(fieldset, mode, tmp_zarr):
59+
def test_pfile_array_remove_particles(fieldset, mode, tmp_zarrfile):
6060
npart = 10
6161
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0)
62-
pfile = pset.ParticleFile(tmp_zarr)
62+
pfile = pset.ParticleFile(tmp_zarrfile)
6363
pfile.write(pset, 0)
6464
pset.remove_indices(3)
6565
for p in pset:
6666
p.time = 1
6767
pfile.write(pset, 1)
6868

69-
ds = xr.open_zarr(tmp_zarr)
69+
ds = xr.open_zarr(tmp_zarrfile)
7070
timearr = ds["time"][:]
7171
assert (np.isnat(timearr[3, 1])) and (np.isfinite(timearr[3, 0]))
7272
ds.close()
7373

7474

7575
@pytest.mark.parametrize("mode", ["scipy", "jit"])
76-
def test_pfile_set_towrite_False(fieldset, mode, tmp_zarr):
76+
def test_pfile_set_towrite_False(fieldset, mode, tmp_zarrfile):
7777
npart = 10
7878
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart))
7979
pset.set_variable_write_status("depth", False)
8080
pset.set_variable_write_status("lat", False)
81-
pfile = pset.ParticleFile(tmp_zarr, outputdt=1)
81+
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1)
8282

8383
def Update_lon(particle, fieldset, time):
8484
particle_dlon += 0.1 # noqa
8585

8686
pset.execute(Update_lon, runtime=10, output_file=pfile)
8787

88-
ds = xr.open_zarr(tmp_zarr)
88+
ds = xr.open_zarr(tmp_zarrfile)
8989
assert "time" in ds
9090
assert "z" not in ds
9191
assert "lat" not in ds
@@ -98,18 +98,18 @@ def Update_lon(particle, fieldset, time):
9898

9999
@pytest.mark.parametrize("mode", ["scipy", "jit"])
100100
@pytest.mark.parametrize("chunks_obs", [1, None])
101-
def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmp_zarr):
101+
def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmp_zarrfile):
102102
npart = 10
103103
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0)
104104
chunks = (npart, chunks_obs) if chunks_obs else None
105-
pfile = pset.ParticleFile(tmp_zarr, chunks=chunks)
105+
pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks)
106106
pfile.write(pset, 0)
107107
for _ in range(npart):
108108
pset.remove_indices(-1)
109109
pfile.write(pset, 1)
110110
pfile.write(pset, 2)
111111

112-
ds = xr.open_zarr(tmp_zarr)
112+
ds = xr.open_zarr(tmp_zarrfile)
113113
assert np.allclose(ds["time"][:, 0], np.timedelta64(0, "s"), atol=np.timedelta64(1, "ms"))
114114
if chunks_obs is not None:
115115
assert ds["time"][:].shape == chunks
@@ -120,22 +120,22 @@ def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmp_zarr):
120120

121121

122122
@pytest.mark.parametrize("mode", ["scipy", "jit"])
123-
def test_variable_write_double(fieldset, mode, tmp_zarr):
123+
def test_variable_write_double(fieldset, mode, tmp_zarrfile):
124124
def Update_lon(particle, fieldset, time):
125125
particle_dlon += 0.1 # noqa
126126

127127
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=[0], lat=[0], lonlatdepth_dtype=np.float64)
128-
ofile = pset.ParticleFile(name=tmp_zarr, outputdt=0.00001)
128+
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.00001)
129129
pset.execute(pset.Kernel(Update_lon), endtime=0.001, dt=0.00001, output_file=ofile)
130130

131-
ds = xr.open_zarr(tmp_zarr)
131+
ds = xr.open_zarr(tmp_zarrfile)
132132
lons = ds["lon"][:]
133133
assert isinstance(lons.values[0, 0], np.float64)
134134
ds.close()
135135

136136

137137
@pytest.mark.parametrize("mode", ["scipy", "jit"])
138-
def test_write_dtypes_pfile(fieldset, mode, tmp_zarr):
138+
def test_write_dtypes_pfile(fieldset, mode, tmp_zarrfile):
139139
dtypes = [np.float32, np.float64, np.int32, np.uint32, np.int64, np.uint64]
140140
if mode == "scipy":
141141
dtypes.extend([np.bool_, np.int8, np.uint8, np.int16, np.uint16])
@@ -144,19 +144,19 @@ def test_write_dtypes_pfile(fieldset, mode, tmp_zarr):
144144
MyParticle = ptype[mode].add_variables(extra_vars)
145145

146146
pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=0)
147-
pfile = pset.ParticleFile(name=tmp_zarr, outputdt=1)
147+
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1)
148148
pfile.write(pset, 0)
149149

150150
ds = xr.open_zarr(
151-
tmp_zarr, mask_and_scale=False
151+
tmp_zarrfile, mask_and_scale=False
152152
) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float
153153
for d in dtypes:
154154
assert ds[f"v_{d.__name__}"].dtype == d
155155

156156

157157
@pytest.mark.parametrize("mode", ["scipy", "jit"])
158158
@pytest.mark.parametrize("npart", [1, 2, 5])
159-
def test_variable_written_once(fieldset, mode, tmp_zarr, npart):
159+
def test_variable_written_once(fieldset, mode, tmp_zarrfile, npart):
160160
def Update_v(particle, fieldset, time):
161161
particle.v_once += 1.0
162162
particle.age += particle.dt
@@ -171,11 +171,11 @@ def Update_v(particle, fieldset, time):
171171
lat = np.linspace(1, 0, npart)
172172
time = np.arange(0, npart / 10.0, 0.1, dtype=np.float64)
173173
pset = ParticleSet(fieldset, pclass=MyParticle, lon=lon, lat=lat, time=time, v_once=time)
174-
ofile = pset.ParticleFile(name=tmp_zarr, outputdt=0.1)
174+
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.1)
175175
pset.execute(pset.Kernel(Update_v), endtime=1, dt=0.1, output_file=ofile)
176176

177177
assert np.allclose(pset.v_once - time - pset.age * 10, 1, atol=1e-5)
178-
ds = xr.open_zarr(tmp_zarr)
178+
ds = xr.open_zarr(tmp_zarrfile)
179179
vfile = np.ma.filled(ds["v_once"][:], np.nan)
180180
assert vfile.shape == (npart,)
181181
ds.close()
@@ -186,7 +186,7 @@ def Update_v(particle, fieldset, time):
186186
@pytest.mark.parametrize("repeatdt", range(1, 3))
187187
@pytest.mark.parametrize("dt", [-1, 1])
188188
@pytest.mark.parametrize("maxvar", [2, 4, 10])
189-
def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, repeatdt, tmp_zarr, dt, maxvar):
189+
def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, repeatdt, tmp_zarrfile, dt, maxvar):
190190
runtime = 10
191191
fieldset.maxvar = maxvar
192192
pset = None
@@ -201,7 +201,7 @@ def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, rep
201201
pset = ParticleSet(
202202
fieldset, lon=np.zeros(runtime), lat=np.zeros(runtime), pclass=MyParticle, time=list(range(runtime))
203203
)
204-
pfile = pset.ParticleFile(tmp_zarr, outputdt=abs(dt), chunks=(1, 1))
204+
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1))
205205

206206
def IncrLon(particle, fieldset, time):
207207
particle.sample_var += 1.0
@@ -211,7 +211,7 @@ def IncrLon(particle, fieldset, time):
211211
for _ in range(runtime):
212212
pset.execute(IncrLon, dt=dt, runtime=1.0, output_file=pfile)
213213

214-
ds = xr.open_zarr(tmp_zarr)
214+
ds = xr.open_zarr(tmp_zarrfile)
215215
samplevar = ds["sample_var"][:]
216216
if type == "repeatdt":
217217
assert samplevar.shape == (runtime // repeatdt, min(maxvar + 1, runtime))
@@ -221,47 +221,47 @@ def IncrLon(particle, fieldset, time):
221221
# test whether samplevar[:, k] = k
222222
for k in range(samplevar.shape[1]):
223223
assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k + 1)
224-
filesize = os.path.getsize(str(tmp_zarr))
224+
filesize = os.path.getsize(str(tmp_zarrfile))
225225
assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB
226226
ds.close()
227227

228228

229229
@pytest.mark.parametrize("mode", ["scipy", "jit"])
230230
@pytest.mark.parametrize("repeatdt", [1, 2])
231231
@pytest.mark.parametrize("nump", [1, 10])
232-
def test_pfile_chunks_repeatedrelease(fieldset, mode, repeatdt, nump, tmp_zarr):
232+
def test_pfile_chunks_repeatedrelease(fieldset, mode, repeatdt, nump, tmp_zarrfile):
233233
runtime = 8
234234
pset = ParticleSet(
235235
fieldset, pclass=ptype[mode], lon=np.zeros((nump, 1)), lat=np.zeros((nump, 1)), repeatdt=repeatdt
236236
)
237237
chunks = (20, 10)
238-
pfile = pset.ParticleFile(tmp_zarr, outputdt=1, chunks=chunks)
238+
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1, chunks=chunks)
239239

240240
def DoNothing(particle, fieldset, time):
241241
pass
242242

243243
pset.execute(DoNothing, dt=1, runtime=runtime, output_file=pfile)
244-
ds = xr.open_zarr(tmp_zarr)
244+
ds = xr.open_zarr(tmp_zarrfile)
245245
assert ds["time"].shape == (int(nump * runtime / repeatdt), chunks[1])
246246

247247

248248
@pytest.mark.parametrize("mode", ["scipy", "jit"])
249-
def test_write_timebackward(fieldset, mode, tmp_zarr):
249+
def test_write_timebackward(fieldset, mode, tmp_zarrfile):
250250
def Update_lon(particle, fieldset, time):
251251
particle_dlon -= 0.1 * particle.dt # noqa
252252

253253
pset = ParticleSet(fieldset, pclass=ptype[mode], lat=np.linspace(0, 1, 3), lon=[0, 0, 0], time=[1, 2, 3])
254-
pfile = pset.ParticleFile(name=tmp_zarr, outputdt=1.0)
254+
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1.0)
255255
pset.execute(pset.Kernel(Update_lon), runtime=4, dt=-1.0, output_file=pfile)
256-
ds = xr.open_zarr(tmp_zarr)
256+
ds = xr.open_zarr(tmp_zarrfile)
257257
trajs = ds["trajectory"][:]
258258
assert trajs.values.dtype == "int64"
259259
assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release
260260
ds.close()
261261

262262

263263
@pytest.mark.parametrize("mode", ["scipy", "jit"])
264-
def test_write_xiyi(fieldset, mode, tmp_zarr):
264+
def test_write_xiyi(fieldset, mode, tmp_zarrfile):
265265
fieldset.U.data[:] = 1 # set a non-zero zonal velocity
266266
fieldset.add_field(Field(name="P", data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2]))
267267
dt = 3600
@@ -289,10 +289,10 @@ def SampleP(particle, fieldset, time):
289289
_ = fieldset.P[particle] # To trigger sampling of the P field
290290

291291
pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1], lonlatdepth_dtype=np.float64)
292-
pfile = pset.ParticleFile(name=tmp_zarr, outputdt=dt)
292+
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=dt)
293293
pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile)
294294

295-
ds = xr.open_zarr(tmp_zarr)
295+
ds = xr.open_zarr(tmp_zarrfile)
296296
pxi0 = ds["pxi0"][:].values.astype(np.int32)
297297
pxi1 = ds["pxi1"][:].values.astype(np.int32)
298298
lons = ds["lon"][:].values
@@ -320,15 +320,15 @@ def test_set_calendar():
320320

321321

322322
@pytest.mark.parametrize("mode", ["scipy", "jit"])
323-
def test_reset_dt(fieldset, mode, tmp_zarr):
323+
def test_reset_dt(fieldset, mode, tmp_zarrfile):
324324
# Assert that p.dt gets reset when a write_time is not a multiple of dt
325325
# for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions
326326

327327
def Update_lon(particle, fieldset, time):
328328
particle_dlon += 0.1 # noqa
329329

330330
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=[0], lat=[0], lonlatdepth_dtype=np.float64)
331-
ofile = pset.ParticleFile(name=tmp_zarr, outputdt=0.05)
331+
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.05)
332332
pset.execute(pset.Kernel(Update_lon), endtime=0.12, dt=0.02, output_file=ofile)
333333

334334
assert np.allclose(pset.lon, 0.6)

tests/test_particlesets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_pset_create_list_with_customvariable(fieldset, mode):
7979

8080
@pytest.mark.parametrize("mode", ["scipy", "jit"])
8181
@pytest.mark.parametrize("restart", [True, False])
82-
def test_pset_create_fromparticlefile(fieldset, mode, restart, tmp_zarr):
82+
def test_pset_create_fromparticlefile(fieldset, mode, restart, tmp_zarrfile):
8383
lon = np.linspace(0, 1, 10, dtype=np.float32)
8484
lat = np.linspace(1, 0, 10, dtype=np.float32)
8585

@@ -88,7 +88,7 @@ def test_pset_create_fromparticlefile(fieldset, mode, restart, tmp_zarr):
8888
TestParticle = TestParticle.add_variable("p3", np.float64, to_write="once")
8989

9090
pset = ParticleSet(fieldset, lon=lon, lat=lat, depth=[4] * len(lon), pclass=TestParticle, p3=np.arange(len(lon)))
91-
pfile = pset.ParticleFile(tmp_zarr, outputdt=1)
91+
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1)
9292

9393
def Kernel(particle, fieldset, time):
9494
particle.p = 2.0
@@ -98,7 +98,7 @@ def Kernel(particle, fieldset, time):
9898
pset.execute(Kernel, runtime=2, dt=1, output_file=pfile)
9999

100100
pset_new = ParticleSet.from_particlefile(
101-
fieldset, pclass=TestParticle, filename=tmp_zarr, restart=restart, repeatdt=1
101+
fieldset, pclass=TestParticle, filename=tmp_zarrfile, restart=restart, repeatdt=1
102102
)
103103

104104
for var in ["lon", "lat", "depth", "time", "p", "p2", "p3"]:

tests/tools/test_warnings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ def test_fieldset_warnings():
5656
fieldset = FieldSet.from_pop(filenames, variables, dimensions, mesh="flat", timestamps=[0, 1, 2, 3])
5757

5858

59-
def test_file_warnings(tmp_zarr):
59+
def test_file_warnings(tmp_zarrfile):
6060
fieldset = FieldSet.from_data(
6161
data={"U": np.zeros((1, 1)), "V": np.zeros((1, 1))}, dimensions={"lon": [0], "lat": [0]}
6262
)
6363
pset = ParticleSet(fieldset=fieldset, pclass=ScipyParticle, lon=[0, 0], lat=[0, 0], time=[0, 1])
64-
pfile = pset.ParticleFile(name=tmp_zarr, outputdt=2)
64+
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=2)
6565
with pytest.warns(FileWarning, match="Some of the particles have a start time difference.*"):
6666
pset.execute(AdvectionRK4, runtime=3, dt=1, output_file=pfile)
6767

0 commit comments

Comments
 (0)