Skip to content

Commit 238c580

Browse files
committed
ref: Add early return for particlefile writing with no data
1 parent 429d3ae commit 238c580

File tree

2 files changed

+76
-73
lines changed

2 files changed

+76
-73
lines changed

parcels/particledata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,9 @@ def flatten_dense_data_array(vname):
350350
cstruct = CParticles(*cdata)
351351
return cstruct
352352

353-
def _to_write_particles(self, pd, time):
353+
def _to_write_particles(self, time):
354354
"""Return the Particles that need to be written at time: if particle.time is between time-dt/2 and time+dt (/2)"""
355+
pd = self._data
355356
return np.where(
356357
(
357358
np.less_equal(time - np.abs(pd["dt"] / 2), pd["time"], where=np.isfinite(pd["time"]))

parcels/particlefile.py

Lines changed: 74 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -285,86 +285,88 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N
285285
return
286286

287287
if indices is None:
288-
indices_to_write = pset.particledata._to_write_particles(pset.particledata._data, time)
288+
indices_to_write = pset.particledata._to_write_particles(time)
289289
else:
290290
indices_to_write = indices
291291

292-
if len(indices_to_write) > 0:
293-
pids = pset.particledata.getvardata("id", indices_to_write)
294-
to_add = sorted(set(pids) - set(self._pids_written.keys()))
295-
for i, pid in enumerate(to_add):
296-
self._pids_written[pid] = self._maxids + i
297-
ids = np.array([self._pids_written[p] for p in pids], dtype=int)
298-
self._maxids = len(self._pids_written)
299-
300-
once_ids = np.where(pset.particledata.getvardata("obs_written", indices_to_write) == 0)[0]
301-
if len(once_ids) > 0:
302-
ids_once = ids[once_ids]
303-
indices_to_write_once = indices_to_write[once_ids]
304-
305-
if self.create_new_zarrfile:
306-
if self.chunks is None:
307-
self._chunks = (len(pset), 1)
308-
if pset._repeatpclass is not None and self.chunks[0] < 1e4: # type: ignore[index]
309-
warnings.warn(
310-
f"ParticleFile chunks are set to {self.chunks}, but this may lead to "
311-
f"a significant slowdown in Parcels when many calls to repeatdt. "
312-
f"Consider setting a larger chunk size for your ParticleFile (e.g. chunks=(int(1e4), 1)).",
313-
FileWarning,
314-
stacklevel=2,
315-
)
316-
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index]
317-
arrsize = (self._maxids, self.chunks[1]) # type: ignore[index]
318-
else:
319-
arrsize = (len(ids), self.chunks[1]) # type: ignore[index]
320-
ds = xr.Dataset(
321-
attrs=self.metadata,
322-
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
292+
if len(indices_to_write) == 0:
293+
return
294+
295+
pids = pset.particledata.getvardata("id", indices_to_write)
296+
to_add = sorted(set(pids) - set(self._pids_written.keys()))
297+
for i, pid in enumerate(to_add):
298+
self._pids_written[pid] = self._maxids + i
299+
ids = np.array([self._pids_written[p] for p in pids], dtype=int)
300+
self._maxids = len(self._pids_written)
301+
302+
once_ids = np.where(pset.particledata.getvardata("obs_written", indices_to_write) == 0)[0]
303+
if len(once_ids) > 0:
304+
ids_once = ids[once_ids]
305+
indices_to_write_once = indices_to_write[once_ids]
306+
307+
if self.create_new_zarrfile:
308+
if self.chunks is None:
309+
self._chunks = (len(pset), 1)
310+
if pset._repeatpclass is not None and self.chunks[0] < 1e4: # type: ignore[index]
311+
warnings.warn(
312+
f"ParticleFile chunks are set to {self.chunks}, but this may lead to "
313+
f"a significant slowdown in Parcels when many calls to repeatdt. "
314+
f"Consider setting a larger chunk size for your ParticleFile (e.g. chunks=(int(1e4), 1)).",
315+
FileWarning,
316+
stacklevel=2,
323317
)
324-
attrs = self._create_variables_attribute_dict()
325-
obs = np.zeros((self._maxids), dtype=np.int32)
326-
for var in self.vars_to_write:
327-
varout = self._convert_varout_name(var)
328-
if varout not in ["trajectory"]: # because 'trajectory' is written as coordinate
329-
if self._write_once(var):
330-
data = np.full(
331-
(arrsize[0],),
332-
self._fill_value_map[self.vars_to_write[var]],
333-
dtype=self.vars_to_write[var],
334-
)
335-
data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
336-
dims = ["trajectory"]
337-
else:
338-
data = np.full(
339-
arrsize, self._fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]
340-
)
341-
data[ids, 0] = pset.particledata.getvardata(var, indices_to_write)
342-
dims = ["trajectory", "obs"]
343-
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
344-
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index]
345-
ds.to_zarr(self.fname, mode="w")
346-
self._create_new_zarrfile = False
318+
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index]
319+
arrsize = (self._maxids, self.chunks[1]) # type: ignore[index]
347320
else:
348-
# Either use the store that was provided directly or create a DirectoryStore:
349-
if issubclass(type(self.fname), zarr.storage.Store):
350-
store = self.fname
351-
else:
352-
store = zarr.DirectoryStore(self.fname)
353-
Z = zarr.group(store=store, overwrite=False)
354-
obs = pset.particledata.getvardata("obs_written", indices_to_write)
355-
for var in self.vars_to_write:
356-
varout = self._convert_varout_name(var)
357-
if self._maxids > Z[varout].shape[0]:
358-
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=0)
321+
arrsize = (len(ids), self.chunks[1]) # type: ignore[index]
322+
ds = xr.Dataset(
323+
attrs=self.metadata,
324+
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
325+
)
326+
attrs = self._create_variables_attribute_dict()
327+
obs = np.zeros((self._maxids), dtype=np.int32)
328+
for var in self.vars_to_write:
329+
varout = self._convert_varout_name(var)
330+
if varout not in ["trajectory"]: # because 'trajectory' is written as coordinate
359331
if self._write_once(var):
360-
if len(once_ids) > 0:
361-
Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
332+
data = np.full(
333+
(arrsize[0],),
334+
self._fill_value_map[self.vars_to_write[var]],
335+
dtype=self.vars_to_write[var],
336+
)
337+
data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
338+
dims = ["trajectory"]
362339
else:
363-
if max(obs) >= Z[varout].shape[1]: # type: ignore[type-var]
364-
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=1)
365-
Z[varout].vindex[ids, obs] = pset.particledata.getvardata(var, indices_to_write)
340+
data = np.full(
341+
arrsize, self._fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]
342+
)
343+
data[ids, 0] = pset.particledata.getvardata(var, indices_to_write)
344+
dims = ["trajectory", "obs"]
345+
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
346+
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index]
347+
ds.to_zarr(self.fname, mode="w")
348+
self._create_new_zarrfile = False
349+
else:
350+
# Either use the store that was provided directly or create a DirectoryStore:
351+
if issubclass(type(self.fname), zarr.storage.Store):
352+
store = self.fname
353+
else:
354+
store = zarr.DirectoryStore(self.fname)
355+
Z = zarr.group(store=store, overwrite=False)
356+
obs = pset.particledata.getvardata("obs_written", indices_to_write)
357+
for var in self.vars_to_write:
358+
varout = self._convert_varout_name(var)
359+
if self._maxids > Z[varout].shape[0]:
360+
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=0)
361+
if self._write_once(var):
362+
if len(once_ids) > 0:
363+
Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
364+
else:
365+
if max(obs) >= Z[varout].shape[1]: # type: ignore[type-var]
366+
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=1)
367+
Z[varout].vindex[ids, obs] = pset.particledata.getvardata(var, indices_to_write)
366368

367-
pset.particledata.setvardata("obs_written", indices_to_write, obs + 1)
369+
pset.particledata.setvardata("obs_written", indices_to_write, obs + 1)
368370

369371
def write_latest_locations(self, pset, time):
370372
"""Write the current (latest) particle locations to zarr file.

0 commit comments

Comments
 (0)