Skip to content

Commit 823d180

Browse files
Creating helper functions for ParticleSetViewArray operations
1 parent 9aa0f32 commit 823d180

File tree

1 file changed

+35
-73
lines changed

1 file changed

+35
-73
lines changed

src/parcels/_core/particle.py

Lines changed: 35 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -193,19 +193,18 @@ def __len__(self):
193193
return len(self._index)
194194

195195

196-
class ParticleSetViewArray:
197-
"""Array-like proxy for a particle variable that writes through to the
198-
parent arrays when mutated.
196+
def _unwrap(other):
197+
"""Return ndarray for ParticleSetViewArray or the value unchanged."""
198+
return other.__array__() if isinstance(other, ParticleSetViewArray) else other
199199

200-
Parameters
201-
----------
202-
data : dict-like
203-
Parent particle storage (mapping varname -> ndarray)
204-
index : array-like
205-
Index representing the subset in the parent arrays (boolean mask or integer indices)
206-
name : str
207-
Variable name in `data` to proxy
208-
"""
200+
201+
def _asarray(other):
202+
"""Return numpy array for ParticleSetViewArray, otherwise return argument."""
203+
return np.asarray(other.__array__()) if isinstance(other, ParticleSetViewArray) else other
204+
205+
206+
class ParticleSetViewArray:
207+
"""Array-like proxy for a ParticleSetView that writes through to the parent arrays when mutated."""
209208

210209
def __init__(self, data, index, name):
211210
self._data = data
@@ -253,24 +252,11 @@ def _to_global_index(self, subindex=None):
253252
# map the first index (local selection) to global particle indices
254253
if base.dtype == bool:
255254
particle_idxs = np.flatnonzero(base)
256-
if isinstance(first, slice):
257-
sel = particle_idxs[first]
258-
elif isinstance(first, (np.ndarray, list)):
259-
first_arr = np.asarray(first)
260-
if first_arr.dtype == bool:
261-
sel = particle_idxs[first_arr]
262-
else:
263-
sel = particle_idxs[first_arr]
264-
elif isinstance(first, int):
265-
sel = particle_idxs[first]
266-
else:
267-
sel = particle_idxs[first]
255+
first_arr = np.asarray(first) if isinstance(first, (np.ndarray, list)) else first
256+
sel = particle_idxs[first_arr]
268257
else:
269258
base_arr = np.asarray(base)
270-
if isinstance(first, slice):
271-
sel = base_arr[first]
272-
else:
273-
sel = base_arr[first]
259+
sel = base_arr[first]
274260

275261
# if rest contains a single int (e.g., column), return tuple index
276262
if len(rest) == 1:
@@ -323,44 +309,38 @@ def __setitem__(self, subindex, value):
323309

324310
# in-place ops must write back into the parent array
325311
def __iadd__(self, other):
326-
vals = self._data[self._name][self._index] + (
327-
other.__array__() if isinstance(other, ParticleSetViewArray) else other
328-
)
312+
vals = self._data[self._name][self._index] + _unwrap(other)
329313
self._data[self._name][self._index] = vals
330314
return self
331315

332316
def __isub__(self, other):
333-
vals = self._data[self._name][self._index] - (
334-
other.__array__() if isinstance(other, ParticleSetViewArray) else other
335-
)
317+
vals = self._data[self._name][self._index] - _unwrap(other)
336318
self._data[self._name][self._index] = vals
337319
return self
338320

339321
def __imul__(self, other):
340-
vals = self._data[self._name][self._index] * (
341-
other.__array__() if isinstance(other, ParticleSetViewArray) else other
342-
)
322+
vals = self._data[self._name][self._index] * _unwrap(other)
343323
self._data[self._name][self._index] = vals
344324
return self
345325

346326
# Provide simple numpy-like evaluation for binary ops by delegating to ndarray
347327
def __add__(self, other):
348-
return self.__array__() + (other.__array__() if isinstance(other, ParticleSetViewArray) else other)
328+
return self.__array__() + _unwrap(other)
349329

350330
def __sub__(self, other):
351-
return self.__array__() - (other.__array__() if isinstance(other, ParticleSetViewArray) else other)
331+
return self.__array__() - _unwrap(other)
352332

353333
def __mul__(self, other):
354-
return self.__array__() * (other.__array__() if isinstance(other, ParticleSetViewArray) else other)
334+
return self.__array__() * _unwrap(other)
355335

356336
def __truediv__(self, other):
357-
return self.__array__() / (other.__array__() if isinstance(other, ParticleSetViewArray) else other)
337+
return self.__array__() / _unwrap(other)
358338

359339
def __floordiv__(self, other):
360-
return self.__array__() // (other.__array__() if isinstance(other, ParticleSetViewArray) else other)
340+
return self.__array__() // _unwrap(other)
361341

362342
def __pow__(self, other):
363-
return self.__array__() ** (other.__array__() if isinstance(other, ParticleSetViewArray) else other)
343+
return self.__array__() ** _unwrap(other)
364344

365345
def __neg__(self):
366346
return -self.__array__()
@@ -373,72 +353,54 @@ def __abs__(self):
373353

374354
# Right-hand operations to handle cases like `scalar - ParticleSetViewArray`
375355
def __radd__(self, other):
376-
return (other.__array__() if isinstance(other, ParticleSetViewArray) else other) + self.__array__()
356+
return _unwrap(other) + self.__array__()
377357

378358
def __rsub__(self, other):
379-
return (other.__array__() if isinstance(other, ParticleSetViewArray) else other) - self.__array__()
359+
return _unwrap(other) - self.__array__()
380360

381361
def __rmul__(self, other):
382-
return (other.__array__() if isinstance(other, ParticleSetViewArray) else other) * self.__array__()
362+
return _unwrap(other) * self.__array__()
383363

384364
def __rtruediv__(self, other):
385-
return (other.__array__() if isinstance(other, ParticleSetViewArray) else other) / self.__array__()
365+
return _unwrap(other) / self.__array__()
386366

387367
def __rfloordiv__(self, other):
388-
return (other.__array__() if isinstance(other, ParticleSetViewArray) else other) // self.__array__()
368+
return _unwrap(other) // self.__array__()
389369

390370
def __rpow__(self, other):
391-
return (other.__array__() if isinstance(other, ParticleSetViewArray) else other) ** self.__array__()
371+
return _unwrap(other) ** self.__array__()
392372

393373
# Comparison operators should return plain numpy boolean arrays so that
394374
# expressions like `mask = particles.gridID == gid` produce an ndarray
395375
# usable for indexing (rather than another ParticleSetViewArray).
396376
def __eq__(self, other):
397377
left = np.asarray(self.__array__())
398-
if isinstance(other, ParticleSetViewArray):
399-
right = np.asarray(other.__array__())
400-
else:
401-
right = other
378+
right = _asarray(other)
402379
return left == right
403380

404381
def __ne__(self, other):
405382
left = np.asarray(self.__array__())
406-
if isinstance(other, ParticleSetViewArray):
407-
right = np.asarray(other.__array__())
408-
else:
409-
right = other
383+
right = _asarray(other)
410384
return left != right
411385

412386
def __lt__(self, other):
413387
left = np.asarray(self.__array__())
414-
if isinstance(other, ParticleSetViewArray):
415-
right = np.asarray(other.__array__())
416-
else:
417-
right = other
388+
right = _asarray(other)
418389
return left < right
419390

420391
def __le__(self, other):
421392
left = np.asarray(self.__array__())
422-
if isinstance(other, ParticleSetViewArray):
423-
right = np.asarray(other.__array__())
424-
else:
425-
right = other
393+
right = _asarray(other)
426394
return left <= right
427395

428396
def __gt__(self, other):
429397
left = np.asarray(self.__array__())
430-
if isinstance(other, ParticleSetViewArray):
431-
right = np.asarray(other.__array__())
432-
else:
433-
right = other
398+
right = _asarray(other)
434399
return left > right
435400

436401
def __ge__(self, other):
437402
left = np.asarray(self.__array__())
438-
if isinstance(other, ParticleSetViewArray):
439-
right = np.asarray(other.__array__())
440-
else:
441-
right = other
403+
right = _asarray(other)
442404
return left >= right
443405

444406
# Allow attribute access like .dtype etc. by forwarding to the ndarray

0 commit comments

Comments
 (0)