Skip to content

Commit 0d14640

Browse files
Add tmp_zarr fixture (#1784)
* Add tmp_zarr fixture And specify warning in test_file_warnings() * Review feedback
1 parent 1995da0 commit 0d14640

File tree

7 files changed

+58
-72
lines changed

7 files changed

+58
-72
lines changed

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import pytest
2+
3+
4+
@pytest.fixture()
5+
def tmp_zarrfile(tmp_path, request):
6+
test_name = request.node.name
7+
yield tmp_path / f"{test_name}-output.zarr"

tests/test_advection.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def test_analyticalAgrid(mode):
605605
@pytest.mark.parametrize("v", [1, -0.3, 0, -1])
606606
@pytest.mark.parametrize("w", [None, 1, -0.3, 0, -1])
607607
@pytest.mark.parametrize("direction", [1, -1])
608-
def test_uniform_analytical(mode, u, v, w, direction, tmpdir):
608+
def test_uniform_analytical(mode, u, v, w, direction, tmp_zarrfile):
609609
lon = np.arange(0, 15, dtype=np.float32)
610610
lat = np.arange(0, 15, dtype=np.float32)
611611
if w is not None:
@@ -625,15 +625,14 @@ def test_uniform_analytical(mode, u, v, w, direction, tmpdir):
625625
x0, y0, z0 = 6.1, 6.2, 20
626626
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=x0, lat=y0, depth=z0)
627627

628-
outfile_path = tmpdir.join("uniformanalytical.zarr")
629-
outfile = pset.ParticleFile(name=outfile_path, outputdt=1, chunks=(1, 1))
628+
outfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1, chunks=(1, 1))
630629
pset.execute(AdvectionAnalytical, runtime=4, dt=direction, output_file=outfile)
631630
assert np.abs(pset.lon - x0 - pset.time * u) < 1e-6
632631
assert np.abs(pset.lat - y0 - pset.time * v) < 1e-6
633632
if w is not None:
634633
assert np.abs(pset.depth - z0 - pset.time * w) < 1e-4
635634

636-
ds = xr.open_zarr(outfile_path)
635+
ds = xr.open_zarr(tmp_zarrfile)
637636
times = (direction * ds["time"][:]).values.astype("timedelta64[s]")[0]
638637
timeref = np.arange(1, 5).astype("timedelta64[s]")
639638
assert np.allclose(times, timeref, atol=np.timedelta64(1, "ms"))

tests/test_fieldset.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,7 @@ def SampleUV2(particle, fieldset, time):
655655
assert abs(pset.lat[0] - 0.5) < 1e-9
656656

657657

658-
def test_fieldset_write(tmpdir):
659-
filepath = tmpdir.join("fieldset_write.zarr")
658+
def test_fieldset_write(tmp_zarrfile):
660659
xdim, ydim = 3, 4
661660
lon = np.linspace(0.0, 10.0, xdim, dtype=np.float32)
662661
lat = np.linspace(0.0, 10.0, ydim, dtype=np.float32)
@@ -674,12 +673,12 @@ def UpdateU(particle, fieldset, time):
674673
fieldset.U.grid.time[0] = time
675674

676675
pset = ParticleSet(fieldset, pclass=ScipyParticle, lon=5, lat=5)
677-
ofile = pset.ParticleFile(name=filepath, outputdt=2.0)
676+
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=2.0)
678677
pset.execute(UpdateU, dt=1, runtime=10, output_file=ofile)
679678

680679
assert fieldset.U.data[0, 1, 0] == 11
681680

682-
da = xr.open_dataset(str(filepath).replace(".zarr", "_0005U.nc"))
681+
da = xr.open_dataset(str(tmp_zarrfile).replace(".zarr", "_0005U.nc"))
683682
assert np.allclose(fieldset.U.data, da["U"].values, atol=1.0)
684683

685684

tests/test_fieldset_sampling.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def Recover(particle, fieldset, time):
891891

892892

893893
@pytest.mark.parametrize("mode", ["jit", "scipy"])
894-
def test_fieldset_sampling_updating_order(mode, tmpdir):
894+
def test_fieldset_sampling_updating_order(mode, tmp_zarrfile):
895895
def calc_p(t, y, x):
896896
return 10 * t + x + 0.2 * y
897897

@@ -923,11 +923,10 @@ def SampleP(particle, fieldset, time):
923923

924924
kernels = [AdvectionRK4, SampleP]
925925

926-
filename = tmpdir.join("interpolation_offset.zarr")
927-
pfile = pset.ParticleFile(filename, outputdt=1)
926+
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1)
928927
pset.execute(kernels, endtime=1, dt=1, output_file=pfile)
929928

930-
ds = xr.open_zarr(filename)
929+
ds = xr.open_zarr(tmp_zarrfile)
931930
for t in range(len(ds["obs"])):
932931
for i in range(len(ds["trajectory"])):
933932
assert np.isclose(

tests/test_particlefile.py

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,12 @@ def fieldset():
3232

3333

3434
@pytest.mark.parametrize("mode", ["scipy", "jit"])
35-
def test_metadata(fieldset, mode, tmpdir):
36-
filepath = tmpdir.join("pfile_metadata.zarr")
35+
def test_metadata(fieldset, mode, tmp_zarrfile):
3736
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=0, lat=0)
3837

39-
pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(filepath))
38+
pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile))
4039

41-
ds = xr.open_zarr(filepath)
40+
ds = xr.open_zarr(tmp_zarrfile)
4241
assert ds.attrs["parcels_kernels"].lower() == f"{mode}ParticleDoNothing".lower()
4342

4443

@@ -57,38 +56,36 @@ def test_pfile_array_write_zarr_memorystore(fieldset, mode):
5756

5857

5958
@pytest.mark.parametrize("mode", ["scipy", "jit"])
60-
def test_pfile_array_remove_particles(fieldset, mode, tmpdir):
59+
def test_pfile_array_remove_particles(fieldset, mode, tmp_zarrfile):
6160
npart = 10
62-
filepath = tmpdir.join("pfile_array_remove_particles.zarr")
6361
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0)
64-
pfile = pset.ParticleFile(filepath)
62+
pfile = pset.ParticleFile(tmp_zarrfile)
6563
pfile.write(pset, 0)
6664
pset.remove_indices(3)
6765
for p in pset:
6866
p.time = 1
6967
pfile.write(pset, 1)
7068

71-
ds = xr.open_zarr(filepath)
69+
ds = xr.open_zarr(tmp_zarrfile)
7270
timearr = ds["time"][:]
7371
assert (np.isnat(timearr[3, 1])) and (np.isfinite(timearr[3, 0]))
7472
ds.close()
7573

7674

7775
@pytest.mark.parametrize("mode", ["scipy", "jit"])
78-
def test_pfile_set_towrite_False(fieldset, mode, tmpdir):
76+
def test_pfile_set_towrite_False(fieldset, mode, tmp_zarrfile):
7977
npart = 10
80-
filepath = tmpdir.join("pfile_set_towrite_False.zarr")
8178
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart))
8279
pset.set_variable_write_status("depth", False)
8380
pset.set_variable_write_status("lat", False)
84-
pfile = pset.ParticleFile(filepath, outputdt=1)
81+
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1)
8582

8683
def Update_lon(particle, fieldset, time):
8784
particle_dlon += 0.1 # noqa
8885

8986
pset.execute(Update_lon, runtime=10, output_file=pfile)
9087

91-
ds = xr.open_zarr(filepath)
88+
ds = xr.open_zarr(tmp_zarrfile)
9289
assert "time" in ds
9390
assert "z" not in ds
9491
assert "lat" not in ds
@@ -101,19 +98,18 @@ def Update_lon(particle, fieldset, time):
10198

10299
@pytest.mark.parametrize("mode", ["scipy", "jit"])
103100
@pytest.mark.parametrize("chunks_obs", [1, None])
104-
def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmpdir):
101+
def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmp_zarrfile):
105102
npart = 10
106-
filepath = tmpdir.join("pfile_array_remove_particles.zarr")
107103
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0)
108104
chunks = (npart, chunks_obs) if chunks_obs else None
109-
pfile = pset.ParticleFile(filepath, chunks=chunks)
105+
pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks)
110106
pfile.write(pset, 0)
111107
for _ in range(npart):
112108
pset.remove_indices(-1)
113109
pfile.write(pset, 1)
114110
pfile.write(pset, 2)
115111

116-
ds = xr.open_zarr(filepath)
112+
ds = xr.open_zarr(tmp_zarrfile)
117113
assert np.allclose(ds["time"][:, 0], np.timedelta64(0, "s"), atol=np.timedelta64(1, "ms"))
118114
if chunks_obs is not None:
119115
assert ds["time"][:].shape == chunks
@@ -124,26 +120,22 @@ def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmpdir):
124120

125121

126122
@pytest.mark.parametrize("mode", ["scipy", "jit"])
127-
def test_variable_write_double(fieldset, mode, tmpdir):
128-
filepath = tmpdir.join("pfile_variable_write_double.zarr")
129-
123+
def test_variable_write_double(fieldset, mode, tmp_zarrfile):
130124
def Update_lon(particle, fieldset, time):
131125
particle_dlon += 0.1 # noqa
132126

133127
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=[0], lat=[0], lonlatdepth_dtype=np.float64)
134-
ofile = pset.ParticleFile(name=filepath, outputdt=0.00001)
128+
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.00001)
135129
pset.execute(pset.Kernel(Update_lon), endtime=0.001, dt=0.00001, output_file=ofile)
136130

137-
ds = xr.open_zarr(filepath)
131+
ds = xr.open_zarr(tmp_zarrfile)
138132
lons = ds["lon"][:]
139133
assert isinstance(lons.values[0, 0], np.float64)
140134
ds.close()
141135

142136

143137
@pytest.mark.parametrize("mode", ["scipy", "jit"])
144-
def test_write_dtypes_pfile(fieldset, mode, tmpdir):
145-
filepath = tmpdir.join("pfile_dtypes.zarr")
146-
138+
def test_write_dtypes_pfile(fieldset, mode, tmp_zarrfile):
147139
dtypes = [np.float32, np.float64, np.int32, np.uint32, np.int64, np.uint64]
148140
if mode == "scipy":
149141
dtypes.extend([np.bool_, np.int8, np.uint8, np.int16, np.uint16])
@@ -152,21 +144,19 @@ def test_write_dtypes_pfile(fieldset, mode, tmpdir):
152144
MyParticle = ptype[mode].add_variables(extra_vars)
153145

154146
pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=0)
155-
pfile = pset.ParticleFile(name=filepath, outputdt=1)
147+
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1)
156148
pfile.write(pset, 0)
157149

158150
ds = xr.open_zarr(
159-
filepath, mask_and_scale=False
151+
tmp_zarrfile, mask_and_scale=False
160152
) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float
161153
for d in dtypes:
162154
assert ds[f"v_{d.__name__}"].dtype == d
163155

164156

165157
@pytest.mark.parametrize("mode", ["scipy", "jit"])
166158
@pytest.mark.parametrize("npart", [1, 2, 5])
167-
def test_variable_written_once(fieldset, mode, tmpdir, npart):
168-
filepath = tmpdir.join("pfile_once_written_variables.zarr")
169-
159+
def test_variable_written_once(fieldset, mode, tmp_zarrfile, npart):
170160
def Update_v(particle, fieldset, time):
171161
particle.v_once += 1.0
172162
particle.age += particle.dt
@@ -181,11 +171,11 @@ def Update_v(particle, fieldset, time):
181171
lat = np.linspace(1, 0, npart)
182172
time = np.arange(0, npart / 10.0, 0.1, dtype=np.float64)
183173
pset = ParticleSet(fieldset, pclass=MyParticle, lon=lon, lat=lat, time=time, v_once=time)
184-
ofile = pset.ParticleFile(name=filepath, outputdt=0.1)
174+
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.1)
185175
pset.execute(pset.Kernel(Update_v), endtime=1, dt=0.1, output_file=ofile)
186176

187177
assert np.allclose(pset.v_once - time - pset.age * 10, 1, atol=1e-5)
188-
ds = xr.open_zarr(filepath)
178+
ds = xr.open_zarr(tmp_zarrfile)
189179
vfile = np.ma.filled(ds["v_once"][:], np.nan)
190180
assert vfile.shape == (npart,)
191181
ds.close()
@@ -196,7 +186,7 @@ def Update_v(particle, fieldset, time):
196186
@pytest.mark.parametrize("repeatdt", range(1, 3))
197187
@pytest.mark.parametrize("dt", [-1, 1])
198188
@pytest.mark.parametrize("maxvar", [2, 4, 10])
199-
def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, repeatdt, tmpdir, dt, maxvar):
189+
def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, repeatdt, tmp_zarrfile, dt, maxvar):
200190
runtime = 10
201191
fieldset.maxvar = maxvar
202192
pset = None
@@ -211,8 +201,7 @@ def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, rep
211201
pset = ParticleSet(
212202
fieldset, lon=np.zeros(runtime), lat=np.zeros(runtime), pclass=MyParticle, time=list(range(runtime))
213203
)
214-
outfilepath = tmpdir.join("pfile_repeated_release.zarr")
215-
pfile = pset.ParticleFile(outfilepath, outputdt=abs(dt), chunks=(1, 1))
204+
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1))
216205

217206
def IncrLon(particle, fieldset, time):
218207
particle.sample_var += 1.0
@@ -222,7 +211,7 @@ def IncrLon(particle, fieldset, time):
222211
for _ in range(runtime):
223212
pset.execute(IncrLon, dt=dt, runtime=1.0, output_file=pfile)
224213

225-
ds = xr.open_zarr(outfilepath)
214+
ds = xr.open_zarr(tmp_zarrfile)
226215
samplevar = ds["sample_var"][:]
227216
if type == "repeatdt":
228217
assert samplevar.shape == (runtime // repeatdt, min(maxvar + 1, runtime))
@@ -232,51 +221,47 @@ def IncrLon(particle, fieldset, time):
232221
# test whether samplevar[:, k] = k
233222
for k in range(samplevar.shape[1]):
234223
assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k + 1)
235-
filesize = os.path.getsize(str(outfilepath))
224+
filesize = os.path.getsize(str(tmp_zarrfile))
236225
assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB
237226
ds.close()
238227

239228

240229
@pytest.mark.parametrize("mode", ["scipy", "jit"])
241230
@pytest.mark.parametrize("repeatdt", [1, 2])
242231
@pytest.mark.parametrize("nump", [1, 10])
243-
def test_pfile_chunks_repeatedrelease(fieldset, mode, repeatdt, nump, tmpdir):
232+
def test_pfile_chunks_repeatedrelease(fieldset, mode, repeatdt, nump, tmp_zarrfile):
244233
runtime = 8
245234
pset = ParticleSet(
246235
fieldset, pclass=ptype[mode], lon=np.zeros((nump, 1)), lat=np.zeros((nump, 1)), repeatdt=repeatdt
247236
)
248-
outfilepath = tmpdir.join("pfile_chunks_repeatedrelease.zarr")
249237
chunks = (20, 10)
250-
pfile = pset.ParticleFile(outfilepath, outputdt=1, chunks=chunks)
238+
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1, chunks=chunks)
251239

252240
def DoNothing(particle, fieldset, time):
253241
pass
254242

255243
pset.execute(DoNothing, dt=1, runtime=runtime, output_file=pfile)
256-
ds = xr.open_zarr(outfilepath)
244+
ds = xr.open_zarr(tmp_zarrfile)
257245
assert ds["time"].shape == (int(nump * runtime / repeatdt), chunks[1])
258246

259247

260248
@pytest.mark.parametrize("mode", ["scipy", "jit"])
261-
def test_write_timebackward(fieldset, mode, tmpdir):
262-
outfilepath = tmpdir.join("pfile_write_timebackward.zarr")
263-
249+
def test_write_timebackward(fieldset, mode, tmp_zarrfile):
264250
def Update_lon(particle, fieldset, time):
265251
particle_dlon -= 0.1 * particle.dt # noqa
266252

267253
pset = ParticleSet(fieldset, pclass=ptype[mode], lat=np.linspace(0, 1, 3), lon=[0, 0, 0], time=[1, 2, 3])
268-
pfile = pset.ParticleFile(name=outfilepath, outputdt=1.0)
254+
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1.0)
269255
pset.execute(pset.Kernel(Update_lon), runtime=4, dt=-1.0, output_file=pfile)
270-
ds = xr.open_zarr(outfilepath)
256+
ds = xr.open_zarr(tmp_zarrfile)
271257
trajs = ds["trajectory"][:]
272258
assert trajs.values.dtype == "int64"
273259
assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release
274260
ds.close()
275261

276262

277263
@pytest.mark.parametrize("mode", ["scipy", "jit"])
278-
def test_write_xiyi(fieldset, mode, tmpdir):
279-
outfilepath = tmpdir.join("pfile_xiyi.zarr")
264+
def test_write_xiyi(fieldset, mode, tmp_zarrfile):
280265
fieldset.U.data[:] = 1 # set a non-zero zonal velocity
281266
fieldset.add_field(Field(name="P", data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2]))
282267
dt = 3600
@@ -304,10 +289,10 @@ def SampleP(particle, fieldset, time):
304289
_ = fieldset.P[particle] # To trigger sampling of the P field
305290

306291
pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1], lonlatdepth_dtype=np.float64)
307-
pfile = pset.ParticleFile(name=outfilepath, outputdt=dt)
292+
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=dt)
308293
pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile)
309294

310-
ds = xr.open_zarr(outfilepath)
295+
ds = xr.open_zarr(tmp_zarrfile)
311296
pxi0 = ds["pxi0"][:].values.astype(np.int32)
312297
pxi1 = ds["pxi1"][:].values.astype(np.int32)
313298
lons = ds["lon"][:].values
@@ -335,16 +320,15 @@ def test_set_calendar():
335320

336321

337322
@pytest.mark.parametrize("mode", ["scipy", "jit"])
338-
def test_reset_dt(fieldset, mode, tmpdir):
323+
def test_reset_dt(fieldset, mode, tmp_zarrfile):
339324
# Assert that p.dt gets reset when a write_time is not a multiple of dt
340325
# for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions
341-
filepath = tmpdir.join("pfile_reset_dt.zarr")
342326

343327
def Update_lon(particle, fieldset, time):
344328
particle_dlon += 0.1 # noqa
345329

346330
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=[0], lat=[0], lonlatdepth_dtype=np.float64)
347-
ofile = pset.ParticleFile(name=filepath, outputdt=0.05)
331+
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.05)
348332
pset.execute(pset.Kernel(Update_lon), endtime=0.12, dt=0.02, output_file=ofile)
349333

350334
assert np.allclose(pset.lon, 0.6)

tests/test_particlesets.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def test_pset_create_list_with_customvariable(fieldset, mode):
7979

8080
@pytest.mark.parametrize("mode", ["scipy", "jit"])
8181
@pytest.mark.parametrize("restart", [True, False])
82-
def test_pset_create_fromparticlefile(fieldset, mode, restart, tmpdir):
83-
filename = tmpdir.join("pset_fromparticlefile.zarr")
82+
def test_pset_create_fromparticlefile(fieldset, mode, restart, tmp_zarrfile):
8483
lon = np.linspace(0, 1, 10, dtype=np.float32)
8584
lat = np.linspace(1, 0, 10, dtype=np.float32)
8685

@@ -89,7 +88,7 @@ def test_pset_create_fromparticlefile(fieldset, mode, restart, tmpdir):
8988
TestParticle = TestParticle.add_variable("p3", np.float64, to_write="once")
9089

9190
pset = ParticleSet(fieldset, lon=lon, lat=lat, depth=[4] * len(lon), pclass=TestParticle, p3=np.arange(len(lon)))
92-
pfile = pset.ParticleFile(filename, outputdt=1)
91+
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1)
9392

9493
def Kernel(particle, fieldset, time):
9594
particle.p = 2.0
@@ -99,7 +98,7 @@ def Kernel(particle, fieldset, time):
9998
pset.execute(Kernel, runtime=2, dt=1, output_file=pfile)
10099

101100
pset_new = ParticleSet.from_particlefile(
102-
fieldset, pclass=TestParticle, filename=filename, restart=restart, repeatdt=1
101+
fieldset, pclass=TestParticle, filename=tmp_zarrfile, restart=restart, repeatdt=1
103102
)
104103

105104
for var in ["lon", "lat", "depth", "time", "p", "p2", "p3"]:

0 commit comments

Comments
 (0)