@@ -32,12 +32,12 @@ def fieldset():
3232
3333
3434@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
35- def test_metadata (fieldset , mode , tmp_zarr ):
35+ def test_metadata (fieldset , mode , tmp_zarrfile ):
3636 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = 0 , lat = 0 )
3737
38- pset .execute (DoNothing , runtime = 1 , output_file = pset .ParticleFile (tmp_zarr ))
38+ pset .execute (DoNothing , runtime = 1 , output_file = pset .ParticleFile (tmp_zarrfile ))
3939
40- ds = xr .open_zarr (tmp_zarr )
40+ ds = xr .open_zarr (tmp_zarrfile )
4141 assert ds .attrs ["parcels_kernels" ].lower () == f"{ mode } ParticleDoNothing" .lower ()
4242
4343
@@ -56,36 +56,36 @@ def test_pfile_array_write_zarr_memorystore(fieldset, mode):
5656
5757
5858@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
59- def test_pfile_array_remove_particles (fieldset , mode , tmp_zarr ):
59+ def test_pfile_array_remove_particles (fieldset , mode , tmp_zarrfile ):
6060 npart = 10
6161 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = np .linspace (0 , 1 , npart ), lat = 0.5 * np .ones (npart ), time = 0 )
62- pfile = pset .ParticleFile (tmp_zarr )
62+ pfile = pset .ParticleFile (tmp_zarrfile )
6363 pfile .write (pset , 0 )
6464 pset .remove_indices (3 )
6565 for p in pset :
6666 p .time = 1
6767 pfile .write (pset , 1 )
6868
69- ds = xr .open_zarr (tmp_zarr )
69+ ds = xr .open_zarr (tmp_zarrfile )
7070 timearr = ds ["time" ][:]
7171 assert (np .isnat (timearr [3 , 1 ])) and (np .isfinite (timearr [3 , 0 ]))
7272 ds .close ()
7373
7474
7575@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
76- def test_pfile_set_towrite_False (fieldset , mode , tmp_zarr ):
76+ def test_pfile_set_towrite_False (fieldset , mode , tmp_zarrfile ):
7777 npart = 10
7878 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = np .linspace (0 , 1 , npart ), lat = 0.5 * np .ones (npart ))
7979 pset .set_variable_write_status ("depth" , False )
8080 pset .set_variable_write_status ("lat" , False )
81- pfile = pset .ParticleFile (tmp_zarr , outputdt = 1 )
81+ pfile = pset .ParticleFile (tmp_zarrfile , outputdt = 1 )
8282
8383 def Update_lon (particle , fieldset , time ):
8484 particle_dlon += 0.1 # noqa
8585
8686 pset .execute (Update_lon , runtime = 10 , output_file = pfile )
8787
88- ds = xr .open_zarr (tmp_zarr )
88+ ds = xr .open_zarr (tmp_zarrfile )
8989 assert "time" in ds
9090 assert "z" not in ds
9191 assert "lat" not in ds
@@ -98,18 +98,18 @@ def Update_lon(particle, fieldset, time):
9898
9999@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
100100@pytest .mark .parametrize ("chunks_obs" , [1 , None ])
101- def test_pfile_array_remove_all_particles (fieldset , mode , chunks_obs , tmp_zarr ):
101+ def test_pfile_array_remove_all_particles (fieldset , mode , chunks_obs , tmp_zarrfile ):
102102 npart = 10
103103 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = np .linspace (0 , 1 , npart ), lat = 0.5 * np .ones (npart ), time = 0 )
104104 chunks = (npart , chunks_obs ) if chunks_obs else None
105- pfile = pset .ParticleFile (tmp_zarr , chunks = chunks )
105+ pfile = pset .ParticleFile (tmp_zarrfile , chunks = chunks )
106106 pfile .write (pset , 0 )
107107 for _ in range (npart ):
108108 pset .remove_indices (- 1 )
109109 pfile .write (pset , 1 )
110110 pfile .write (pset , 2 )
111111
112- ds = xr .open_zarr (tmp_zarr )
112+ ds = xr .open_zarr (tmp_zarrfile )
113113 assert np .allclose (ds ["time" ][:, 0 ], np .timedelta64 (0 , "s" ), atol = np .timedelta64 (1 , "ms" ))
114114 if chunks_obs is not None :
115115 assert ds ["time" ][:].shape == chunks
@@ -120,22 +120,22 @@ def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmp_zarr):
120120
121121
122122@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
123- def test_variable_write_double (fieldset , mode , tmp_zarr ):
123+ def test_variable_write_double (fieldset , mode , tmp_zarrfile ):
124124 def Update_lon (particle , fieldset , time ):
125125 particle_dlon += 0.1 # noqa
126126
127127 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = [0 ], lat = [0 ], lonlatdepth_dtype = np .float64 )
128- ofile = pset .ParticleFile (name = tmp_zarr , outputdt = 0.00001 )
128+ ofile = pset .ParticleFile (name = tmp_zarrfile , outputdt = 0.00001 )
129129 pset .execute (pset .Kernel (Update_lon ), endtime = 0.001 , dt = 0.00001 , output_file = ofile )
130130
131- ds = xr .open_zarr (tmp_zarr )
131+ ds = xr .open_zarr (tmp_zarrfile )
132132 lons = ds ["lon" ][:]
133133 assert isinstance (lons .values [0 , 0 ], np .float64 )
134134 ds .close ()
135135
136136
137137@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
138- def test_write_dtypes_pfile (fieldset , mode , tmp_zarr ):
138+ def test_write_dtypes_pfile (fieldset , mode , tmp_zarrfile ):
139139 dtypes = [np .float32 , np .float64 , np .int32 , np .uint32 , np .int64 , np .uint64 ]
140140 if mode == "scipy" :
141141 dtypes .extend ([np .bool_ , np .int8 , np .uint8 , np .int16 , np .uint16 ])
@@ -144,19 +144,19 @@ def test_write_dtypes_pfile(fieldset, mode, tmp_zarr):
144144 MyParticle = ptype [mode ].add_variables (extra_vars )
145145
146146 pset = ParticleSet (fieldset , pclass = MyParticle , lon = 0 , lat = 0 , time = 0 )
147- pfile = pset .ParticleFile (name = tmp_zarr , outputdt = 1 )
147+ pfile = pset .ParticleFile (name = tmp_zarrfile , outputdt = 1 )
148148 pfile .write (pset , 0 )
149149
150150 ds = xr .open_zarr (
151- tmp_zarr , mask_and_scale = False
151+ tmp_zarrfile , mask_and_scale = False
152152 ) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float
153153 for d in dtypes :
154154 assert ds [f"v_{ d .__name__ } " ].dtype == d
155155
156156
157157@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
158158@pytest .mark .parametrize ("npart" , [1 , 2 , 5 ])
159- def test_variable_written_once (fieldset , mode , tmp_zarr , npart ):
159+ def test_variable_written_once (fieldset , mode , tmp_zarrfile , npart ):
160160 def Update_v (particle , fieldset , time ):
161161 particle .v_once += 1.0
162162 particle .age += particle .dt
@@ -171,11 +171,11 @@ def Update_v(particle, fieldset, time):
171171 lat = np .linspace (1 , 0 , npart )
172172 time = np .arange (0 , npart / 10.0 , 0.1 , dtype = np .float64 )
173173 pset = ParticleSet (fieldset , pclass = MyParticle , lon = lon , lat = lat , time = time , v_once = time )
174- ofile = pset .ParticleFile (name = tmp_zarr , outputdt = 0.1 )
174+ ofile = pset .ParticleFile (name = tmp_zarrfile , outputdt = 0.1 )
175175 pset .execute (pset .Kernel (Update_v ), endtime = 1 , dt = 0.1 , output_file = ofile )
176176
177177 assert np .allclose (pset .v_once - time - pset .age * 10 , 1 , atol = 1e-5 )
178- ds = xr .open_zarr (tmp_zarr )
178+ ds = xr .open_zarr (tmp_zarrfile )
179179 vfile = np .ma .filled (ds ["v_once" ][:], np .nan )
180180 assert vfile .shape == (npart ,)
181181 ds .close ()
@@ -186,7 +186,7 @@ def Update_v(particle, fieldset, time):
186186@pytest .mark .parametrize ("repeatdt" , range (1 , 3 ))
187187@pytest .mark .parametrize ("dt" , [- 1 , 1 ])
188188@pytest .mark .parametrize ("maxvar" , [2 , 4 , 10 ])
189- def test_pset_repeated_release_delayed_adding_deleting (type , fieldset , mode , repeatdt , tmp_zarr , dt , maxvar ):
189+ def test_pset_repeated_release_delayed_adding_deleting (type , fieldset , mode , repeatdt , tmp_zarrfile , dt , maxvar ):
190190 runtime = 10
191191 fieldset .maxvar = maxvar
192192 pset = None
@@ -201,7 +201,7 @@ def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, rep
201201 pset = ParticleSet (
202202 fieldset , lon = np .zeros (runtime ), lat = np .zeros (runtime ), pclass = MyParticle , time = list (range (runtime ))
203203 )
204- pfile = pset .ParticleFile (tmp_zarr , outputdt = abs (dt ), chunks = (1 , 1 ))
204+ pfile = pset .ParticleFile (tmp_zarrfile , outputdt = abs (dt ), chunks = (1 , 1 ))
205205
206206 def IncrLon (particle , fieldset , time ):
207207 particle .sample_var += 1.0
@@ -211,7 +211,7 @@ def IncrLon(particle, fieldset, time):
211211 for _ in range (runtime ):
212212 pset .execute (IncrLon , dt = dt , runtime = 1.0 , output_file = pfile )
213213
214- ds = xr .open_zarr (tmp_zarr )
214+ ds = xr .open_zarr (tmp_zarrfile )
215215 samplevar = ds ["sample_var" ][:]
216216 if type == "repeatdt" :
217217 assert samplevar .shape == (runtime // repeatdt , min (maxvar + 1 , runtime ))
@@ -221,47 +221,47 @@ def IncrLon(particle, fieldset, time):
221221 # test whether samplevar[:, k] = k
222222 for k in range (samplevar .shape [1 ]):
223223 assert np .allclose ([p for p in samplevar [:, k ] if np .isfinite (p )], k + 1 )
224- filesize = os .path .getsize (str (tmp_zarr ))
224+ filesize = os .path .getsize (str (tmp_zarrfile ))
225225 assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB
226226 ds .close ()
227227
228228
229229@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
230230@pytest .mark .parametrize ("repeatdt" , [1 , 2 ])
231231@pytest .mark .parametrize ("nump" , [1 , 10 ])
232- def test_pfile_chunks_repeatedrelease (fieldset , mode , repeatdt , nump , tmp_zarr ):
232+ def test_pfile_chunks_repeatedrelease (fieldset , mode , repeatdt , nump , tmp_zarrfile ):
233233 runtime = 8
234234 pset = ParticleSet (
235235 fieldset , pclass = ptype [mode ], lon = np .zeros ((nump , 1 )), lat = np .zeros ((nump , 1 )), repeatdt = repeatdt
236236 )
237237 chunks = (20 , 10 )
238- pfile = pset .ParticleFile (tmp_zarr , outputdt = 1 , chunks = chunks )
238+ pfile = pset .ParticleFile (tmp_zarrfile , outputdt = 1 , chunks = chunks )
239239
240240 def DoNothing (particle , fieldset , time ):
241241 pass
242242
243243 pset .execute (DoNothing , dt = 1 , runtime = runtime , output_file = pfile )
244- ds = xr .open_zarr (tmp_zarr )
244+ ds = xr .open_zarr (tmp_zarrfile )
245245 assert ds ["time" ].shape == (int (nump * runtime / repeatdt ), chunks [1 ])
246246
247247
248248@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
249- def test_write_timebackward (fieldset , mode , tmp_zarr ):
249+ def test_write_timebackward (fieldset , mode , tmp_zarrfile ):
250250 def Update_lon (particle , fieldset , time ):
251251 particle_dlon -= 0.1 * particle .dt # noqa
252252
253253 pset = ParticleSet (fieldset , pclass = ptype [mode ], lat = np .linspace (0 , 1 , 3 ), lon = [0 , 0 , 0 ], time = [1 , 2 , 3 ])
254- pfile = pset .ParticleFile (name = tmp_zarr , outputdt = 1.0 )
254+ pfile = pset .ParticleFile (name = tmp_zarrfile , outputdt = 1.0 )
255255 pset .execute (pset .Kernel (Update_lon ), runtime = 4 , dt = - 1.0 , output_file = pfile )
256- ds = xr .open_zarr (tmp_zarr )
256+ ds = xr .open_zarr (tmp_zarrfile )
257257 trajs = ds ["trajectory" ][:]
258258 assert trajs .values .dtype == "int64"
259259 assert np .all (np .diff (trajs .values ) < 0 ) # all particles written in order of release
260260 ds .close ()
261261
262262
263263@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
264- def test_write_xiyi (fieldset , mode , tmp_zarr ):
264+ def test_write_xiyi (fieldset , mode , tmp_zarrfile ):
265265 fieldset .U .data [:] = 1 # set a non-zero zonal velocity
266266 fieldset .add_field (Field (name = "P" , data = np .zeros ((3 , 20 )), lon = np .linspace (0 , 1 , 20 ), lat = [- 2 , 0 , 2 ]))
267267 dt = 3600
@@ -289,10 +289,10 @@ def SampleP(particle, fieldset, time):
289289 _ = fieldset .P [particle ] # To trigger sampling of the P field
290290
291291 pset = ParticleSet (fieldset , pclass = XiYiParticle , lon = [0 , 0.2 ], lat = [0.2 , 1 ], lonlatdepth_dtype = np .float64 )
292- pfile = pset .ParticleFile (name = tmp_zarr , outputdt = dt )
292+ pfile = pset .ParticleFile (name = tmp_zarrfile , outputdt = dt )
293293 pset .execute ([SampleP , Get_XiYi , AdvectionRK4 ], endtime = 10 * dt , dt = dt , output_file = pfile )
294294
295- ds = xr .open_zarr (tmp_zarr )
295+ ds = xr .open_zarr (tmp_zarrfile )
296296 pxi0 = ds ["pxi0" ][:].values .astype (np .int32 )
297297 pxi1 = ds ["pxi1" ][:].values .astype (np .int32 )
298298 lons = ds ["lon" ][:].values
@@ -320,15 +320,15 @@ def test_set_calendar():
320320
321321
322322@pytest .mark .parametrize ("mode" , ["scipy" , "jit" ])
323- def test_reset_dt (fieldset , mode , tmp_zarr ):
323+ def test_reset_dt (fieldset , mode , tmp_zarrfile ):
324324 # Assert that p.dt gets reset when a write_time is not a multiple of dt
325325 # for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions
326326
327327 def Update_lon (particle , fieldset , time ):
328328 particle_dlon += 0.1 # noqa
329329
330330 pset = ParticleSet (fieldset , pclass = ptype [mode ], lon = [0 ], lat = [0 ], lonlatdepth_dtype = np .float64 )
331- ofile = pset .ParticleFile (name = tmp_zarr , outputdt = 0.05 )
331+ ofile = pset .ParticleFile (name = tmp_zarrfile , outputdt = 0.05 )
332332 pset .execute (pset .Kernel (Update_lon ), endtime = 0.12 , dt = 0.02 , output_file = ofile )
333333
334334 assert np .allclose (pset .lon , 0.6 )
0 commit comments