Skip to content

Commit fff9d4e

Browse files
committed
Checking if an indexer is a scalar or not is now more robust
1 parent 9ea3b51 commit fff9d4e

2 files changed

Lines changed: 9 additions & 10 deletions

File tree

src/earthkit/data/indexing/tensor.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -422,14 +422,13 @@ def _prepare_tensor_data(self, source_to_array_func, index=None):
422422
# * `field_shape` does lose the dimension `dim` if `dim` is a field dimension
423423
current_field_shape = []
424424
for n, _idx in zip(self.field_shape, index):
425-
if isinstance(_idx, int):
426-
# simply, ignore this index
425+
_sizes = np.arange(n)[_idx].shape
426+
if len(_sizes) == 0:
427+
# _idx is a scalar indexer, and thus we ignore it
427428
continue
428-
if isinstance(_idx, slice):
429-
_size = len(range(n)[_idx])
430-
else:
431-
# _idx must be an iterable of integers
432-
_size = len(np.arange(n)[_idx])
429+
# _idx is a slice, an array of int's or a boolean mask
430+
# get the size of the selection made by the indexer _idx
431+
(_size,) = _sizes
433432
current_field_shape.append(_size)
434433
current_field_shape = tuple(current_field_shape)
435434

@@ -485,7 +484,7 @@ def _subset(self, indexes):
485484
user_indexes = []
486485

487486
for s, c in zip(indexes, self._user_shape):
488-
lst = np.array(list(range(c)))[s].tolist()
487+
lst = np.arange(c)[s].tolist()
489488
if not isinstance(lst, list):
490489
lst = [lst]
491490
user_coords.append(lst)
@@ -687,7 +686,7 @@ def _subset(self, indexes):
687686
user_indexes = []
688687

689688
for s, c in zip(indexes, self._user_shape):
690-
lst = np.array(list(range(c)))[s].tolist()
689+
lst = np.arange(c)[s].tolist()
691690
if not isinstance(lst, list):
692691
lst = [lst]
693692
user_icoords.append(lst)

src/earthkit/data/utils/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def outer_indexing(v, indices):
120120
v = v[_1d_index]
121121
v_ndim = v.ndim
122122
if v_ndim == ndim:
123-
full_slices = full_slices + (slice(None, None),)
123+
full_slices = full_slices + (slice(None),)
124124
else:
125125
# the current dimension has collapsed
126126
ndim = v_ndim

0 commit comments

Comments
 (0)