@@ -32,13 +32,12 @@ def fieldset():
3232
3333
3434@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
35- def test_metadata (fieldset , mode , tmpdir ):
36- filepath = tmpdir .join ("pfile_metadata.zarr" )
35+ def test_metadata (fieldset , mode , tmp_zarrfile ):
3736 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = 0 , lat = 0 )
3837
39- pset .execute (DoNothing , runtime = 1 , output_file = pset .ParticleFile (filepath ))
38+ pset .execute (DoNothing , runtime = 1 , output_file = pset .ParticleFile (tmp_zarrfile ))
4039
41- ds = xr .open_zarr (filepath )
40+ ds = xr .open_zarr (tmp_zarrfile )
4241 assert ds .attrs ["parcels_kernels" ].lower () == f"{ mode } ParticleDoNothing" .lower ()
4342
4443
@@ -57,38 +56,36 @@ def test_pfile_array_write_zarr_memorystore(fieldset, mode):
5756
5857
5958@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
60- def test_pfile_array_remove_particles (fieldset , mode , tmpdir ):
59+ def test_pfile_array_remove_particles (fieldset , mode , tmp_zarrfile ):
6160 npart = 10
62- filepath = tmpdir .join ("pfile_array_remove_particles.zarr" )
6361 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = np .linspace (0 , 1 , npart ), lat = 0.5 * np .ones (npart ), time = 0 )
64- pfile = pset .ParticleFile (filepath )
62+ pfile = pset .ParticleFile (tmp_zarrfile )
6563 pfile .write (pset , 0 )
6664 pset .remove_indices (3 )
6765 for p in pset :
6866 p .time = 1
6967 pfile .write (pset , 1 )
7068
71- ds = xr .open_zarr (filepath )
69+ ds = xr .open_zarr (tmp_zarrfile )
7270 timearr = ds ["time" ][:]
7371 assert (np .isnat (timearr [3 , 1 ])) and (np .isfinite (timearr [3 , 0 ]))
7472 ds .close ()
7573
7674
7775@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
78- def test_pfile_set_towrite_False (fieldset , mode , tmpdir ):
76+ def test_pfile_set_towrite_False (fieldset , mode , tmp_zarrfile ):
7977 npart = 10
80- filepath = tmpdir .join ("pfile_set_towrite_False.zarr" )
8178 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = np .linspace (0 , 1 , npart ), lat = 0.5 * np .ones (npart ))
8279 pset .set_variable_write_status ("depth" , False )
8380 pset .set_variable_write_status ("lat" , False )
84- pfile = pset .ParticleFile (filepath , outputdt = 1 )
81+ pfile = pset .ParticleFile (tmp_zarrfile , outputdt = 1 )
8582
8683 def Update_lon (particle , fieldset , time ):
8784 particle_dlon += 0.1 # noqa
8885
8986 pset .execute (Update_lon , runtime = 10 , output_file = pfile )
9087
91- ds = xr .open_zarr (filepath )
88+ ds = xr .open_zarr (tmp_zarrfile )
9289 assert "time" in ds
9390 assert "z" not in ds
9491 assert "lat" not in ds
@@ -101,19 +98,18 @@ def Update_lon(particle, fieldset, time):
10198
10299@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
103100@pytest .mark .parametrize ("chunks_obs" , [1 , None ])
104- def test_pfile_array_remove_all_particles (fieldset , mode , chunks_obs , tmpdir ):
101+ def test_pfile_array_remove_all_particles (fieldset , mode , chunks_obs , tmp_zarrfile ):
105102 npart = 10
106- filepath = tmpdir .join ("pfile_array_remove_particles.zarr" )
107103 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = np .linspace (0 , 1 , npart ), lat = 0.5 * np .ones (npart ), time = 0 )
108104 chunks = (npart , chunks_obs ) if chunks_obs else None
109- pfile = pset .ParticleFile (filepath , chunks = chunks )
105+ pfile = pset .ParticleFile (tmp_zarrfile , chunks = chunks )
110106 pfile .write (pset , 0 )
111107 for _ in range (npart ):
112108 pset .remove_indices (- 1 )
113109 pfile .write (pset , 1 )
114110 pfile .write (pset , 2 )
115111
116- ds = xr .open_zarr (filepath )
112+ ds = xr .open_zarr (tmp_zarrfile )
117113 assert np .allclose (ds ["time" ][:, 0 ], np .timedelta64 (0 , "s" ), atol = np .timedelta64 (1 , "ms" ))
118114 if chunks_obs is not None :
119115 assert ds ["time" ][:].shape == chunks
@@ -124,26 +120,22 @@ def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmpdir):
124120
125121
126122@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
127- def test_variable_write_double (fieldset , mode , tmpdir ):
128- filepath = tmpdir .join ("pfile_variable_write_double.zarr" )
129-
123+ def test_variable_write_double (fieldset , mode , tmp_zarrfile ):
130124 def Update_lon (particle , fieldset , time ):
131125 particle_dlon += 0.1 # noqa
132126
133127 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = [0 ], lat = [0 ], lonlatdepth_dtype = np .float64 )
134- ofile = pset .ParticleFile (name = filepath , outputdt = 0.00001 )
128+ ofile = pset .ParticleFile (name = tmp_zarrfile , outputdt = 0.00001 )
135129 pset .execute (pset .Kernel (Update_lon ), endtime = 0.001 , dt = 0.00001 , output_file = ofile )
136130
137- ds = xr .open_zarr (filepath )
131+ ds = xr .open_zarr (tmp_zarrfile )
138132 lons = ds ["lon" ][:]
139133 assert isinstance (lons .values [0 , 0 ], np .float64 )
140134 ds .close ()
141135
142136
143137@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
144- def test_write_dtypes_pfile (fieldset , mode , tmpdir ):
145- filepath = tmpdir .join ("pfile_dtypes.zarr" )
146-
138+ def test_write_dtypes_pfile (fieldset , mode , tmp_zarrfile ):
147139 dtypes = [np .float32 , np .float64 , np .int32 , np .uint32 , np .int64 , np .uint64 ]
148140 if mode == "scipy" :
149141 dtypes .extend ([np .bool_ , np .int8 , np .uint8 , np .int16 , np .uint16 ])
@@ -152,21 +144,19 @@ def test_write_dtypes_pfile(fieldset, mode, tmpdir):
152144 MyParticle = ptype [mode ].add_variables (extra_vars )
153145
154146 pset = ParticleSet (fieldset , pclass = MyParticle , lon = 0 , lat = 0 , time = 0 )
155- pfile = pset .ParticleFile (name = filepath , outputdt = 1 )
147+ pfile = pset .ParticleFile (name = tmp_zarrfile , outputdt = 1 )
156148 pfile .write (pset , 0 )
157149
158150 ds = xr .open_zarr (
159- filepath , mask_and_scale = False
151+ tmp_zarrfile , mask_and_scale = False
160152 ) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float
161153 for d in dtypes :
162154 assert ds [f"v_{ d .__name__ } " ].dtype == d
163155
164156
165157@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
166158@pytest .mark .parametrize ("npart" , [1 , 2 , 5 ])
167- def test_variable_written_once (fieldset , mode , tmpdir , npart ):
168- filepath = tmpdir .join ("pfile_once_written_variables.zarr" )
169-
159+ def test_variable_written_once (fieldset , mode , tmp_zarrfile , npart ):
170160 def Update_v (particle , fieldset , time ):
171161 particle .v_once += 1.0
172162 particle .age += particle .dt
@@ -181,11 +171,11 @@ def Update_v(particle, fieldset, time):
181171 lat = np .linspace (1 , 0 , npart )
182172 time = np .arange (0 , npart / 10.0 , 0.1 , dtype = np .float64 )
183173 pset = ParticleSet (fieldset , pclass = MyParticle , lon = lon , lat = lat , time = time , v_once = time )
184- ofile = pset .ParticleFile (name = filepath , outputdt = 0.1 )
174+ ofile = pset .ParticleFile (name = tmp_zarrfile , outputdt = 0.1 )
185175 pset .execute (pset .Kernel (Update_v ), endtime = 1 , dt = 0.1 , output_file = ofile )
186176
187177 assert np .allclose (pset .v_once - time - pset .age * 10 , 1 , atol = 1e-5 )
188- ds = xr .open_zarr (filepath )
178+ ds = xr .open_zarr (tmp_zarrfile )
189179 vfile = np .ma .filled (ds ["v_once" ][:], np .nan )
190180 assert vfile .shape == (npart ,)
191181 ds .close ()
@@ -196,7 +186,7 @@ def Update_v(particle, fieldset, time):
196186@pytest .mark .parametrize ("repeatdt" , range (1 , 3 ))
197187@pytest .mark .parametrize ("dt" , [- 1 , 1 ])
198188@pytest .mark .parametrize ("maxvar" , [2 , 4 , 10 ])
199- def test_pset_repeated_release_delayed_adding_deleting (type , fieldset , mode , repeatdt , tmpdir , dt , maxvar ):
189+ def test_pset_repeated_release_delayed_adding_deleting (type , fieldset , mode , repeatdt , tmp_zarrfile , dt , maxvar ):
200190 runtime = 10
201191 fieldset .maxvar = maxvar
202192 pset = None
@@ -211,8 +201,7 @@ def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, rep
211201 pset = ParticleSet (
212202 fieldset , lon = np .zeros (runtime ), lat = np .zeros (runtime ), pclass = MyParticle , time = list (range (runtime ))
213203 )
214- outfilepath = tmpdir .join ("pfile_repeated_release.zarr" )
215- pfile = pset .ParticleFile (outfilepath , outputdt = abs (dt ), chunks = (1 , 1 ))
204+ pfile = pset .ParticleFile (tmp_zarrfile , outputdt = abs (dt ), chunks = (1 , 1 ))
216205
217206 def IncrLon (particle , fieldset , time ):
218207 particle .sample_var += 1.0
@@ -222,7 +211,7 @@ def IncrLon(particle, fieldset, time):
222211 for _ in range (runtime ):
223212 pset .execute (IncrLon , dt = dt , runtime = 1.0 , output_file = pfile )
224213
225- ds = xr .open_zarr (outfilepath )
214+ ds = xr .open_zarr (tmp_zarrfile )
226215 samplevar = ds ["sample_var" ][:]
227216 if type == "repeatdt" :
228217 assert samplevar .shape == (runtime // repeatdt , min (maxvar + 1 , runtime ))
@@ -232,51 +221,47 @@ def IncrLon(particle, fieldset, time):
232221 # test whether samplevar[:, k] = k
233222 for k in range (samplevar .shape [1 ]):
234223 assert np .allclose ([p for p in samplevar [:, k ] if np .isfinite (p )], k + 1 )
235- filesize = os .path .getsize (str (outfilepath ))
224+ filesize = os .path .getsize (str (tmp_zarrfile ))
236225 assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB
237226 ds .close ()
238227
239228
240229@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
241230@pytest .mark .parametrize ("repeatdt" , [1 , 2 ])
242231@pytest .mark .parametrize ("nump" , [1 , 10 ])
243- def test_pfile_chunks_repeatedrelease (fieldset , mode , repeatdt , nump , tmpdir ):
232+ def test_pfile_chunks_repeatedrelease (fieldset , mode , repeatdt , nump , tmp_zarrfile ):
244233 runtime = 8
245234 pset = ParticleSet (
246235 fieldset , pclass = ptype [mode ], lon = np .zeros ((nump , 1 )), lat = np .zeros ((nump , 1 )), repeatdt = repeatdt
247236 )
248- outfilepath = tmpdir .join ("pfile_chunks_repeatedrelease.zarr" )
249237 chunks = (20 , 10 )
250- pfile = pset .ParticleFile (outfilepath , outputdt = 1 , chunks = chunks )
238+ pfile = pset .ParticleFile (tmp_zarrfile , outputdt = 1 , chunks = chunks )
251239
252240 def DoNothing (particle , fieldset , time ):
253241 pass
254242
255243 pset .execute (DoNothing , dt = 1 , runtime = runtime , output_file = pfile )
256- ds = xr .open_zarr (outfilepath )
244+ ds = xr .open_zarr (tmp_zarrfile )
257245 assert ds ["time" ].shape == (int (nump * runtime / repeatdt ), chunks [1 ])
258246
259247
260248@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
261- def test_write_timebackward (fieldset , mode , tmpdir ):
262- outfilepath = tmpdir .join ("pfile_write_timebackward.zarr" )
263-
249+ def test_write_timebackward (fieldset , mode , tmp_zarrfile ):
264250 def Update_lon (particle , fieldset , time ):
265251 particle_dlon -= 0.1 * particle .dt # noqa
266252
267253 pset = ParticleSet (fieldset , pclass = ptype [mode ], lat = np .linspace (0 , 1 , 3 ), lon = [0 , 0 , 0 ], time = [1 , 2 , 3 ])
268- pfile = pset .ParticleFile (name = outfilepath , outputdt = 1.0 )
254+ pfile = pset .ParticleFile (name = tmp_zarrfile , outputdt = 1.0 )
269255 pset .execute (pset .Kernel (Update_lon ), runtime = 4 , dt = - 1.0 , output_file = pfile )
270- ds = xr .open_zarr (outfilepath )
256+ ds = xr .open_zarr (tmp_zarrfile )
271257 trajs = ds ["trajectory" ][:]
272258 assert trajs .values .dtype == "int64"
273259 assert np .all (np .diff (trajs .values ) < 0 ) # all particles written in order of release
274260 ds .close ()
275261
276262
277263@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
278- def test_write_xiyi (fieldset , mode , tmpdir ):
279- outfilepath = tmpdir .join ("pfile_xiyi.zarr" )
264+ def test_write_xiyi (fieldset , mode , tmp_zarrfile ):
280265 fieldset .U .data [:] = 1 # set a non-zero zonal velocity
281266 fieldset .add_field (Field (name = "P" , data = np .zeros ((3 , 20 )), lon = np .linspace (0 , 1 , 20 ), lat = [- 2 , 0 , 2 ]))
282267 dt = 3600
@@ -304,10 +289,10 @@ def SampleP(particle, fieldset, time):
304289 _ = fieldset .P [particle ] # To trigger sampling of the P field
305290
306291 pset = ParticleSet (fieldset , pclass = XiYiParticle , lon = [0 , 0.2 ], lat = [0.2 , 1 ], lonlatdepth_dtype = np .float64 )
307- pfile = pset .ParticleFile (name = outfilepath , outputdt = dt )
292+ pfile = pset .ParticleFile (name = tmp_zarrfile , outputdt = dt )
308293 pset .execute ([SampleP , Get_XiYi , AdvectionRK4 ], endtime = 10 * dt , dt = dt , output_file = pfile )
309294
310- ds = xr .open_zarr (outfilepath )
295+ ds = xr .open_zarr (tmp_zarrfile )
311296 pxi0 = ds ["pxi0" ][:].values .astype (np .int32 )
312297 pxi1 = ds ["pxi1" ][:].values .astype (np .int32 )
313298 lons = ds ["lon" ][:].values
@@ -335,16 +320,15 @@ def test_set_calendar():
335320
336321
337322@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
338- def test_reset_dt (fieldset , mode , tmpdir ):
323+ def test_reset_dt (fieldset , mode , tmp_zarrfile ):
339324 # Assert that p.dt gets reset when a write_time is not a multiple of dt
340325 # for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions
341- filepath = tmpdir .join ("pfile_reset_dt.zarr" )
342326
343327 def Update_lon (particle , fieldset , time ):
344328 particle_dlon += 0.1 # noqa
345329
346330 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = [0 ], lat = [0 ], lonlatdepth_dtype = np .float64 )
347- ofile = pset .ParticleFile (name = filepath , outputdt = 0.05 )
331+ ofile = pset .ParticleFile (name = tmp_zarrfile , outputdt = 0.05 )
348332 pset .execute (pset .Kernel (Update_lon ), endtime = 0.12 , dt = 0.02 , output_file = ofile )
349333
350334 assert np .allclose (pset .lon , 0.6 )
0 commit comments