Skip to content

Commit 8cfc912

Browse files
Merge pull request #1843 from OceanParcels/v/small-ref
Small refactors in `particlefile.py`
2 parents 429d3ae + 01ee0e7 commit 8cfc912

File tree

5 files changed

+83
-80
lines changed

5 files changed

+83
-80
lines changed

parcels/_index_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def search_indices_vertical_s(
7979
eta = 1
8080
if time < grid.time[ti]:
8181
ti -= 1
82-
if grid._z4d:
82+
if grid._z4d: # type: ignore[attr-defined]
8383
if ti == len(grid.time) - 1:
8484
depth_vector = (
8585
(1 - xsi) * (1 - eta) * grid.depth[-1, :, yi, xi]
@@ -232,7 +232,7 @@ def _search_indices_curvilinear(field: Field, time, z, y, x, ti=-1, particle=Non
232232
else:
233233
xi = int(field.grid.xdim / 2) - 1
234234
yi = int(field.grid.ydim / 2) - 1
235-
xsi = eta = -1
235+
xsi = eta = -1.0
236236
grid = field.grid
237237
invA = np.array([[1, 0, 0, 0], [-1, 1, 0, 0], [-1, 0, 0, 1], [1, -1, 1, -1]])
238238
maxIterSearch = 1e6

parcels/_interpolation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
146146
return 0
147147
elif nb_land > 0:
148148
val = 0
149-
w_sum = 0
149+
w_sum = 0.0
150150
for j in range(2):
151151
for i in range(2):
152152
distance = pow((eta - j), 2) + pow((xsi - i), 2)
@@ -196,8 +196,8 @@ def _linear_invdist_land_tracer_3d(ctx: InterpolationContext3D) -> float:
196196
if nb_land == 8:
197197
return 0
198198
elif nb_land > 0:
199-
val = 0
200-
w_sum = 0
199+
val = 0.0
200+
w_sum = 0.0
201201
for k in range(2):
202202
for j in range(2):
203203
for i in range(2):

parcels/grid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ def _calc_cell_edge_sizes(grid: RectilinearGrid) -> None:
830830
attribute to the grid.
831831
"""
832832
if not grid.cell_edge_sizes:
833-
if grid._gtype in (GridType.RectilinearZGrid, GridType.RectilinearSGrid):
833+
if grid._gtype in (GridType.RectilinearZGrid, GridType.RectilinearSGrid): # type: ignore[attr-defined]
834834
grid.cell_edge_sizes["x"] = np.zeros((grid.ydim, grid.xdim), dtype=np.float32)
835835
grid.cell_edge_sizes["y"] = np.zeros((grid.ydim, grid.xdim), dtype=np.float32)
836836

@@ -842,7 +842,7 @@ def _calc_cell_edge_sizes(grid: RectilinearGrid) -> None:
842842
grid.cell_edge_sizes["y"][y, x] = y_conv.to_source(dy, grid.depth[0], lat, lon)
843843
else:
844844
raise ValueError(
845-
f"_cell_edge_sizes() not implemented for {grid._gtype} grids. "
845+
f"_cell_edge_sizes() not implemented for {grid._gtype} grids. " # type: ignore[attr-defined]
846846
"You can provide Field.grid.cell_edge_sizes yourself by in, e.g., "
847847
"NEMO using the e1u fields etc from the mesh_mask.nc file."
848848
)

parcels/particledata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,9 @@ def flatten_dense_data_array(vname):
350350
cstruct = CParticles(*cdata)
351351
return cstruct
352352

353-
def _to_write_particles(self, pd, time):
353+
def _to_write_particles(self, time):
354354
"""Return the Particles that need to be written at time: if particle.time is between time-dt/2 and time+dt (/2)"""
355+
pd = self._data
355356
return np.where(
356357
(
357358
np.less_equal(time - np.abs(pd["dt"] / 2), pd["time"], where=np.isfinite(pd["time"]))

parcels/particlefile.py

Lines changed: 74 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -285,86 +285,88 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N
285285
return
286286

287287
if indices is None:
288-
indices_to_write = pset.particledata._to_write_particles(pset.particledata._data, time)
288+
indices_to_write = pset.particledata._to_write_particles(time)
289289
else:
290290
indices_to_write = indices
291291

292-
if len(indices_to_write) > 0:
293-
pids = pset.particledata.getvardata("id", indices_to_write)
294-
to_add = sorted(set(pids) - set(self._pids_written.keys()))
295-
for i, pid in enumerate(to_add):
296-
self._pids_written[pid] = self._maxids + i
297-
ids = np.array([self._pids_written[p] for p in pids], dtype=int)
298-
self._maxids = len(self._pids_written)
299-
300-
once_ids = np.where(pset.particledata.getvardata("obs_written", indices_to_write) == 0)[0]
301-
if len(once_ids) > 0:
302-
ids_once = ids[once_ids]
303-
indices_to_write_once = indices_to_write[once_ids]
304-
305-
if self.create_new_zarrfile:
306-
if self.chunks is None:
307-
self._chunks = (len(pset), 1)
308-
if pset._repeatpclass is not None and self.chunks[0] < 1e4: # type: ignore[index]
309-
warnings.warn(
310-
f"ParticleFile chunks are set to {self.chunks}, but this may lead to "
311-
f"a significant slowdown in Parcels when many calls to repeatdt. "
312-
f"Consider setting a larger chunk size for your ParticleFile (e.g. chunks=(int(1e4), 1)).",
313-
FileWarning,
314-
stacklevel=2,
315-
)
316-
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index]
317-
arrsize = (self._maxids, self.chunks[1]) # type: ignore[index]
318-
else:
319-
arrsize = (len(ids), self.chunks[1]) # type: ignore[index]
320-
ds = xr.Dataset(
321-
attrs=self.metadata,
322-
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
292+
if len(indices_to_write) == 0:
293+
return
294+
295+
pids = pset.particledata.getvardata("id", indices_to_write)
296+
to_add = sorted(set(pids) - set(self._pids_written.keys()))
297+
for i, pid in enumerate(to_add):
298+
self._pids_written[pid] = self._maxids + i
299+
ids = np.array([self._pids_written[p] for p in pids], dtype=int)
300+
self._maxids = len(self._pids_written)
301+
302+
once_ids = np.where(pset.particledata.getvardata("obs_written", indices_to_write) == 0)[0]
303+
if len(once_ids) > 0:
304+
ids_once = ids[once_ids]
305+
indices_to_write_once = indices_to_write[once_ids]
306+
307+
if self.create_new_zarrfile:
308+
if self.chunks is None:
309+
self._chunks = (len(pset), 1)
310+
if pset._repeatpclass is not None and self.chunks[0] < 1e4: # type: ignore[index]
311+
warnings.warn(
312+
f"ParticleFile chunks are set to {self.chunks}, but this may lead to "
313+
f"a significant slowdown in Parcels when many calls to repeatdt. "
314+
f"Consider setting a larger chunk size for your ParticleFile (e.g. chunks=(int(1e4), 1)).",
315+
FileWarning,
316+
stacklevel=2,
323317
)
324-
attrs = self._create_variables_attribute_dict()
325-
obs = np.zeros((self._maxids), dtype=np.int32)
326-
for var in self.vars_to_write:
327-
varout = self._convert_varout_name(var)
328-
if varout not in ["trajectory"]: # because 'trajectory' is written as coordinate
329-
if self._write_once(var):
330-
data = np.full(
331-
(arrsize[0],),
332-
self._fill_value_map[self.vars_to_write[var]],
333-
dtype=self.vars_to_write[var],
334-
)
335-
data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
336-
dims = ["trajectory"]
337-
else:
338-
data = np.full(
339-
arrsize, self._fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]
340-
)
341-
data[ids, 0] = pset.particledata.getvardata(var, indices_to_write)
342-
dims = ["trajectory", "obs"]
343-
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
344-
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index]
345-
ds.to_zarr(self.fname, mode="w")
346-
self._create_new_zarrfile = False
318+
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index]
319+
arrsize = (self._maxids, self.chunks[1]) # type: ignore[index]
347320
else:
348-
# Either use the store that was provided directly or create a DirectoryStore:
349-
if issubclass(type(self.fname), zarr.storage.Store):
350-
store = self.fname
351-
else:
352-
store = zarr.DirectoryStore(self.fname)
353-
Z = zarr.group(store=store, overwrite=False)
354-
obs = pset.particledata.getvardata("obs_written", indices_to_write)
355-
for var in self.vars_to_write:
356-
varout = self._convert_varout_name(var)
357-
if self._maxids > Z[varout].shape[0]:
358-
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=0)
321+
arrsize = (len(ids), self.chunks[1]) # type: ignore[index]
322+
ds = xr.Dataset(
323+
attrs=self.metadata,
324+
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
325+
)
326+
attrs = self._create_variables_attribute_dict()
327+
obs = np.zeros((self._maxids), dtype=np.int32)
328+
for var in self.vars_to_write:
329+
varout = self._convert_varout_name(var)
330+
if varout not in ["trajectory"]: # because 'trajectory' is written as coordinate
359331
if self._write_once(var):
360-
if len(once_ids) > 0:
361-
Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
332+
data = np.full(
333+
(arrsize[0],),
334+
self._fill_value_map[self.vars_to_write[var]],
335+
dtype=self.vars_to_write[var],
336+
)
337+
data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
338+
dims = ["trajectory"]
362339
else:
363-
if max(obs) >= Z[varout].shape[1]: # type: ignore[type-var]
364-
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=1)
365-
Z[varout].vindex[ids, obs] = pset.particledata.getvardata(var, indices_to_write)
340+
data = np.full(
341+
arrsize, self._fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]
342+
)
343+
data[ids, 0] = pset.particledata.getvardata(var, indices_to_write)
344+
dims = ["trajectory", "obs"]
345+
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
346+
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index]
347+
ds.to_zarr(self.fname, mode="w")
348+
self._create_new_zarrfile = False
349+
else:
350+
# Either use the store that was provided directly or create a DirectoryStore:
351+
if isinstance(self.fname, zarr.storage.Store):
352+
store = self.fname
353+
else:
354+
store = zarr.DirectoryStore(self.fname)
355+
Z = zarr.group(store=store, overwrite=False)
356+
obs = pset.particledata.getvardata("obs_written", indices_to_write)
357+
for var in self.vars_to_write:
358+
varout = self._convert_varout_name(var)
359+
if self._maxids > Z[varout].shape[0]:
360+
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=0)
361+
if self._write_once(var):
362+
if len(once_ids) > 0:
363+
Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
364+
else:
365+
if max(obs) >= Z[varout].shape[1]: # type: ignore[type-var]
366+
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=1)
367+
Z[varout].vindex[ids, obs] = pset.particledata.getvardata(var, indices_to_write)
366368

367-
pset.particledata.setvardata("obs_written", indices_to_write, obs + 1)
369+
pset.particledata.setvardata("obs_written", indices_to_write, obs + 1)
368370

369371
def write_latest_locations(self, pset, time):
370372
"""Write the current (latest) particle locations to zarr file.

0 commit comments

Comments
 (0)