diff --git a/src/earthkit/data/core/field.py b/src/earthkit/data/core/field.py index fa7720a0..82868e0d 100644 --- a/src/earthkit/data/core/field.py +++ b/src/earthkit/data/core/field.py @@ -18,7 +18,7 @@ from earthkit.data.core.order import Patch, Remapping, build_remapping from earthkit.data.decorators import normalise from earthkit.data.utils.args import metadata_argument_new -from earthkit.data.utils.array import flatten_array, reshape_array, target_shape +from earthkit.data.utils.array import flatten_array, outer_indexing, reshape_array, target_shape from earthkit.data.utils.compute import wrap_maths GRIB = "grib" @@ -485,7 +485,7 @@ def to_numpy(self, flatten=False, dtype=None, copy=True, index=None): v = flatten_array(v) if flatten else reshape_array(v, self.shape) if index is not None: - v = v[index] + v = outer_indexing(v, index) return v @@ -531,7 +531,7 @@ def to_array( v = flatten_array(v) if flatten else reshape_array(v, self.shape) if index is not None: - v = v[index] + v = outer_indexing(v, index) return v @@ -615,7 +615,7 @@ def _reshape(v, flatten): raise ValueError(f"data: {k} not available") v = _reshape(v, flatten) if index is not None: - v = v[index] + v = outer_indexing(v, index) r[k] = v # convert latlon to array format diff --git a/src/earthkit/data/data/wrappers/xarray.py b/src/earthkit/data/data/wrappers/xarray.py index 43206b4c..fff5177f 100644 --- a/src/earthkit/data/data/wrappers/xarray.py +++ b/src/earthkit/data/data/wrappers/xarray.py @@ -7,6 +7,7 @@ # nor does it submit to any jurisdiction. # + from . import ObjectWrapperData diff --git a/src/earthkit/data/indexing/tensor.py b/src/earthkit/data/indexing/tensor.py index cc3b45ac..18fd2355 100644 --- a/src/earthkit/data/indexing/tensor.py +++ b/src/earthkit/data/indexing/tensor.py @@ -419,14 +419,13 @@ def _prepare_tensor_data(self, source_to_array_func, index=None): # * `field_shape` does lose the dimension `dim` if `dim` is a field dimension current_field_shape = [] for n, _idx in zip(self.field_shape, index): - if isinstance(_idx, int): - # simply, ignore this index + _sizes = np.arange(n)[_idx].shape + if len(_sizes) == 0: + # _idx is a scalar indexer, and thus we ignore it continue - if isinstance(_idx, slice): - _size = len(range(n)[_idx]) - else: - # _idx must be an iterable of integers - _size = len(np.arange(n)[_idx]) + # _idx is a slice, an array of int's or a boolean mask + # get the size of the selection made by the indexer _idx + (_size,) = _sizes current_field_shape.append(_size) current_field_shape = tuple(current_field_shape) @@ -461,13 +460,16 @@ def field_indexes(self, indexes): return indexes[len(self._user_shape) :] def is_full_field(self, indexes): - assert len(indexes) == len(self._field_shape) - for i, s in enumerate(indexes): - if not ( - s is None - or isinstance(s, slice) - and (s == slice(None, None, None) or s == slice(0, self._field_shape[i], 1)) - ): + def is_full_indexer(index, size): + if index is None: + return True + if isinstance(index, slice): + return index.indices(size) == (0, size, 1) + full_indexer = np.arange(size) + return np.array_equal(full_indexer[index], full_indexer) + + for index, size in zip(indexes, self._field_shape, strict=True): + if not is_full_indexer(index, size): return False return True @@ -481,7 +483,7 @@ def _subset(self, indexes): user_indexes = [] for s, c in zip(indexes, self._user_shape): - lst = np.array(list(range(c)))[s].tolist() + lst = np.arange(c)[s].tolist() if not isinstance(lst, list): lst = [lst] user_coords.append(lst) @@ -683,7 +685,7 @@ def _subset(self, indexes): user_indexes = [] for s, c in zip(indexes, self._user_shape): - lst = np.array(list(range(c)))[s].tolist() + lst = np.arange(c)[s].tolist() if not isinstance(lst, list): lst = [lst] user_icoords.append(lst) diff --git a/src/earthkit/data/utils/array.py b/src/earthkit/data/utils/array.py index 97c3d317..7cdb38ff 100644 --- a/src/earthkit/data/utils/array.py +++ b/src/earthkit/data/utils/array.py @@ -84,3 +84,42 @@ def adjust_array(v, flatten=False, dtype=None): v = target_xp.astype(v, target_dtype, copy=False) return v + + +def outer_indexing(v, indices): + """Performs an outer indexing of an array ``v``. + The parameter ``indices`` is a tuple of indices. + Each index is either an int, a slice of an array of int's. + Each index form ``indices`` restricts/sub-selects the corresponding dimension of the array ``v``, + independently (orthogonally) to other indices (hence the name: "outer indexing"). + + This function mimics the behaviour of ``xarray.DataArray(v)[indices].values``, + and in general it is different from the behaviour of e.g. + the numpy indexing, i.e. ``v[indices]`` + (see https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing). + + Parameters + ---------- + v: array-like + The array to be reshaped. + indices: tuple of items, each being an int, a slice or an array-like of int's + + Returns + ------- + array-like + Sub-selection of the array ``v`` according to ``indices`` + """ + if not isinstance(indices, tuple): + indices = (indices,) + full_slices = () + ndim = v.ndim + for idx in indices: + _1d_index = full_slices + (idx,) + v = v[_1d_index] + v_ndim = v.ndim + if v_ndim == ndim: + full_slices = full_slices + (slice(None),) + else: + # the current dimension has collapsed + ndim = v_ndim + return v diff --git a/tests/xr_engine/test_xr_engine_core.py b/tests/xr_engine/test_xr_engine_core.py index 4cdc5531..6a0653dc 100644 --- a/tests/xr_engine/test_xr_engine_core.py +++ b/tests/xr_engine/test_xr_engine_core.py @@ -216,35 +216,44 @@ def test_xr_engine_detailed_check_1(allow_holes, lazy_load, api): vals_ref = np.array([[269.00918579, 268.78610229], [268.57771301, 268.08932495]]) assert np.allclose(r.values, vals_ref) - r = da[:, 0, :, 2, 9:12, :2] - assert r.shape == (2, 2, 3, 2) - vals_ref = np.array([ - [ - [ - [269.00918579, 269.31680298], - [269.70254517, 269.81387329], - [267.50527954, 266.83828735], - ], - [ - [268.78610229, 268.80758667], - [269.52731323, 269.75680542], - [266.61813354, 267.12106323], - ], - ], - [ + r = da[:, 0, :, 2, [9], [True] + [False] * 35] + assert r.shape == (2, 2, 1, 1) + vals_ref = np.array([[[[269.00918579]], [[268.78610229]]], [[[268.57771301]], [[268.08932495]]]]) + assert np.allclose(r.values, vals_ref) + + for r in ( + da[:, 0, :, 2, 9:12, :2], + da[:, 0, :, 2, [9, 10, 11], [0, 1]], + da[:, 0, :, 2, [9, 10, 11], [True, True] + [False] * 34], + ): + assert r.shape == (2, 2, 3, 2) + vals_ref = np.array([ [ - [268.57771301, 269.03767395], - [269.33357239, 269.56111145], - [264.75154114, 266.55036926], + [ + [269.00918579, 269.31680298], + [269.70254517, 269.81387329], + [267.50527954, 266.83828735], + ], + [ + [268.78610229, 268.80758667], + [269.52731323, 269.75680542], + [266.61813354, 267.12106323], + ], ], [ - [268.08932495, 268.35983276], - [269.01803589, 269.02389526], - [264.29733276, 266.08248901], + [ + [268.57771301, 269.03767395], + [269.33357239, 269.56111145], + [264.75154114, 266.55036926], + ], + [ + [268.08932495, 268.35983276], + [269.01803589, 269.02389526], + [264.29733276, 266.08248901], + ], ], - ], - ]) - assert np.allclose(r.values, vals_ref) + ]) + assert np.allclose(r.values, vals_ref) r = da.loc[:, 0, :, 500, 0, 0] assert r.shape == (2, 2) @@ -408,35 +417,44 @@ def test_xr_engine_detailed_check_2(allow_holes, lazy_load, api): vals_ref = np.array([[269.00918579, 268.78610229], [268.57771301, 268.08932495]]) assert np.allclose(r.values, vals_ref) - r = da[:, 0, :, 2, 9:12, :2] - assert r.shape == (2, 2, 3, 2) - vals_ref = np.array([ - [ - [ - [269.00918579, 269.31680298], - [269.70254517, 269.81387329], - [267.50527954, 266.83828735], - ], - [ - [268.78610229, 268.80758667], - [269.52731323, 269.75680542], - [266.61813354, 267.12106323], - ], - ], - [ + r = da[:, 0, :, 2, [9], [True] + [False] * 35] + assert r.shape == (2, 2, 1, 1) + vals_ref = np.array([[[[269.00918579]], [[268.78610229]]], [[[268.57771301]], [[268.08932495]]]]) + assert np.allclose(r.values, vals_ref) + + for r in ( + da[:, 0, :, 2, 9:12, :2], + da[:, 0, :, 2, [9, 10, 11], [0, 1]], + da[:, 0, :, 2, [9, 10, 11], [True, True] + [False] * 34], + ): + assert r.shape == (2, 2, 3, 2) + vals_ref = np.array([ [ - [268.57771301, 269.03767395], - [269.33357239, 269.56111145], - [264.75154114, 266.55036926], + [ + [269.00918579, 269.31680298], + [269.70254517, 269.81387329], + [267.50527954, 266.83828735], + ], + [ + [268.78610229, 268.80758667], + [269.52731323, 269.75680542], + [266.61813354, 267.12106323], + ], ], [ - [268.08932495, 268.35983276], - [269.01803589, 269.02389526], - [264.29733276, 266.08248901], + [ + [268.57771301, 269.03767395], + [269.33357239, 269.56111145], + [264.75154114, 266.55036926], + ], + [ + [268.08932495, 268.35983276], + [269.01803589, 269.02389526], + [264.29733276, 266.08248901], + ], ], - ], - ]) - assert np.allclose(r.values, vals_ref) + ]) + assert np.allclose(r.values, vals_ref) r = da.loc[:, datetime.time(0, 0), :, 500, 0, 0] assert r.shape == (2, 2) diff --git a/tests/xr_engine/test_xr_engine_indexing.py b/tests/xr_engine/test_xr_engine_indexing.py index 68382814..a15c0577 100644 --- a/tests/xr_engine/test_xr_engine_indexing.py +++ b/tests/xr_engine/test_xr_engine_indexing.py @@ -51,3 +51,26 @@ def test_xr_engine_daily_mean(allow_holes, kwargs): ds.resample({"valid_time": "24h"}).mean().groupby("valid_time.month") - monthly_mean_ds ) assert np.allclose(np.abs(daily_anomaly_from_monthly_mean_ds).max()["2t"].values, 28.98466590143022) + + +@pytest.mark.cache +@pytest.mark.long_test +@pytest.mark.timeout(60) +def test_xr_engine_groupby_forecast_valid_time(): + seas5_data = from_source( + "url", + "https://sites.ecmwf.int/repository/earthkit-data/test-data/seas5_2m_temperature_201501-201503_europe_1deg.grib", + ) + seas5_data = seas5_data.to_fieldlist() + seas5_xr = seas5_data.to_xarray( + time_dim_mode="forecast", + add_valid_time_coord=True, + ).rename({"2t": "t2m"}) + + seas5_xr_mean = seas5_xr.groupby("valid_time.day").mean() + assert np.allclose(seas5_xr_mean.max()["t2m"].values, 288.32264709) + + seas5_xr_mean2 = seas5_xr.sel(latitude=[50, 55, 60], longitude=[-10, 0, 10, 20]).groupby("valid_time.day").mean() + assert np.allclose(seas5_xr_mean2["latitude"].values, [50, 55, 60]) + assert np.allclose(seas5_xr_mean2["longitude"].values, [-10, 0, 10, 20]) + assert np.allclose(seas5_xr_mean2.max()["t2m"].values, 283.53151703)